diff --git a/lua/cp/runner/execute.lua b/lua/cp/runner/execute.lua index a871d6f..8f004e5 100644 --- a/lua/cp/runner/execute.lua +++ b/lua/cp/runner/execute.lua @@ -164,7 +164,7 @@ function M.compile_problem(debug) local state = require('cp.state') local config = require('cp.config').get_config() local platform = state.get_platform() - local language = config.platforms[platform].default_language + local language = state.get_language() or config.platforms[platform].default_language local eff = config.runtime.effective[platform][language] local compile_config = (debug and eff.commands.debug) or eff.commands.build diff --git a/lua/cp/runner/run.lua b/lua/cp/runner/run.lua index c8fcd0c..eaec30b 100644 --- a/lua/cp/runner/run.lua +++ b/lua/cp/runner/run.lua @@ -109,7 +109,7 @@ local function run_single_test_case(test_case, debug) local substitutions = { source = source_file, binary = binary_file } local platform_config = config.platforms[state.get_platform() or ''] - local language = platform_config.default_language + local language = state.get_language() or platform_config.default_language local eff = config.runtime.effective[state.get_platform() or ''][language] local run_template = eff and eff.commands and eff.commands.run or {} local cmd = build_command(run_template, substitutions) diff --git a/lua/cp/setup.lua b/lua/cp/setup.lua index 6129614..20f852a 100644 --- a/lua/cp/setup.lua +++ b/lua/cp/setup.lua @@ -202,6 +202,8 @@ function M.setup_problem(problem_id, language) end end + state.set_language(lang) + local source_file = state.get_source_file(lang) if not source_file then return diff --git a/lua/cp/state.lua b/lua/cp/state.lua index caf5044..621b184 100644 --- a/lua/cp/state.lua +++ b/lua/cp/state.lua @@ -20,6 +20,8 @@ ---@field set_contest_id fun(contest_id: string) ---@field get_problem_id fun(): string? ---@field set_problem_id fun(problem_id: string) +---@field get_language fun(): string? +---@field set_language fun(language: string) ---@field get_active_panel fun(): string? ---@field set_active_panel fun(panel: string?) ---@field get_base_name fun(): string? @@ -42,6 +44,7 @@ local state = { platform = nil, contest_id = nil, problem_id = nil, + language = nil, test_cases = nil, saved_session = nil, active_panel = nil, @@ -80,6 +83,16 @@ function M.set_problem_id(problem_id) state.problem_id = problem_id end +---@return string? +function M.get_language() + return state.language +end + +---@param language string +function M.set_language(language) + state.language = language +end + ---@return string? function M.get_base_name() local platform, contest_id, problem_id = M.get_platform(), M.get_contest_id(), M.get_problem_id() @@ -112,7 +125,7 @@ function M.get_source_file(language) return nil end - local target_language = language or platform_cfg.default_language + local target_language = language or state.language or platform_cfg.default_language local eff = config.runtime.effective[plat] and config.runtime.effective[plat][target_language] or nil if not eff or not eff.extension then