diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index ede6eaa..8a45294 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -6,10 +6,9 @@ ---@class ContestData ---@field problems Problem[] ----@field test_cases? CachedTestCase[] ----@field timeout_ms? number ----@field memory_mb? number ----@field interactive? boolean +---@field index_map table +---@field name string +---@field display_name string ---@class ContestSummary ---@field display_name string @@ -19,19 +18,20 @@ ---@class CacheData ---@field [string] table ---@field file_states? table ----@field contest_lists? table - ----@class ContestListData ----@field contests table[] +---@field contest_lists? table ---@class Problem ---@field id string ---@field name? string +---@field interactive? boolean +---@field memory_mb? number +---@field timeout_ms? number +---@field test_cases TestCase[] ----@class CachedTestCase +---@class TestCase ---@field index? number ----@field input string ---@field expected? string +---@field input? string ---@field output? string local M = {} @@ -164,7 +164,7 @@ end ---@param platform string ---@param contest_id string ---@param problem_id? string ----@return CachedTestCase[]? +---@return TestCase[] function M.get_test_cases(platform, contest_id, problem_id) vim.validate({ platform = { platform, 'string' }, @@ -178,8 +178,7 @@ function M.get_test_cases(platform, contest_id, problem_id) or not cache_data[platform][contest_id].problems or not cache_data[platform][contest_id].index_map then - print('bad, failing') - return nil + return {} end local index = cache_data[platform][contest_id].index_map[problem_id] @@ -189,7 +188,7 @@ end ---@param platform string ---@param contest_id string ---@param problem_id string ----@param test_cases CachedTestCase[] +---@param test_cases TestCase[] ---@param timeout_ms? number ---@param memory_mb? number ---@param interactive? boolean @@ -269,7 +268,7 @@ function M.set_file_state(file_path, platform, contest_id, problem_id, language) end ---@param platform string ----@return table[ContestSummary] +---@return ContestSummary[] function M.get_contest_summaries(platform) local contest_list = {} for contest_id, contest_data in pairs(cache_data[platform] or {}) do @@ -283,7 +282,7 @@ function M.get_contest_summaries(platform) end ---@param platform string ----@param contests table[ContestSummary] +---@param contests ContestSummary[] function M.set_contest_summaries(platform, contests) cache_data[platform] = cache_data[platform] or {} for _, contest in ipairs(contests) do diff --git a/lua/cp/commands/cache.lua b/lua/cp/commands/cache.lua index 8af5ff7..e7a2c1e 100644 --- a/lua/cp/commands/cache.lua +++ b/lua/cp/commands/cache.lua @@ -6,6 +6,9 @@ local logger = require('cp.log') local platforms = constants.PLATFORMS +--- Dispatch any `:CP cache ...` command +---@param cmd table +---@return nil function M.handle_cache_command(cmd) if cmd.subcommand == 'read' then local data = cache.get_data_pretty() diff --git a/lua/cp/commands/init.lua b/lua/cp/commands/init.lua index b4ce499..6ad48c2 100644 --- a/lua/cp/commands/init.lua +++ b/lua/cp/commands/init.lua @@ -7,6 +7,19 @@ local state = require('cp.state') local platforms = constants.PLATFORMS local actions = constants.ACTIONS +---@class ParsedCommand +---@field type string +---@field error string? +---@field language? string +---@field debug? boolean +---@field action? string +---@field message? string +---@field contest? string +---@field platform? string + +--- Turn raw args into normalized structure to later dispatch +---@param args string[] The raw command-line mode args +---@return ParsedCommand local function parse_command(args) if vim.tbl_isempty(args) then return { @@ -94,6 +107,8 @@ local function parse_command(args) return { type = 'error', message = 'Unknown command or no contest context' } end +--- Core logic for handling `:CP ...` commands +---@return nil function M.handle_command(opts) local cmd = parse_command(opts.fargs) @@ -105,10 +120,7 @@ function M.handle_command(opts) if cmd.type == 'restore_from_file' then local restore = require('cp.restore') restore.restore_from_current_file() - return - end - - if cmd.type == 'action' then + elseif cmd.type == 'action' then local setup = require('cp.setup') local ui = require('cp.ui.panel') @@ -124,16 +136,10 @@ function M.handle_command(opts) local picker = require('cp.commands.picker') picker.handle_pick_action() end - return - end - - if cmd.type == 'cache' then + elseif cmd.type == 'cache' then local cache_commands = require('cp.commands.cache') cache_commands.handle_cache_command(cmd) - return - end - - if cmd.type == 'contest_setup' then + elseif cmd.type == 'contest_setup' then local setup = require('cp.setup') if setup.set_platform(cmd.platform) then setup.setup_contest(cmd.platform, cmd.contest, cmd.language, nil) diff --git a/lua/cp/commands/picker.lua b/lua/cp/commands/picker.lua index 80d79be..f41c9b3 100644 --- a/lua/cp/commands/picker.lua +++ b/lua/cp/commands/picker.lua @@ -3,6 +3,8 @@ local M = {} local config_module = require('cp.config') local logger = require('cp.log') +--- Dispatch `:CP pick` to appropriate picker +---@return nil function M.handle_pick_action() local config = config_module.get_config() diff --git a/lua/cp/config.lua b/lua/cp/config.lua index 6d96793..4788832 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -6,23 +6,11 @@ ---@field version? number Language version ---@field extension? string File extension ----@class PartialLanguageConfig ----@field compile? string[] Compile command template ----@field test? string[] Test execution command template ----@field debug? string[] Debug command template ----@field executable? string Executable name ----@field extension? string File extension - ---@class ContestConfig ---@field cpp LanguageConfig ---@field python LanguageConfig ---@field default_language? string ----@class PartialContestConfig ----@field cpp? PartialLanguageConfig ----@field python? PartialLanguageConfig ----@field default_language? string - ---@class Hooks ---@field before_run? fun(state: cp.State) ---@field before_debug? fun(state: cp.State) @@ -43,25 +31,25 @@ ---@class cp.Config ---@field contests table ----@field snippets table[] +---@field snippets any[] ---@field hooks Hooks ---@field debug boolean ----@field scrapers table +---@field scrapers string[] ---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string ---@field run_panel RunPanelConfig ---@field diff DiffConfig ----@field picker "telescope"|"fzf-lua"|nil +---@field picker string|nil ----@class cp.UserConfig ----@field contests? table ----@field snippets? table[] +---@class cp.PartialConfig +---@field contests? table +---@field snippets? any[] ---@field hooks? Hooks ---@field debug? boolean ----@field scrapers? table +---@field scrapers? string[] ---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string ---@field run_panel? RunPanelConfig ---@field diff? DiffConfig ----@field picker? "telescope"|"fzf-lua"|nil +---@field picker? string|nil local M = {} local constants = require('cp.constants') @@ -112,7 +100,7 @@ M.defaults = { picker = nil, } ----@param user_config cp.UserConfig|nil +---@param user_config cp.PartialConfig|nil ---@return cp.Config function M.setup(user_config) vim.validate({ @@ -279,10 +267,14 @@ M.default_filename = default_filename local current_config = nil +--- Set the config +---@return nil function M.set_current_config(config) current_config = config end +--- Get the config +---@return cp.Config function M.get_config() return current_config or M.defaults end diff --git a/lua/cp/init.lua b/lua/cp/init.lua index b2881a9..a6f70a1 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -13,6 +13,8 @@ local user_config = {} local config = config_module.setup(user_config) local snippets_initialized = false +--- Root handler for all `:CP ...` commands +---@return nil function M.handle_command(opts) local commands = require('cp.commands') commands.handle_command(opts) diff --git a/lua/cp/runner/run.lua b/lua/cp/runner/run.lua index 4462d0c..9480ca3 100644 --- a/lua/cp/runner/run.lua +++ b/lua/cp/runner/run.lua @@ -1,10 +1,10 @@ ----@class TestCase +---@class RanTestCase ---@field index number ---@field input string ---@field expected string ---@field status "pending"|"pass"|"fail"|"running"|"timeout" ---@field actual string? ----@field actual_highlights table[]? +---@field actual_highlights? any[] ---@field time_ms number? ---@field error string? ---@field stderr string? @@ -19,7 +19,7 @@ ---@field memory_mb number ---@class RunPanelState ----@field test_cases TestCase[] +---@field test_cases RanTestCase[] ---@field current_index number ---@field buffer number? ---@field namespace number? @@ -45,7 +45,7 @@ local run_panel_state = { ---@param index number ---@param input string ---@param expected string ----@return TestCase +---@return RanTestCase local function create_test_case(index, input, expected) return { index = index, @@ -62,7 +62,7 @@ end ---@param platform string ---@param contest_id string ---@param problem_id string? ----@return TestCase[] +---@return RanTestCase[] local function parse_test_cases_from_cache(platform, contest_id, problem_id) local cache = require('cp.cache') cache.load() @@ -103,13 +103,13 @@ local function load_constraints_from_cache(platform, contest_id, problem_id) end ---@param contest_config ContestConfig ----@param test_case TestCase +---@param test_case RanTestCase ---@return table local function run_single_test_case(contest_config, cp_config, test_case) local state = require('cp.state') local source_file = state.get_source_file() - local language = vim.fn.fnamemodify(source_file, ':e') + local language = vim.fn.fnamemodify(source_file or '', ':e') local language_name = constants.filetype_to_language[language] or contest_config.default_language local language_config = contest_config[language_name] @@ -297,7 +297,7 @@ end ---@param contest_config ContestConfig ---@param cp_config cp.Config ----@return TestCase[] +---@return RanTestCase[] function M.run_all_test_cases(contest_config, cp_config) local results = {} for i, _ in ipairs(run_panel_state.test_cases) do diff --git a/lua/cp/runner/run_render.lua b/lua/cp/runner/run_render.lua index f8ad0a5..33eb3f6 100644 --- a/lua/cp/runner/run_render.lua +++ b/lua/cp/runner/run_render.lua @@ -23,22 +23,22 @@ local exit_code_names = { [143] = 'SIGCHLD', } ----@param test_case TestCase +---@param ran_test_case RanTestCase ---@return StatusInfo -function M.get_status_info(test_case) - if test_case.status == 'pass' then +function M.get_status_info(ran_test_case) + if ran_test_case.status == 'pass' then return { text = 'AC', highlight_group = 'CpTestAC' } - elseif test_case.status == 'fail' then - if test_case.timed_out then + elseif ran_test_case.status == 'fail' then + if ran_test_case.timed_out then return { text = 'TLE', highlight_group = 'CpTestTLE' } - elseif test_case.code and test_case.code >= 128 then + elseif ran_test_case.code and ran_test_case.code >= 128 then return { text = 'RTE', highlight_group = 'CpTestRTE' } else return { text = 'WA', highlight_group = 'CpTestWA' } end - elseif test_case.status == 'timeout' then + elseif ran_test_case.status == 'timeout' then return { text = 'TLE', highlight_group = 'CpTestTLE' } - elseif test_case.status == 'running' then + elseif ran_test_case.status == 'running' then return { text = '...', highlight_group = 'CpTestPending' } else return { text = '', highlight_group = 'CpTestPending' } @@ -278,7 +278,7 @@ local function data_row(c, idx, tc, is_current, test_state) end ---@param test_state RunPanelState ----@return string[], table[] lines and highlight positions +---@return string[], any[] lines and highlight positions function M.render_test_list(test_state) local lines, highlights = {}, {} local c = compute_cols(test_state) @@ -332,18 +332,18 @@ function M.render_test_list(test_state) return lines, highlights end ----@param test_case TestCase? +---@param ran_test_case RanTestCase? ---@return string -function M.render_status_bar(test_case) - if not test_case then +function M.render_status_bar(ran_test_case) + if not ran_test_case then return '' end local parts = {} - if test_case.time_ms then - table.insert(parts, string.format('%.2fms', test_case.time_ms)) + if ran_test_case.time_ms then + table.insert(parts, string.format('%.2fms', ran_test_case.time_ms)) end - if test_case.code then - table.insert(parts, string.format('Exit: %d', test_case.code)) + if ran_test_case.code then + table.insert(parts, string.format('Exit: %d', ran_test_case.code)) end return table.concat(parts, ' │ ') end diff --git a/lua/cp/ui/ansi.lua b/lua/cp/ui/ansi.lua index 642b624..893b0db 100644 --- a/lua/cp/ui/ansi.lua +++ b/lua/cp/ui/ansi.lua @@ -1,6 +1,6 @@ ---@class AnsiParseResult ---@field lines string[] ----@field highlights table[] +---@field highlights any[] local M = {} diff --git a/lua/cp/ui/diff.lua b/lua/cp/ui/diff.lua index 16dff5a..819fa3f 100644 --- a/lua/cp/ui/diff.lua +++ b/lua/cp/ui/diff.lua @@ -1,6 +1,6 @@ ---@class DiffResult ---@field content string[] ----@field highlights table[]? +---@field highlights any[]? ---@field raw_diff string? ---@class DiffBackend