diff --git a/doc/cp.txt b/doc/cp.txt index 375adbd..ff67c4b 100644 --- a/doc/cp.txt +++ b/doc/cp.txt @@ -102,8 +102,11 @@ Optional configuration with lazy.nvim: > codeforces = { cpp = { version = 23 } }, }, hooks = { - before_run = function(problem_id) vim.cmd.w() end, - before_debug = function(problem_id) ... end, + before_run = function(ctx) vim.cmd.w() end, + before_debug = function(ctx) + -- ctx.problem_id, ctx.platform, ctx.source_file, etc. + vim.cmd.w() + end, }, snippets = { ... }, -- LuaSnip snippets tile = function(source_buf, input_buf, output_buf) ... end, @@ -136,10 +139,20 @@ snippets LuaSnip snippets by contest type hooks Functions called at specific events before_run Called before :CP run - function(problem_id) + function(ctx) + ctx contains: + - problem_id: string + - platform: string (atcoder/codeforces/cses) + - contest_id: string + - source_file: string (path to source) + - input_file: string (path to .cpin) + - output_file: string (path to .cpout) + - expected_file: string (path to .expected) + - contest_config: table (language configs) (default: nil, do nothing) before_debug Called before :CP debug - function(problem_id) + function(ctx) + Same ctx as before_run (default: nil, do nothing) debug Show info messages during operation diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 565afc4..516ddb3 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -1,16 +1,48 @@ +---@class CacheData +---@field [string] table + +---@class ContestData +---@field problems Problem[] +---@field scraped_at string +---@field expires_at? number +---@field test_cases? TestCase[] +---@field test_cases_cached_at? number + +---@class Problem +---@field id string +---@field name? string + +---@class TestCase +---@field input string +---@field output string + local M = {} local cache_file = vim.fn.stdpath("data") .. "/cp-nvim.json" local cache_data = {} +---@param platform string +---@return number? local function get_expiry_date(platform) + vim.validate({ + platform = { platform, "string" }, + }) + if platform == "cses" then return os.time() + (30 * 24 * 60 * 60) end return nil end +---@param contest_data ContestData +---@param platform string +---@return boolean local function is_cache_valid(contest_data, platform) + vim.validate({ + contest_data = { contest_data, "table" }, + platform = { platform, "string" }, + }) + if platform ~= "cses" then return true end @@ -49,7 +81,15 @@ function M.save() vim.fn.writefile(vim.split(encoded, "\n"), cache_file) end +---@param platform string +---@param contest_id string +---@return ContestData? function M.get_contest_data(platform, contest_id) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + }) + if not cache_data[platform] then return nil end @@ -66,7 +106,16 @@ function M.get_contest_data(platform, contest_id) return contest_data end +---@param platform string +---@param contest_id string +---@param problems Problem[] function M.set_contest_data(platform, contest_id, problems) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + problems = { problems, "table" }, + }) + if not cache_data[platform] then cache_data[platform] = {} end @@ -80,14 +129,31 @@ function M.set_contest_data(platform, contest_id, problems) M.save() end +---@param platform string +---@param contest_id string function M.clear_contest_data(platform, contest_id) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + }) + if cache_data[platform] and cache_data[platform][contest_id] then cache_data[platform][contest_id] = nil M.save() end end +---@param platform string +---@param contest_id string +---@param problem_id? string +---@return TestCase[]? function M.get_test_cases(platform, contest_id, problem_id) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + problem_id = { problem_id, { "string", "nil" }, true }, + }) + local problem_key = problem_id and (contest_id .. "_" .. problem_id) or contest_id if not cache_data[platform] or not cache_data[platform][problem_key] then return nil @@ -95,7 +161,18 @@ function M.get_test_cases(platform, contest_id, problem_id) return cache_data[platform][problem_key].test_cases end +---@param platform string +---@param contest_id string +---@param problem_id? string +---@param test_cases TestCase[] function M.set_test_cases(platform, contest_id, problem_id, test_cases) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + problem_id = { problem_id, { "string", "nil" }, true }, + test_cases = { test_cases, "table" }, + }) + local problem_key = problem_id and (contest_id .. "_" .. problem_id) or contest_id if not cache_data[platform] then cache_data[platform] = {} diff --git a/lua/cp/config.lua b/lua/cp/config.lua index ff6cdcd..a05fc5c 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -12,10 +12,24 @@ ---@field default_language string ---@field timeout_ms number +---@class HookContext +---@field problem_id string +---@field platform string +---@field contest_id string +---@field source_file string +---@field input_file string +---@field output_file string +---@field expected_file string +---@field contest_config table + +---@class Hooks +---@field before_run? fun(ctx: HookContext) +---@field before_debug? fun(ctx: HookContext) + ---@class cp.Config ---@field contests table ---@field snippets table[] ----@field hooks table +---@field hooks Hooks ---@field debug boolean ---@field tile? fun(source_buf: number, input_buf: number, output_buf: number) ---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string diff --git a/lua/cp/execute.lua b/lua/cp/execute.lua index 00e9332..14724db 100644 --- a/lua/cp/execute.lua +++ b/lua/cp/execute.lua @@ -1,3 +1,10 @@ +---@class ExecuteResult +---@field stdout string +---@field stderr string +---@field code integer +---@field time_ms number +---@field timed_out boolean + local M = {} local logger = require("cp.log") @@ -10,14 +17,30 @@ local filetype_to_language = { py3 = "python", } -local function get_language_from_file(source_file) +---@param source_file string +---@param contest_config table +---@return string +local function get_language_from_file(source_file, contest_config) + vim.validate({ + source_file = { source_file, "string" }, + contest_config = { contest_config, "table" }, + }) + local extension = vim.fn.fnamemodify(source_file, ":e") - local language = filetype_to_language[extension] or "cpp" + local language = filetype_to_language[extension] or contest_config.default_language logger.log(("detected language: %s (extension: %s)"):format(language, extension)) return language end +---@param cmd_template string[] +---@param substitutions table +---@return string[] local function substitute_template(cmd_template, substitutions) + vim.validate({ + cmd_template = { cmd_template, "table" }, + substitutions = { substitutions, "table" }, + }) + local result = {} for _, arg in ipairs(cmd_template) do local substituted = arg @@ -29,7 +52,17 @@ local function substitute_template(cmd_template, substitutions) return result end +---@param cmd_template string[] +---@param executable? string +---@param substitutions table +---@return string[] local function build_command(cmd_template, executable, substitutions) + vim.validate({ + cmd_template = { cmd_template, "table" }, + executable = { executable, { "string", "nil" }, true }, + substitutions = { substitutions, "table" }, + }) + local cmd = substitute_template(cmd_template, substitutions) if executable then table.insert(cmd, 1, executable) @@ -59,7 +92,15 @@ local function ensure_directories() vim.system({ "mkdir", "-p", "build", "io" }):wait() end +---@param language_config table +---@param substitutions table +---@return {code: integer, stderr: string} local function compile_generic(language_config, substitutions) + vim.validate({ + language_config = { language_config, "table" }, + substitutions = { substitutions, "table" }, + }) + if not language_config.compile then logger.log("no compilation step required") return { code = 0, stderr = "" } @@ -81,7 +122,17 @@ local function compile_generic(language_config, substitutions) return result end +---@param cmd string[] +---@param input_data string +---@param timeout_ms integer +---@return ExecuteResult local function execute_command(cmd, input_data, timeout_ms) + vim.validate({ + cmd = { cmd, "table" }, + input_data = { input_data, "string" }, + timeout_ms = { timeout_ms, "number" }, + }) + logger.log(("executing: %s"):format(table.concat(cmd, " "))) local start_time = vim.loop.hrtime() @@ -114,7 +165,17 @@ local function execute_command(cmd, input_data, timeout_ms) } end +---@param exec_result ExecuteResult +---@param expected_file string +---@param is_debug boolean +---@return string local function format_output(exec_result, expected_file, is_debug) + vim.validate({ + exec_result = { exec_result, "table" }, + expected_file = { expected_file, "string" }, + is_debug = { is_debug, "boolean" }, + }) + local output_lines = { exec_result.stdout } local metadata_lines = {} @@ -158,9 +219,15 @@ end ---@param contest_config table ---@param is_debug boolean function M.run_problem(ctx, contest_config, is_debug) + vim.validate({ + ctx = { ctx, "table" }, + contest_config = { contest_config, "table" }, + is_debug = { is_debug, "boolean" }, + }) + ensure_directories() - local language = get_language_from_file(ctx.source_file) + local language = get_language_from_file(ctx.source_file, contest_config) local language_config = contest_config[language] if not language_config then @@ -171,7 +238,7 @@ function M.run_problem(ctx, contest_config, is_debug) local substitutions = { source = ctx.source_file, binary = ctx.binary_file, - version = tostring(language_config.version or ""), + version = tostring(language_config.version), } local compile_cmd = is_debug and language_config.debug or language_config.compile diff --git a/lua/cp/init.lua b/lua/cp/init.lua index f6fe548..3375f00 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -162,19 +162,28 @@ local function run_problem() logger.log(("running problem: %s"):format(problem_id)) - if config.hooks and config.hooks.before_run then - config.hooks.before_run(problem_id) - end - if not state.platform then logger.log("no platform set", vim.log.levels.ERROR) return end local contest_config = config.contests[state.platform] + local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) + + if config.hooks and config.hooks.before_run then + config.hooks.before_run({ + problem_id = problem_id, + platform = state.platform, + contest_id = state.contest_id, + source_file = ctx.source_file, + input_file = ctx.input_file, + output_file = ctx.output_file, + expected_file = ctx.expected_file, + contest_config = contest_config, + }) + end vim.schedule(function() - local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) execute.run_problem(ctx, contest_config, false) vim.cmd.checktime() end) @@ -186,19 +195,28 @@ local function debug_problem() return end - if config.hooks and config.hooks.before_debug then - config.hooks.before_debug(problem_id) - end - if not state.platform then logger.log("no platform set", vim.log.levels.ERROR) return end local contest_config = config.contests[state.platform] + local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) + + if config.hooks and config.hooks.before_debug then + config.hooks.before_debug({ + problem_id = problem_id, + platform = state.platform, + contest_id = state.contest_id, + source_file = ctx.source_file, + input_file = ctx.input_file, + output_file = ctx.output_file, + expected_file = ctx.expected_file, + contest_config = contest_config, + }) + end vim.schedule(function() - local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) execute.run_problem(ctx, contest_config, true) vim.cmd.checktime() end) diff --git a/lua/cp/problem.lua b/lua/cp/problem.lua index 6904d27..caeac2c 100644 --- a/lua/cp/problem.lua +++ b/lua/cp/problem.lua @@ -18,6 +18,14 @@ local M = {} ---@param language? string ---@return ProblemContext function M.create_context(contest, contest_id, problem_id, config, language) + vim.validate({ + contest = { contest, "string" }, + contest_id = { contest_id, "string" }, + problem_id = { problem_id, { "string", "nil" }, true }, + config = { config, "table" }, + language = { language, { "string", "nil" }, true }, + }) + local filename_fn = config.filename or require("cp.config").default_filename local source_file = filename_fn(contest, contest_id, problem_id, config, language) local base_name = vim.fn.fnamemodify(source_file, ":t:r") diff --git a/lua/cp/scrape.lua b/lua/cp/scrape.lua index bde861a..bf40059 100644 --- a/lua/cp/scrape.lua +++ b/lua/cp/scrape.lua @@ -45,6 +45,11 @@ end ---@param contest_id string ---@return {success: boolean, problems?: table[], error?: string} function M.scrape_contest_metadata(platform, contest_id) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + }) + cache.load() local cached_data = cache.get_contest_data(platform, contest_id) @@ -121,6 +126,10 @@ end ---@param ctx ProblemContext ---@return {success: boolean, problem_id: string, test_count?: number, test_cases?: table[], url?: string, error?: string} function M.scrape_problem(ctx) + vim.validate({ + ctx = { ctx, "table" }, + }) + ensure_io_directory() if vim.fn.filereadable(ctx.input_file) == 1 and vim.fn.filereadable(ctx.expected_file) == 1 then diff --git a/lua/cp/window.lua b/lua/cp/window.lua index a62465d..72ebe8b 100644 --- a/lua/cp/window.lua +++ b/lua/cp/window.lua @@ -1,3 +1,14 @@ +---@class WindowState +---@field windows table +---@field current_win integer +---@field layout string + +---@class WindowData +---@field bufnr integer +---@field view table +---@field width integer +---@field height integer + local M = {} function M.clearcol() @@ -8,6 +19,7 @@ function M.clearcol() vim.api.nvim_set_option_value("foldcolumn", "0", { scope = "local" }) end +---@return WindowState function M.save_layout() local windows = {} for _, win in ipairs(vim.api.nvim_list_wins()) do @@ -29,7 +41,14 @@ function M.save_layout() } end +---@param state? WindowState +---@param tile_fn? fun(source_buf: integer, input_buf: integer, output_buf: integer) function M.restore_layout(state, tile_fn) + vim.validate({ + state = { state, { "table", "nil" }, true }, + tile_fn = { tile_fn, { "function", "nil" }, true }, + }) + if not state then return end @@ -56,7 +75,21 @@ function M.restore_layout(state, tile_fn) local input_file = ("%s/io/%s.in"):format(base_fp, problem_id) local output_file = ("%s/io/%s.out"):format(base_fp, problem_id) local source_files = vim.fn.glob(problem_id .. ".*") - local source_file = source_files ~= "" and vim.split(source_files, "\n")[1] or (problem_id .. ".cc") + local source_file + if source_files ~= "" then + local files = vim.split(source_files, "\n") + local valid_extensions = { "cc", "cpp", "cxx", "c", "py", "py3" } + for _, file in ipairs(files) do + local ext = vim.fn.fnamemodify(file, ":e") + if vim.tbl_contains(valid_extensions, ext) then + source_file = file + break + end + end + source_file = source_file or files[1] + else + source_file = problem_id .. ".cc" + end if vim.fn.filereadable(source_file) == 0 then return @@ -90,7 +123,16 @@ function M.restore_layout(state, tile_fn) end end +---@param actual_output string +---@param expected_output string +---@param input_file string function M.setup_diff_layout(actual_output, expected_output, input_file) + vim.validate({ + actual_output = { actual_output, "string" }, + expected_output = { expected_output, "string" }, + input_file = { input_file, "string" }, + }) + vim.cmd.diffoff() vim.cmd("silent only") @@ -117,7 +159,16 @@ function M.setup_diff_layout(actual_output, expected_output, input_file) vim.cmd.wincmd("k") end +---@param source_buf integer +---@param input_buf integer +---@param output_buf integer local function default_tile(source_buf, input_buf, output_buf) + vim.validate({ + source_buf = { source_buf, "number" }, + input_buf = { input_buf, "number" }, + output_buf = { output_buf, "number" }, + }) + vim.api.nvim_set_current_buf(source_buf) vim.cmd.vsplit() vim.api.nvim_set_current_buf(output_buf)