Merge pull request #132 from barrett-ruth/fix/typing

Fix/typing
This commit is contained in:
Barrett Ruth 2025-10-02 20:30:21 +02:00 committed by GitHub
commit 3c0f8d7deb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 110 additions and 119 deletions

View file

@ -4,29 +4,34 @@
---@field problem_id? string ---@field problem_id? string
---@field language? string ---@field language? string
---@class ContestData
---@field problems Problem[]
---@field index_map table<string, number>
---@field name string
---@field display_name string
---@class ContestSummary
---@field display_name string
---@field name string
---@field id string
---@class CacheData ---@class CacheData
---@field [string] table<string, ContestData> ---@field [string] table<string, ContestData>
---@field file_states? table<string, FileState> ---@field file_states? table<string, FileState>
---@field contest_lists? table<string, ContestListData> ---@field contest_lists? table<string, ContestSummary>
---@class ContestListData
---@field contests table[]
---@class ContestData
---@field problems Problem[]
---@field test_cases? CachedTestCase[]
---@field timeout_ms? number
---@field memory_mb? number
---@field interactive? boolean
---@class Problem ---@class Problem
---@field id string ---@field id string
---@field name? string ---@field name? string
---@field interactive? boolean
---@field memory_mb? number
---@field timeout_ms? number
---@field test_cases TestCase[]
---@class CachedTestCase ---@class TestCase
---@field index? number ---@field index? number
---@field input string
---@field expected? string ---@field expected? string
---@field input? string
---@field output? string ---@field output? string
local M = {} local M = {}
@ -36,6 +41,8 @@ local cache_file = vim.fn.stdpath('data') .. '/cp-nvim.json'
local cache_data = {} local cache_data = {}
local loaded = false local loaded = false
--- Load the cache from disk if not done already
---@return nil
function M.load() function M.load()
if loaded then if loaded then
return return
@ -63,6 +70,8 @@ function M.load()
loaded = true loaded = true
end end
--- Save the cache to disk, overwriting existing contents
---@return nil
function M.save() function M.save()
vim.schedule(function() vim.schedule(function()
vim.fn.mkdir(vim.fn.fnamemodify(cache_file, ':h'), 'p') vim.fn.mkdir(vim.fn.fnamemodify(cache_file, ':h'), 'p')
@ -155,7 +164,7 @@ end
---@param platform string ---@param platform string
---@param contest_id string ---@param contest_id string
---@param problem_id? string ---@param problem_id? string
---@return CachedTestCase[]? ---@return TestCase[]
function M.get_test_cases(platform, contest_id, problem_id) function M.get_test_cases(platform, contest_id, problem_id)
vim.validate({ vim.validate({
platform = { platform, 'string' }, platform = { platform, 'string' },
@ -169,8 +178,7 @@ function M.get_test_cases(platform, contest_id, problem_id)
or not cache_data[platform][contest_id].problems or not cache_data[platform][contest_id].problems
or not cache_data[platform][contest_id].index_map or not cache_data[platform][contest_id].index_map
then then
print('bad, failing') return {}
return nil
end end
local index = cache_data[platform][contest_id].index_map[problem_id] local index = cache_data[platform][contest_id].index_map[problem_id]
@ -180,7 +188,7 @@ end
---@param platform string ---@param platform string
---@param contest_id string ---@param contest_id string
---@param problem_id string ---@param problem_id string
---@param test_cases CachedTestCase[] ---@param test_cases TestCase[]
---@param timeout_ms? number ---@param timeout_ms? number
---@param memory_mb? number ---@param memory_mb? number
---@param interactive? boolean ---@param interactive? boolean
@ -260,8 +268,8 @@ function M.set_file_state(file_path, platform, contest_id, problem_id, language)
end end
---@param platform string ---@param platform string
---@return table[] ---@return ContestSummary[]
function M.get_contest_list(platform) function M.get_contest_summaries(platform)
local contest_list = {} local contest_list = {}
for contest_id, contest_data in pairs(cache_data[platform] or {}) do for contest_id, contest_data in pairs(cache_data[platform] or {}) do
table.insert(contest_list, { table.insert(contest_list, {
@ -274,8 +282,8 @@ function M.get_contest_list(platform)
end end
---@param platform string ---@param platform string
---@param contests table[] ---@param contests ContestSummary[]
function M.set_contest_list(platform, contests) function M.set_contest_summaries(platform, contests)
cache_data[platform] = cache_data[platform] or {} cache_data[platform] = cache_data[platform] or {}
for _, contest in ipairs(contests) do for _, contest in ipairs(contests) do
cache_data[platform][contest.id] = cache_data[platform][contest] or {} cache_data[platform][contest.id] = cache_data[platform][contest] or {}

View file

@ -6,6 +6,9 @@ local logger = require('cp.log')
local platforms = constants.PLATFORMS local platforms = constants.PLATFORMS
--- Dispatch any `:CP cache ...` command
---@param cmd table
---@return nil
function M.handle_cache_command(cmd) function M.handle_cache_command(cmd)
if cmd.subcommand == 'read' then if cmd.subcommand == 'read' then
local data = cache.get_data_pretty() local data = cache.get_data_pretty()

View file

@ -7,6 +7,19 @@ local state = require('cp.state')
local platforms = constants.PLATFORMS local platforms = constants.PLATFORMS
local actions = constants.ACTIONS local actions = constants.ACTIONS
---@class ParsedCommand
---@field type string
---@field error string?
---@field language? string
---@field debug? boolean
---@field action? string
---@field message? string
---@field contest? string
---@field platform? string
--- Turn raw args into normalized structure to later dispatch
---@param args string[] The raw command-line mode args
---@return ParsedCommand
local function parse_command(args) local function parse_command(args)
if vim.tbl_isempty(args) then if vim.tbl_isempty(args) then
return { return {
@ -94,6 +107,8 @@ local function parse_command(args)
return { type = 'error', message = 'Unknown command or no contest context' } return { type = 'error', message = 'Unknown command or no contest context' }
end end
--- Core logic for handling `:CP ...` commands
---@return nil
function M.handle_command(opts) function M.handle_command(opts)
local cmd = parse_command(opts.fargs) local cmd = parse_command(opts.fargs)
@ -105,10 +120,7 @@ function M.handle_command(opts)
if cmd.type == 'restore_from_file' then if cmd.type == 'restore_from_file' then
local restore = require('cp.restore') local restore = require('cp.restore')
restore.restore_from_current_file() restore.restore_from_current_file()
return elseif cmd.type == 'action' then
end
if cmd.type == 'action' then
local setup = require('cp.setup') local setup = require('cp.setup')
local ui = require('cp.ui.panel') local ui = require('cp.ui.panel')
@ -124,16 +136,10 @@ function M.handle_command(opts)
local picker = require('cp.commands.picker') local picker = require('cp.commands.picker')
picker.handle_pick_action() picker.handle_pick_action()
end end
return elseif cmd.type == 'cache' then
end
if cmd.type == 'cache' then
local cache_commands = require('cp.commands.cache') local cache_commands = require('cp.commands.cache')
cache_commands.handle_cache_command(cmd) cache_commands.handle_cache_command(cmd)
return elseif cmd.type == 'contest_setup' then
end
if cmd.type == 'contest_setup' then
local setup = require('cp.setup') local setup = require('cp.setup')
if setup.set_platform(cmd.platform) then if setup.set_platform(cmd.platform) then
setup.setup_contest(cmd.platform, cmd.contest, cmd.language, nil) setup.setup_contest(cmd.platform, cmd.contest, cmd.language, nil)

View file

@ -3,6 +3,8 @@ local M = {}
local config_module = require('cp.config') local config_module = require('cp.config')
local logger = require('cp.log') local logger = require('cp.log')
--- Dispatch `:CP pick` to appropriate picker
---@return nil
function M.handle_pick_action() function M.handle_pick_action()
local config = config_module.get_config() local config = config_module.get_config()

View file

@ -6,23 +6,11 @@
---@field version? number Language version ---@field version? number Language version
---@field extension? string File extension ---@field extension? string File extension
---@class PartialLanguageConfig
---@field compile? string[] Compile command template
---@field test? string[] Test execution command template
---@field debug? string[] Debug command template
---@field executable? string Executable name
---@field extension? string File extension
---@class ContestConfig ---@class ContestConfig
---@field cpp LanguageConfig ---@field cpp LanguageConfig
---@field python LanguageConfig ---@field python LanguageConfig
---@field default_language? string ---@field default_language? string
---@class PartialContestConfig
---@field cpp? PartialLanguageConfig
---@field python? PartialLanguageConfig
---@field default_language? string
---@class Hooks ---@class Hooks
---@field before_run? fun(state: cp.State) ---@field before_run? fun(state: cp.State)
---@field before_debug? fun(state: cp.State) ---@field before_debug? fun(state: cp.State)
@ -43,25 +31,25 @@
---@class cp.Config ---@class cp.Config
---@field contests table<string, ContestConfig> ---@field contests table<string, ContestConfig>
---@field snippets table[] ---@field snippets any[]
---@field hooks Hooks ---@field hooks Hooks
---@field debug boolean ---@field debug boolean
---@field scrapers table<string, boolean> ---@field scrapers string[]
---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string ---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string
---@field run_panel RunPanelConfig ---@field run_panel RunPanelConfig
---@field diff DiffConfig ---@field diff DiffConfig
---@field picker "telescope"|"fzf-lua"|nil ---@field picker string|nil
---@class cp.UserConfig ---@class cp.PartialConfig
---@field contests? table<string, PartialContestConfig> ---@field contests? table<string, ContestConfig>
---@field snippets? table[] ---@field snippets? any[]
---@field hooks? Hooks ---@field hooks? Hooks
---@field debug? boolean ---@field debug? boolean
---@field scrapers? table<string, boolean> ---@field scrapers? string[]
---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string ---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string
---@field run_panel? RunPanelConfig ---@field run_panel? RunPanelConfig
---@field diff? DiffConfig ---@field diff? DiffConfig
---@field picker? "telescope"|"fzf-lua"|nil ---@field picker? string|nil
local M = {} local M = {}
local constants = require('cp.constants') local constants = require('cp.constants')
@ -112,7 +100,7 @@ M.defaults = {
picker = nil, picker = nil,
} }
---@param user_config cp.UserConfig|nil ---@param user_config cp.PartialConfig|nil
---@return cp.Config ---@return cp.Config
function M.setup(user_config) function M.setup(user_config)
vim.validate({ vim.validate({
@ -279,10 +267,14 @@ M.default_filename = default_filename
local current_config = nil local current_config = nil
--- Set the config
---@return nil
function M.set_current_config(config) function M.set_current_config(config)
current_config = config current_config = config
end end
--- Get the config
---@return cp.Config
function M.get_config() function M.get_config()
return current_config or M.defaults return current_config or M.defaults
end end

View file

@ -13,6 +13,8 @@ local user_config = {}
local config = config_module.setup(user_config) local config = config_module.setup(user_config)
local snippets_initialized = false local snippets_initialized = false
--- Root handler for all `:CP ...` commands
---@return nil
function M.handle_command(opts) function M.handle_command(opts)
local commands = require('cp.commands') local commands = require('cp.commands')
commands.handle_command(opts) commands.handle_command(opts)

View file

@ -49,13 +49,13 @@ function M.get_platform_contests(platform, refresh)
cache.load() cache.load()
local picker_contests = cache.get_contest_list(platform) local picker_contests = cache.get_contest_summaries(platform)
if refresh or vim.tbl_isempty(picker_contests) then if refresh or vim.tbl_isempty(picker_contests) then
logger.log(('Cache miss on %s contests'):format(platform)) logger.log(('Cache miss on %s contests'):format(platform))
local contests = scraper.scrape_contest_list(platform) local contests = scraper.scrape_contest_list(platform)
cache.set_contest_list(platform, contests) cache.set_contest_summaries(platform, contests)
end end
logger.log( logger.log(
@ -64,7 +64,7 @@ function M.get_platform_contests(platform, refresh)
true true
) )
picker_contests = cache.get_contest_list(platform) picker_contests = cache.get_contest_summaries(platform)
return picker_contests return picker_contests
end end

View file

@ -1,10 +1,10 @@
---@class TestCase ---@class RanTestCase
---@field index number ---@field index number
---@field input string ---@field input string
---@field expected string ---@field expected string
---@field status "pending"|"pass"|"fail"|"running"|"timeout" ---@field status "pending"|"pass"|"fail"|"running"|"timeout"
---@field actual string? ---@field actual string?
---@field actual_highlights table[]? ---@field actual_highlights? Highlight[]
---@field time_ms number? ---@field time_ms number?
---@field error string? ---@field error string?
---@field stderr string? ---@field stderr string?
@ -19,7 +19,7 @@
---@field memory_mb number ---@field memory_mb number
---@class RunPanelState ---@class RunPanelState
---@field test_cases TestCase[] ---@field test_cases RanTestCase[]
---@field current_index number ---@field current_index number
---@field buffer number? ---@field buffer number?
---@field namespace number? ---@field namespace number?
@ -45,7 +45,7 @@ local run_panel_state = {
---@param index number ---@param index number
---@param input string ---@param input string
---@param expected string ---@param expected string
---@return TestCase ---@return RanTestCase
local function create_test_case(index, input, expected) local function create_test_case(index, input, expected)
return { return {
index = index, index = index,
@ -62,7 +62,7 @@ end
---@param platform string ---@param platform string
---@param contest_id string ---@param contest_id string
---@param problem_id string? ---@param problem_id string?
---@return TestCase[] ---@return RanTestCase[]
local function parse_test_cases_from_cache(platform, contest_id, problem_id) local function parse_test_cases_from_cache(platform, contest_id, problem_id)
local cache = require('cp.cache') local cache = require('cp.cache')
cache.load() cache.load()
@ -103,13 +103,13 @@ local function load_constraints_from_cache(platform, contest_id, problem_id)
end end
---@param contest_config ContestConfig ---@param contest_config ContestConfig
---@param test_case TestCase ---@param test_case RanTestCase
---@return table ---@return table
local function run_single_test_case(contest_config, cp_config, test_case) local function run_single_test_case(contest_config, cp_config, test_case)
local state = require('cp.state') local state = require('cp.state')
local source_file = state.get_source_file() local source_file = state.get_source_file()
local language = vim.fn.fnamemodify(source_file, ':e') local language = vim.fn.fnamemodify(source_file or '', ':e')
local language_name = constants.filetype_to_language[language] or contest_config.default_language local language_name = constants.filetype_to_language[language] or contest_config.default_language
local language_config = contest_config[language_name] local language_config = contest_config[language_name]
@ -297,7 +297,7 @@ end
---@param contest_config ContestConfig ---@param contest_config ContestConfig
---@param cp_config cp.Config ---@param cp_config cp.Config
---@return TestCase[] ---@return RanTestCase[]
function M.run_all_test_cases(contest_config, cp_config) function M.run_all_test_cases(contest_config, cp_config)
local results = {} local results = {}
for i, _ in ipairs(run_panel_state.test_cases) do for i, _ in ipairs(run_panel_state.test_cases) do

View file

@ -23,22 +23,22 @@ local exit_code_names = {
[143] = 'SIGCHLD', [143] = 'SIGCHLD',
} }
---@param test_case TestCase ---@param ran_test_case RanTestCase
---@return StatusInfo ---@return StatusInfo
function M.get_status_info(test_case) function M.get_status_info(ran_test_case)
if test_case.status == 'pass' then if ran_test_case.status == 'pass' then
return { text = 'AC', highlight_group = 'CpTestAC' } return { text = 'AC', highlight_group = 'CpTestAC' }
elseif test_case.status == 'fail' then elseif ran_test_case.status == 'fail' then
if test_case.timed_out then if ran_test_case.timed_out then
return { text = 'TLE', highlight_group = 'CpTestTLE' } return { text = 'TLE', highlight_group = 'CpTestTLE' }
elseif test_case.code and test_case.code >= 128 then elseif ran_test_case.code and ran_test_case.code >= 128 then
return { text = 'RTE', highlight_group = 'CpTestRTE' } return { text = 'RTE', highlight_group = 'CpTestRTE' }
else else
return { text = 'WA', highlight_group = 'CpTestWA' } return { text = 'WA', highlight_group = 'CpTestWA' }
end end
elseif test_case.status == 'timeout' then elseif ran_test_case.status == 'timeout' then
return { text = 'TLE', highlight_group = 'CpTestTLE' } return { text = 'TLE', highlight_group = 'CpTestTLE' }
elseif test_case.status == 'running' then elseif ran_test_case.status == 'running' then
return { text = '...', highlight_group = 'CpTestPending' } return { text = '...', highlight_group = 'CpTestPending' }
else else
return { text = '', highlight_group = 'CpTestPending' } return { text = '', highlight_group = 'CpTestPending' }
@ -278,7 +278,7 @@ local function data_row(c, idx, tc, is_current, test_state)
end end
---@param test_state RunPanelState ---@param test_state RunPanelState
---@return string[], table[] lines and highlight positions ---@return string[], Highlight[] lines and highlight positions
function M.render_test_list(test_state) function M.render_test_list(test_state)
local lines, highlights = {}, {} local lines, highlights = {}, {}
local c = compute_cols(test_state) local c = compute_cols(test_state)
@ -332,18 +332,18 @@ function M.render_test_list(test_state)
return lines, highlights return lines, highlights
end end
---@param test_case TestCase? ---@param ran_test_case RanTestCase?
---@return string ---@return string
function M.render_status_bar(test_case) function M.render_status_bar(ran_test_case)
if not test_case then if not ran_test_case then
return '' return ''
end end
local parts = {} local parts = {}
if test_case.time_ms then if ran_test_case.time_ms then
table.insert(parts, string.format('%.2fms', test_case.time_ms)) table.insert(parts, string.format('%.2fms', ran_test_case.time_ms))
end end
if test_case.code then if ran_test_case.code then
table.insert(parts, string.format('Exit: %d', test_case.code)) table.insert(parts, string.format('Exit: %d', ran_test_case.code))
end end
return table.concat(parts, '') return table.concat(parts, '')
end end

View file

@ -7,11 +7,6 @@
---@field set_problem_id fun(problem_id: string) ---@field set_problem_id fun(problem_id: string)
---@field get_active_panel fun(): string? ---@field get_active_panel fun(): string?
---@field set_active_panel fun(): string? ---@field set_active_panel fun(): string?
---@field get_saved_session fun(): table?
---@field set_saved_session fun(session: table)
---@field get_context fun(): {platform: string?, contest_id: string?, problem_id: string?}
---@field has_context fun(): boolean
---@field reset fun()
---@field get_base_name fun(): string? ---@field get_base_name fun(): string?
---@field get_source_file fun(language?: string): string? ---@field get_source_file fun(language?: string): string?
---@field get_binary_file fun(): string? ---@field get_binary_file fun(): string?
@ -54,14 +49,6 @@ function M.set_problem_id(problem_id)
state.problem_id = problem_id state.problem_id = problem_id
end end
function M.get_saved_session()
return state.saved_session
end
function M.set_saved_session(session)
state.saved_session = session
end
function M.get_base_name() function M.get_base_name()
local platform, contest_id, problem_id = M.get_platform(), M.get_contest_id(), M.get_problem_id() local platform, contest_id, problem_id = M.get_platform(), M.get_contest_id(), M.get_problem_id()
if not platform or not contest_id or not problem_id then if not platform or not contest_id or not problem_id then
@ -78,14 +65,6 @@ function M.get_base_name()
end end
end end
function M.get_context()
return {
platform = state.platform,
contest_id = state.contest_id,
problem_id = state.problem_id,
}
end
function M.get_source_file(language) function M.get_source_file(language)
local base_name = M.get_base_name() local base_name = M.get_base_name()
if not base_name or not M.get_platform() then if not base_name or not M.get_platform() then
@ -127,10 +106,6 @@ function M.get_expected_file()
return base_name and ('io/%s.expected'):format(base_name) or nil return base_name and ('io/%s.expected'):format(base_name) or nil
end end
function M.has_context()
return state.platform and state.contest_id
end
function M.get_active_panel() function M.get_active_panel()
return state.active_panel return state.active_panel
end end
@ -139,13 +114,4 @@ function M.set_active_panel(panel)
state.active_panel = panel state.active_panel = panel
end end
function M.reset()
state.platform = nil
state.contest_id = nil
state.problem_id = nil
state.test_cases = nil
state.run_panel_active = false
state.saved_session = nil
end
return M return M

View file

@ -1,6 +1,12 @@
---@class AnsiParseResult ---@class AnsiParseResult
---@field lines string[] ---@field lines string[]
---@field highlights table[] ---@field highlights Highlight[]
---@class Highlight
---@field line number
---@field col_start number
---@field col_end number
---@field highlight_group string
local M = {} local M = {}

View file

@ -1,6 +1,6 @@
---@class DiffResult ---@class DiffResult
---@field content string[] ---@field content string[]
---@field highlights table[]? ---@field highlights Highlight[]?
---@field raw_diff string? ---@field raw_diff string?
---@class DiffBackend ---@class DiffBackend

View file

@ -62,7 +62,10 @@ function M.toggle_interactive()
local cache = require('cp.cache') local cache = require('cp.cache')
cache.load() cache.load()
local contest_data = cache.get_contest_data(platform, contest_id) local contest_data = cache.get_contest_data(platform, contest_id)
if contest_data and not contest_data.interactive then if
contest_data
and not contest_data.problems[contest_data.index_map[state.get_problem_id()]].interactive
then
logger.log('This is NOT an interactive problem. Use :CP run instead.', vim.log.levels.WARN) logger.log('This is NOT an interactive problem. Use :CP run instead.', vim.log.levels.WARN)
return return
end end
@ -154,7 +157,10 @@ function M.toggle_run_panel(is_debug)
local cache = require('cp.cache') local cache = require('cp.cache')
cache.load() cache.load()
local contest_data = cache.get_contest_data(platform, contest_id) local contest_data = cache.get_contest_data(platform, contest_id)
if contest_data and contest_data.interactive then if
contest_data
and contest_data.problems[contest_data.index_map[state.get_problem_id()]].interactive
then
logger.log('This is an interactive problem. Use :CP interact instead.', vim.log.levels.WARN) logger.log('This is an interactive problem. Use :CP interact instead.', vim.log.levels.WARN)
return return
end end