diff --git a/lua/cp/setup.lua b/lua/cp/setup.lua index e6f9d21..991242e 100644 --- a/lua/cp/setup.lua +++ b/lua/cp/setup.lua @@ -9,7 +9,7 @@ local scraper = require('cp.scraper') local state = require('cp.state') ---Get the language of the current file from cache ----@return string|nil +---@return string? local function get_current_file_language() local current_file = vim.fn.expand('%:p') if current_file == '' then @@ -20,6 +20,34 @@ local function get_current_file_language() return file_state and file_state.language or nil end +---Check if a problem file exists for any enabled language +---@param platform string +---@param contest_id string +---@param problem_id string +---@return string? +local function get_existing_problem_language(platform, contest_id, problem_id) + local config = config_module.get_config() + local platform_config = config.platforms[platform] + if not platform_config then + return nil + end + + for _, lang_id in ipairs(platform_config.enabled_languages) do + local effective = config.runtime.effective[platform][lang_id] + if effective and effective.extension then + local basename = config.filename + and config.filename(platform, contest_id, problem_id, config, lang_id) + or config_module.default_filename(contest_id, problem_id) + local filepath = basename .. '.' .. effective.extension + if vim.fn.filereadable(filepath) == 1 then + return lang_id + end + end + end + + return nil +end + ---@class TestCaseLite ---@field input string ---@field expected string @@ -306,19 +334,28 @@ function M.navigate_problem(direction, language) require('cp.ui.views').disable() end + local lang = nil + if language then local lang_result = config_module.get_language_for_platform(platform, language) if not lang_result.valid then logger.log(lang_result.error, vim.log.levels.ERROR) return end - end - - local lang = language or get_current_file_language() - if lang and not language then - local lang_result = config_module.get_language_for_platform(platform, lang) - if not lang_result.valid then - lang = nil + lang = language + else + local existing_lang = + get_existing_problem_language(platform, contest_id, problems[new_index].id) + if existing_lang then + lang = existing_lang + else + lang = get_current_file_language() + if lang then + local lang_result = config_module.get_language_for_platform(platform, lang) + if not lang_result.valid then + lang = nil + end + end end end