fix(lua): bunch of typing

This commit is contained in:
Barrett Ruth 2025-10-02 14:18:26 -04:00
parent 057b0890c2
commit 1974addbd2
10 changed files with 79 additions and 75 deletions

View file

@ -6,10 +6,9 @@
---@class ContestData
---@field problems Problem[]
---@field test_cases? CachedTestCase[]
---@field timeout_ms? number
---@field memory_mb? number
---@field interactive? boolean
---@field index_map table<string, number>
---@field name string
---@field display_name string
---@class ContestSummary
---@field display_name string
@ -19,19 +18,20 @@
---@class CacheData
---@field [string] table<string, ContestData>
---@field file_states? table<string, FileState>
---@field contest_lists? table<string, ContestListData>
---@class ContestListData
---@field contests table[]
---@field contest_lists? table<string, ContestSummary>
---@class Problem
---@field id string
---@field name? string
---@field interactive? boolean
---@field memory_mb? number
---@field timeout_ms? number
---@field test_cases TestCase[]
---@class CachedTestCase
---@class TestCase
---@field index? number
---@field input string
---@field expected? string
---@field input? string
---@field output? string
local M = {}
@ -164,7 +164,7 @@ end
---@param platform string
---@param contest_id string
---@param problem_id? string
---@return CachedTestCase[]?
---@return TestCase[]
function M.get_test_cases(platform, contest_id, problem_id)
vim.validate({
platform = { platform, 'string' },
@ -178,8 +178,7 @@ function M.get_test_cases(platform, contest_id, problem_id)
or not cache_data[platform][contest_id].problems
or not cache_data[platform][contest_id].index_map
then
print('bad, failing')
return nil
return {}
end
local index = cache_data[platform][contest_id].index_map[problem_id]
@ -189,7 +188,7 @@ end
---@param platform string
---@param contest_id string
---@param problem_id string
---@param test_cases CachedTestCase[]
---@param test_cases TestCase[]
---@param timeout_ms? number
---@param memory_mb? number
---@param interactive? boolean
@ -269,7 +268,7 @@ function M.set_file_state(file_path, platform, contest_id, problem_id, language)
end
---@param platform string
---@return table[ContestSummary]
---@return ContestSummary[]
function M.get_contest_summaries(platform)
local contest_list = {}
for contest_id, contest_data in pairs(cache_data[platform] or {}) do
@ -283,7 +282,7 @@ function M.get_contest_summaries(platform)
end
---@param platform string
---@param contests table[ContestSummary]
---@param contests ContestSummary[]
function M.set_contest_summaries(platform, contests)
cache_data[platform] = cache_data[platform] or {}
for _, contest in ipairs(contests) do

View file

@ -6,6 +6,9 @@ local logger = require('cp.log')
local platforms = constants.PLATFORMS
--- Dispatch any `:CP cache ...` command
---@param cmd table
---@return nil
function M.handle_cache_command(cmd)
if cmd.subcommand == 'read' then
local data = cache.get_data_pretty()

View file

@ -7,6 +7,19 @@ local state = require('cp.state')
local platforms = constants.PLATFORMS
local actions = constants.ACTIONS
---@class ParsedCommand
---@field type string
---@field error string?
---@field language? string
---@field debug? boolean
---@field action? string
---@field message? string
---@field contest? string
---@field platform? string
--- Turn raw args into normalized structure to later dispatch
---@param args string[] The raw command-line mode args
---@return ParsedCommand
local function parse_command(args)
if vim.tbl_isempty(args) then
return {
@ -94,6 +107,8 @@ local function parse_command(args)
return { type = 'error', message = 'Unknown command or no contest context' }
end
--- Core logic for handling `:CP ...` commands
---@return nil
function M.handle_command(opts)
local cmd = parse_command(opts.fargs)
@ -105,10 +120,7 @@ function M.handle_command(opts)
if cmd.type == 'restore_from_file' then
local restore = require('cp.restore')
restore.restore_from_current_file()
return
end
if cmd.type == 'action' then
elseif cmd.type == 'action' then
local setup = require('cp.setup')
local ui = require('cp.ui.panel')
@ -124,16 +136,10 @@ function M.handle_command(opts)
local picker = require('cp.commands.picker')
picker.handle_pick_action()
end
return
end
if cmd.type == 'cache' then
elseif cmd.type == 'cache' then
local cache_commands = require('cp.commands.cache')
cache_commands.handle_cache_command(cmd)
return
end
if cmd.type == 'contest_setup' then
elseif cmd.type == 'contest_setup' then
local setup = require('cp.setup')
if setup.set_platform(cmd.platform) then
setup.setup_contest(cmd.platform, cmd.contest, cmd.language, nil)

View file

@ -3,6 +3,8 @@ local M = {}
local config_module = require('cp.config')
local logger = require('cp.log')
--- Dispatch `:CP pick` to appropriate picker
---@return nil
function M.handle_pick_action()
local config = config_module.get_config()

View file

@ -6,23 +6,11 @@
---@field version? number Language version
---@field extension? string File extension
---@class PartialLanguageConfig
---@field compile? string[] Compile command template
---@field test? string[] Test execution command template
---@field debug? string[] Debug command template
---@field executable? string Executable name
---@field extension? string File extension
---@class ContestConfig
---@field cpp LanguageConfig
---@field python LanguageConfig
---@field default_language? string
---@class PartialContestConfig
---@field cpp? PartialLanguageConfig
---@field python? PartialLanguageConfig
---@field default_language? string
---@class Hooks
---@field before_run? fun(state: cp.State)
---@field before_debug? fun(state: cp.State)
@ -43,25 +31,25 @@
---@class cp.Config
---@field contests table<string, ContestConfig>
---@field snippets table[]
---@field snippets any[]
---@field hooks Hooks
---@field debug boolean
---@field scrapers table<string, boolean>
---@field scrapers string[]
---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string
---@field run_panel RunPanelConfig
---@field diff DiffConfig
---@field picker "telescope"|"fzf-lua"|nil
---@field picker string|nil
---@class cp.UserConfig
---@field contests? table<string, PartialContestConfig>
---@field snippets? table[]
---@class cp.PartialConfig
---@field contests? table<string, ContestConfig>
---@field snippets? any[]
---@field hooks? Hooks
---@field debug? boolean
---@field scrapers? table<string, boolean>
---@field scrapers? string[]
---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string
---@field run_panel? RunPanelConfig
---@field diff? DiffConfig
---@field picker? "telescope"|"fzf-lua"|nil
---@field picker? string|nil
local M = {}
local constants = require('cp.constants')
@ -112,7 +100,7 @@ M.defaults = {
picker = nil,
}
---@param user_config cp.UserConfig|nil
---@param user_config cp.PartialConfig|nil
---@return cp.Config
function M.setup(user_config)
vim.validate({
@ -279,10 +267,14 @@ M.default_filename = default_filename
local current_config = nil
--- Set the config
---@return nil
function M.set_current_config(config)
current_config = config
end
--- Get the config
---@return cp.Config
function M.get_config()
return current_config or M.defaults
end

View file

@ -13,6 +13,8 @@ local user_config = {}
local config = config_module.setup(user_config)
local snippets_initialized = false
--- Root handler for all `:CP ...` commands
---@return nil
function M.handle_command(opts)
local commands = require('cp.commands')
commands.handle_command(opts)

View file

@ -1,10 +1,10 @@
---@class TestCase
---@class RanTestCase
---@field index number
---@field input string
---@field expected string
---@field status "pending"|"pass"|"fail"|"running"|"timeout"
---@field actual string?
---@field actual_highlights table[]?
---@field actual_highlights? any[]
---@field time_ms number?
---@field error string?
---@field stderr string?
@ -19,7 +19,7 @@
---@field memory_mb number
---@class RunPanelState
---@field test_cases TestCase[]
---@field test_cases RanTestCase[]
---@field current_index number
---@field buffer number?
---@field namespace number?
@ -45,7 +45,7 @@ local run_panel_state = {
---@param index number
---@param input string
---@param expected string
---@return TestCase
---@return RanTestCase
local function create_test_case(index, input, expected)
return {
index = index,
@ -62,7 +62,7 @@ end
---@param platform string
---@param contest_id string
---@param problem_id string?
---@return TestCase[]
---@return RanTestCase[]
local function parse_test_cases_from_cache(platform, contest_id, problem_id)
local cache = require('cp.cache')
cache.load()
@ -103,13 +103,13 @@ local function load_constraints_from_cache(platform, contest_id, problem_id)
end
---@param contest_config ContestConfig
---@param test_case TestCase
---@param test_case RanTestCase
---@return table
local function run_single_test_case(contest_config, cp_config, test_case)
local state = require('cp.state')
local source_file = state.get_source_file()
local language = vim.fn.fnamemodify(source_file, ':e')
local language = vim.fn.fnamemodify(source_file or '', ':e')
local language_name = constants.filetype_to_language[language] or contest_config.default_language
local language_config = contest_config[language_name]
@ -297,7 +297,7 @@ end
---@param contest_config ContestConfig
---@param cp_config cp.Config
---@return TestCase[]
---@return RanTestCase[]
function M.run_all_test_cases(contest_config, cp_config)
local results = {}
for i, _ in ipairs(run_panel_state.test_cases) do

View file

@ -23,22 +23,22 @@ local exit_code_names = {
[143] = 'SIGCHLD',
}
---@param test_case TestCase
---@param ran_test_case RanTestCase
---@return StatusInfo
function M.get_status_info(test_case)
if test_case.status == 'pass' then
function M.get_status_info(ran_test_case)
if ran_test_case.status == 'pass' then
return { text = 'AC', highlight_group = 'CpTestAC' }
elseif test_case.status == 'fail' then
if test_case.timed_out then
elseif ran_test_case.status == 'fail' then
if ran_test_case.timed_out then
return { text = 'TLE', highlight_group = 'CpTestTLE' }
elseif test_case.code and test_case.code >= 128 then
elseif ran_test_case.code and ran_test_case.code >= 128 then
return { text = 'RTE', highlight_group = 'CpTestRTE' }
else
return { text = 'WA', highlight_group = 'CpTestWA' }
end
elseif test_case.status == 'timeout' then
elseif ran_test_case.status == 'timeout' then
return { text = 'TLE', highlight_group = 'CpTestTLE' }
elseif test_case.status == 'running' then
elseif ran_test_case.status == 'running' then
return { text = '...', highlight_group = 'CpTestPending' }
else
return { text = '', highlight_group = 'CpTestPending' }
@ -278,7 +278,7 @@ local function data_row(c, idx, tc, is_current, test_state)
end
---@param test_state RunPanelState
---@return string[], table[] lines and highlight positions
---@return string[], any[] lines and highlight positions
function M.render_test_list(test_state)
local lines, highlights = {}, {}
local c = compute_cols(test_state)
@ -332,18 +332,18 @@ function M.render_test_list(test_state)
return lines, highlights
end
---@param test_case TestCase?
---@param ran_test_case RanTestCase?
---@return string
function M.render_status_bar(test_case)
if not test_case then
function M.render_status_bar(ran_test_case)
if not ran_test_case then
return ''
end
local parts = {}
if test_case.time_ms then
table.insert(parts, string.format('%.2fms', test_case.time_ms))
if ran_test_case.time_ms then
table.insert(parts, string.format('%.2fms', ran_test_case.time_ms))
end
if test_case.code then
table.insert(parts, string.format('Exit: %d', test_case.code))
if ran_test_case.code then
table.insert(parts, string.format('Exit: %d', ran_test_case.code))
end
return table.concat(parts, '')
end

View file

@ -1,6 +1,6 @@
---@class AnsiParseResult
---@field lines string[]
---@field highlights table[]
---@field highlights any[]
local M = {}

View file

@ -1,6 +1,6 @@
---@class DiffResult
---@field content string[]
---@field highlights table[]?
---@field highlights any[]?
---@field raw_diff string?
---@class DiffBackend