diff --git a/after/ftplugin/cpout.lua b/after/ftplugin/cpout.lua index 3f0bcdc..857a799 100644 --- a/after/ftplugin/cpout.lua +++ b/after/ftplugin/cpout.lua @@ -4,4 +4,4 @@ vim.opt_local.statuscolumn = "" vim.opt_local.signcolumn = "no" vim.opt_local.wrap = true vim.opt_local.linebreak = true -vim.opt_local.modifiable = false +vim.opt_local.modifiable = true diff --git a/doc/cp.txt b/doc/cp.txt index 306c70a..ff67c4b 100644 --- a/doc/cp.txt +++ b/doc/cp.txt @@ -27,11 +27,12 @@ cp.nvim uses a single :CP command with intelligent argument parsing: Setup Commands ~ -:CP {platform} {contest_id} {problem_id} +:CP {platform} {contest_id} {problem_id} [--lang={language}] Full setup: set platform, load contest metadata, and set up specific problem. Scrapes test cases and creates source file. Example: :CP codeforces 1933 a + Example: :CP codeforces 1933 a --lang=python :CP {platform} {contest_id} Contest setup: set platform and load contest metadata for navigation. Caches problem list. @@ -40,9 +41,11 @@ Setup Commands ~ :CP {platform} Platform setup: set platform only. Example: :CP cses -:CP {problem_id} Problem switch: switch to different problem +:CP {problem_id} [--lang={language}] + Problem switch: switch to different problem within current contest context. Example: :CP b (switch to problem b) + Example: :CP b --lang=python Action Commands ~ @@ -75,21 +78,39 @@ Optional configuration with lazy.nvim: > debug = false, contests = { default = { - cpp_version = 20, - compile_flags = { "-O2", "-DLOCAL", "-Wall", "-Wextra" }, - debug_flags = { "-g3", "-fsanitize=address,undefined", "-DLOCAL" }, + cpp = { + compile = { + 'g++', '-std=c++{version}', '-O2', '-Wall', '-Wextra', + '-DLOCAL', '{source}', '-o', '{binary}', + }, + run = { '{binary}' }, + debug = { + 'g++', '-std=c++{version}', '-g3', + '-fsanitize=address,undefined', '-DLOCAL', + '{source}', '-o', '{binary}', + }, + version = 20, + extension = "cc", + }, + python = { + run = { 'python3', '{source}' }, + debug = { 'python3', '{source}' }, + extension = "py", + }, timeout_ms = 2000, }, - atcoder = { cpp_version = 23 }, + codeforces = { cpp = { version = 23 } }, }, hooks = { - before_run = function(problem_id) vim.cmd.w() end, - before_debug = function(problem_id) ... end, + before_run = function(ctx) vim.cmd.w() end, + before_debug = function(ctx) + -- ctx.problem_id, ctx.platform, ctx.source_file, etc. + vim.cmd.w() + end, }, - tile = function(source_buf, input_buf, output_buf) - end, - filename = function(contest, problem_id, problem_letter) - end, + snippets = { ... }, -- LuaSnip snippets + tile = function(source_buf, input_buf, output_buf) ... end, + filename = function(contest, problem_id, problem_letter) ... end, } } < @@ -98,19 +119,40 @@ Configuration options: contests Dictionary of contest configurations - each contest inherits from 'default'. - cpp_version c++ standard version (e.g. 20, 23) - compile_flags compiler flags for run builds - debug_flags compiler flags for debug builds - timeout_ms duration (ms) to run/debug before timeout + cpp C++ language configuration + compile Compile command template with {version}, {source}, {binary} placeholders + run Run command template with {binary} placeholder + debug Debug compile command template + version C++ standard version (e.g. 20, 23) + extension File extension for C++ files (default: "cc") + + python Python language configuration + run Run command template with {source} placeholder + debug Debug run command template + extension File extension for Python files (default: "py") + + default_language Default language when --lang not specified (default: "cpp") + + timeout_ms Duration (ms) to run/debug before timeout snippets LuaSnip snippets by contest type hooks Functions called at specific events before_run Called before :CP run - function(problem_id) + function(ctx) + ctx contains: + - problem_id: string + - platform: string (atcoder/codeforces/cses) + - contest_id: string + - source_file: string (path to source) + - input_file: string (path to .cpin) + - output_file: string (path to .cpout) + - expected_file: string (path to .expected) + - contest_config: table (language configs) (default: nil, do nothing) before_debug Called before :CP debug - function(problem_id) + function(ctx) + Same ctx as before_run (default: nil, do nothing) debug Show info messages during operation @@ -221,8 +263,8 @@ cp.nvim creates the following file structure upon problem setup: build/ {contest_id}{problem_id}.run " Compiled binary io/ - {contest_id}{problem_id}.in " Test input - {contest_id}{problem_id}.out " Program output + {contest_id}{problem_id}.cpin " Test input + {contest_id}{problem_id}.cpout " Program output {contest_id}{problem_id}.expected " Expected output The plugin automatically manages this structure and navigation between problems @@ -233,9 +275,17 @@ SNIPPETS *cp-snippets* cp.nvim integrates with LuaSnip for automatic template expansion. When you open a new problem file, type the contest name and press to expand. -Built-in snippets include basic C++ templates for each contest type. +Built-in snippets include basic C++ and Python templates for each contest type. Custom snippets can be added via configuration. +IMPORTANT: Snippet trigger names must exactly match the contest/platform names: +- "codeforces" for Codeforces problems +- "atcoder" for AtCoder problems +- "cses" for CSES problems + +The plugin automatically selects the appropriate template based on the file +extension (e.g., .cc files get C++ templates, .py files get Python templates). + HEALTH CHECK *cp-health* Run |:checkhealth| cp to verify your setup. diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 565afc4..516ddb3 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -1,16 +1,48 @@ +---@class CacheData +---@field [string] table + +---@class ContestData +---@field problems Problem[] +---@field scraped_at string +---@field expires_at? number +---@field test_cases? TestCase[] +---@field test_cases_cached_at? number + +---@class Problem +---@field id string +---@field name? string + +---@class TestCase +---@field input string +---@field output string + local M = {} local cache_file = vim.fn.stdpath("data") .. "/cp-nvim.json" local cache_data = {} +---@param platform string +---@return number? local function get_expiry_date(platform) + vim.validate({ + platform = { platform, "string" }, + }) + if platform == "cses" then return os.time() + (30 * 24 * 60 * 60) end return nil end +---@param contest_data ContestData +---@param platform string +---@return boolean local function is_cache_valid(contest_data, platform) + vim.validate({ + contest_data = { contest_data, "table" }, + platform = { platform, "string" }, + }) + if platform ~= "cses" then return true end @@ -49,7 +81,15 @@ function M.save() vim.fn.writefile(vim.split(encoded, "\n"), cache_file) end +---@param platform string +---@param contest_id string +---@return ContestData? function M.get_contest_data(platform, contest_id) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + }) + if not cache_data[platform] then return nil end @@ -66,7 +106,16 @@ function M.get_contest_data(platform, contest_id) return contest_data end +---@param platform string +---@param contest_id string +---@param problems Problem[] function M.set_contest_data(platform, contest_id, problems) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + problems = { problems, "table" }, + }) + if not cache_data[platform] then cache_data[platform] = {} end @@ -80,14 +129,31 @@ function M.set_contest_data(platform, contest_id, problems) M.save() end +---@param platform string +---@param contest_id string function M.clear_contest_data(platform, contest_id) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + }) + if cache_data[platform] and cache_data[platform][contest_id] then cache_data[platform][contest_id] = nil M.save() end end +---@param platform string +---@param contest_id string +---@param problem_id? string +---@return TestCase[]? function M.get_test_cases(platform, contest_id, problem_id) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + problem_id = { problem_id, { "string", "nil" }, true }, + }) + local problem_key = problem_id and (contest_id .. "_" .. problem_id) or contest_id if not cache_data[platform] or not cache_data[platform][problem_key] then return nil @@ -95,7 +161,18 @@ function M.get_test_cases(platform, contest_id, problem_id) return cache_data[platform][problem_key].test_cases end +---@param platform string +---@param contest_id string +---@param problem_id? string +---@param test_cases TestCase[] function M.set_test_cases(platform, contest_id, problem_id, test_cases) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + problem_id = { problem_id, { "string", "nil" }, true }, + test_cases = { test_cases, "table" }, + }) + local problem_key = problem_id and (contest_id .. "_" .. problem_id) or contest_id if not cache_data[platform] then cache_data[platform] = {} diff --git a/lua/cp/config.lua b/lua/cp/config.lua index 48a245e..0404ab9 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -1,13 +1,70 @@ +---@class LanguageConfig +---@field compile? string[] Compile command template +---@field run string[] Run command template +---@field debug? string[] Debug command template +---@field executable? string Executable name +---@field version? number Language version +---@field extension string File extension + +---@class PartialLanguageConfig +---@field compile? string[] Compile command template +---@field run? string[] Run command template +---@field debug? string[] Debug command template +---@field executable? string Executable name +---@field version? number Language version +---@field extension? string File extension + +---@class ContestConfig +---@field cpp LanguageConfig +---@field python LanguageConfig +---@field default_language string +---@field timeout_ms number + +---@class PartialContestConfig +---@field cpp? PartialLanguageConfig +---@field python? PartialLanguageConfig +---@field default_language? string +---@field timeout_ms? number + +---@class HookContext +---@field problem_id string +---@field platform string +---@field contest_id string +---@field source_file string +---@field input_file string +---@field output_file string +---@field expected_file string +---@field contest_config table + +---@class Hooks +---@field before_run? fun(ctx: HookContext) +---@field before_debug? fun(ctx: HookContext) + ---@class cp.Config ----@field contests table ----@field snippets table ----@field hooks table +---@field contests table +---@field snippets table[] +---@field hooks Hooks ---@field debug boolean ---@field tile? fun(source_buf: number, input_buf: number, output_buf: number) ----@field filename? fun(contest: string, problem_id: string, problem_letter?: string): string +---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string + +---@class cp.UserConfig +---@field contests? table +---@field snippets? table[] +---@field hooks? Hooks +---@field debug? boolean +---@field tile? fun(source_buf: number, input_buf: number, output_buf: number) +---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string local M = {} +local filetype_to_language = { + cc = "cpp", + c = "cpp", + py = "python", + py3 = "python", +} + ---@type cp.Config M.defaults = { contests = { @@ -37,22 +94,31 @@ M.defaults = { }, executable = nil, version = 20, + extension = "cc", }, python = { compile = nil, run = { "{source}" }, debug = { "{source}" }, executable = "python3", + extension = "py", }, + default_language = "cpp", timeout_ms = 2000, }, + ---@type PartialContestConfig atcoder = { + ---@type PartialLanguageConfig cpp = { version = 23 }, }, + ---@type PartialContestConfig codeforces = { + ---@type PartialLanguageConfig cpp = { version = 23 }, }, + ---@type PartialContestConfig cses = { + ---@type PartialLanguageConfig cpp = { version = 20 }, }, }, @@ -74,8 +140,8 @@ local function extend_contest_config(base_config, contest_config) return result end ----@param user_config table|nil ----@return table +---@param user_config cp.UserConfig|nil +---@return cp.Config function M.setup(user_config) vim.validate({ user_config = { user_config, { "table", "nil" }, true }, @@ -97,6 +163,25 @@ function M.setup(user_config) before_debug = { user_config.hooks.before_debug, { "function", "nil" }, true }, }) end + + if user_config.contests then + for contest_name, contest_config in pairs(user_config.contests) do + for lang_name, lang_config in pairs(contest_config) do + if type(lang_config) == "table" and lang_config.extension then + if not vim.tbl_contains(vim.tbl_keys(filetype_to_language), lang_config.extension) then + error( + ("Invalid extension '%s' for language '%s' in contest '%s'. Valid extensions: %s"):format( + lang_config.extension, + lang_name, + contest_name, + table.concat(vim.tbl_keys(filetype_to_language), ", ") + ) + ) + end + end + end + end + end end local config = vim.tbl_deep_extend("force", M.defaults, user_config or {}) @@ -111,14 +196,32 @@ function M.setup(user_config) return config end -local function default_filename(contest, contest_id, problem_id) +---@param contest string +---@param contest_id string +---@param problem_id? string +---@param config cp.Config +---@param language? string +---@return string +local function default_filename(contest, contest_id, problem_id, config, language) + vim.validate({ + contest = { contest, "string" }, + contest_id = { contest_id, "string" }, + problem_id = { problem_id, { "string", "nil" }, true }, + config = { config, "table" }, + language = { language, { "string", "nil" }, true }, + }) + local full_problem_id = contest_id:lower() if contest == "atcoder" or contest == "codeforces" then if problem_id then full_problem_id = full_problem_id .. problem_id:lower() end end - return full_problem_id .. ".cc" + + local contest_config = config.contests[contest] or config.contests.default + local target_language = language or contest_config.default_language + local language_config = contest_config[target_language] + return full_problem_id .. "." .. language_config.extension end M.default_filename = default_filename diff --git a/lua/cp/execute.lua b/lua/cp/execute.lua index 00e9332..a0874f1 100644 --- a/lua/cp/execute.lua +++ b/lua/cp/execute.lua @@ -1,23 +1,40 @@ +---@class ExecuteResult +---@field stdout string +---@field stderr string +---@field code integer +---@field time_ms number +---@field timed_out boolean + local M = {} local logger = require("cp.log") -local filetype_to_language = { - cpp = "cpp", - cxx = "cpp", - cc = "cpp", - c = "cpp", - py = "python", - py3 = "python", -} +local languages = require("cp.languages") +local filetype_to_language = languages.filetype_to_language + +---@param source_file string +---@param contest_config table +---@return string +local function get_language_from_file(source_file, contest_config) + vim.validate({ + source_file = { source_file, "string" }, + contest_config = { contest_config, "table" }, + }) -local function get_language_from_file(source_file) local extension = vim.fn.fnamemodify(source_file, ":e") - local language = filetype_to_language[extension] or "cpp" + local language = filetype_to_language[extension] or contest_config.default_language logger.log(("detected language: %s (extension: %s)"):format(language, extension)) return language end +---@param cmd_template string[] +---@param substitutions table +---@return string[] local function substitute_template(cmd_template, substitutions) + vim.validate({ + cmd_template = { cmd_template, "table" }, + substitutions = { substitutions, "table" }, + }) + local result = {} for _, arg in ipairs(cmd_template) do local substituted = arg @@ -29,7 +46,17 @@ local function substitute_template(cmd_template, substitutions) return result end +---@param cmd_template string[] +---@param executable? string +---@param substitutions table +---@return string[] local function build_command(cmd_template, executable, substitutions) + vim.validate({ + cmd_template = { cmd_template, "table" }, + executable = { executable, { "string", "nil" }, true }, + substitutions = { substitutions, "table" }, + }) + local cmd = substitute_template(cmd_template, substitutions) if executable then table.insert(cmd, 1, executable) @@ -59,7 +86,15 @@ local function ensure_directories() vim.system({ "mkdir", "-p", "build", "io" }):wait() end +---@param language_config table +---@param substitutions table +---@return {code: integer, stderr: string} local function compile_generic(language_config, substitutions) + vim.validate({ + language_config = { language_config, "table" }, + substitutions = { substitutions, "table" }, + }) + if not language_config.compile then logger.log("no compilation step required") return { code = 0, stderr = "" } @@ -81,7 +116,17 @@ local function compile_generic(language_config, substitutions) return result end +---@param cmd string[] +---@param input_data string +---@param timeout_ms integer +---@return ExecuteResult local function execute_command(cmd, input_data, timeout_ms) + vim.validate({ + cmd = { cmd, "table" }, + input_data = { input_data, "string" }, + timeout_ms = { timeout_ms, "number" }, + }) + logger.log(("executing: %s"):format(table.concat(cmd, " "))) local start_time = vim.loop.hrtime() @@ -114,7 +159,17 @@ local function execute_command(cmd, input_data, timeout_ms) } end +---@param exec_result ExecuteResult +---@param expected_file string +---@param is_debug boolean +---@return string local function format_output(exec_result, expected_file, is_debug) + vim.validate({ + exec_result = { exec_result, "table" }, + expected_file = { expected_file, "string" }, + is_debug = { is_debug, "boolean" }, + }) + local output_lines = { exec_result.stdout } local metadata_lines = {} @@ -158,9 +213,15 @@ end ---@param contest_config table ---@param is_debug boolean function M.run_problem(ctx, contest_config, is_debug) + vim.validate({ + ctx = { ctx, "table" }, + contest_config = { contest_config, "table" }, + is_debug = { is_debug, "boolean" }, + }) + ensure_directories() - local language = get_language_from_file(ctx.source_file) + local language = get_language_from_file(ctx.source_file, contest_config) local language_config = contest_config[language] if not language_config then @@ -171,7 +232,7 @@ function M.run_problem(ctx, contest_config, is_debug) local substitutions = { source = ctx.source_file, binary = ctx.binary_file, - version = tostring(language_config.version or ""), + version = tostring(language_config.version), } local compile_cmd = is_debug and language_config.debug or language_config.compile diff --git a/lua/cp/init.lua b/lua/cp/init.lua index 01a046a..72d0213 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -48,7 +48,8 @@ end ---@param contest_id string ---@param problem_id? string -local function setup_problem(contest_id, problem_id) +---@param language? string +local function setup_problem(contest_id, problem_id, language) if not state.platform then logger.log("no platform set. run :CP first", vim.log.levels.ERROR) return @@ -88,7 +89,7 @@ local function setup_problem(contest_id, problem_id) state.test_cases = cached_test_cases end - local ctx = problem.create_context(state.platform, contest_id, problem_id, config) + local ctx = problem.create_context(state.platform, contest_id, problem_id, config, language) local scrape_result = scrape.scrape_problem(ctx) @@ -116,6 +117,14 @@ local function setup_problem(contest_id, problem_id) vim.cmd.startinsert({ bang = true }) vim.schedule(function() + print( + "Debug: platform=" + .. state.platform + .. ", filetype=" + .. vim.bo.filetype + .. ", expandable=" + .. tostring(luasnip.expandable()) + ) if luasnip.expandable() then luasnip.expand() end @@ -161,19 +170,28 @@ local function run_problem() logger.log(("running problem: %s"):format(problem_id)) - if config.hooks and config.hooks.before_run then - config.hooks.before_run(problem_id) - end - if not state.platform then logger.log("no platform set", vim.log.levels.ERROR) return end local contest_config = config.contests[state.platform] + local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) + + if config.hooks and config.hooks.before_run then + config.hooks.before_run({ + problem_id = problem_id, + platform = state.platform, + contest_id = state.contest_id, + source_file = ctx.source_file, + input_file = ctx.input_file, + output_file = ctx.output_file, + expected_file = ctx.expected_file, + contest_config = contest_config, + }) + end vim.schedule(function() - local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) execute.run_problem(ctx, contest_config, false) vim.cmd.checktime() end) @@ -185,19 +203,28 @@ local function debug_problem() return end - if config.hooks and config.hooks.before_debug then - config.hooks.before_debug(problem_id) - end - if not state.platform then logger.log("no platform set", vim.log.levels.ERROR) return end local contest_config = config.contests[state.platform] + local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) + + if config.hooks and config.hooks.before_debug then + config.hooks.before_debug({ + problem_id = problem_id, + platform = state.platform, + contest_id = state.contest_id, + source_file = ctx.source_file, + input_file = ctx.input_file, + output_file = ctx.output_file, + expected_file = ctx.expected_file, + contest_config = contest_config, + }) + end vim.schedule(function() - local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) execute.run_problem(ctx, contest_config, true) vim.cmd.checktime() end) @@ -246,7 +273,8 @@ local function diff_problem() end ---@param delta number 1 for next, -1 for prev -local function navigate_problem(delta) +---@param language? string +local function navigate_problem(delta, language) if not state.platform or not state.contest_id then logger.log("no contest set. run :CP first", vim.log.levels.ERROR) return @@ -297,41 +325,69 @@ local function navigate_problem(delta) local new_problem = problems[new_index] if state.platform == "cses" then - setup_problem(new_problem.id) + setup_problem(new_problem.id, nil, language) else - setup_problem(state.contest_id, new_problem.id) + setup_problem(state.contest_id, new_problem.id, language) end end local function parse_command(args) if #args == 0 then - return { type = "error", message = "Usage: :CP [problem] | :CP | :CP " } + return { + type = "error", + message = "Usage: :CP [problem] [--lang=] | :CP | :CP ", + } end - local first = args[1] + local language = nil + + for i, arg in ipairs(args) do + local lang_match = arg:match("^--lang=(.+)$") + if lang_match then + language = lang_match + elseif arg == "--lang" then + if i + 1 <= #args then + language = args[i + 1] + else + return { type = "error", message = "--lang requires a value" } + end + end + end + + local filtered_args = vim.tbl_filter(function(arg) + return not (arg:match("^--lang") or arg == language) + end, args) + + local first = filtered_args[1] if vim.tbl_contains(actions, first) then - return { type = "action", action = first } + return { type = "action", action = first, language = language } end if vim.tbl_contains(platforms, first) then - if #args == 1 then - return { type = "platform_only", platform = first } - elseif #args == 2 then + if #filtered_args == 1 then + return { type = "platform_only", platform = first, language = language } + elseif #filtered_args == 2 then if first == "cses" then - return { type = "cses_problem", platform = first, problem = args[2] } + return { type = "cses_problem", platform = first, problem = filtered_args[2], language = language } else - return { type = "contest_setup", platform = first, contest = args[2] } + return { type = "contest_setup", platform = first, contest = filtered_args[2], language = language } end - elseif #args == 3 then - return { type = "full_setup", platform = first, contest = args[2], problem = args[3] } + elseif #filtered_args == 3 then + return { + type = "full_setup", + platform = first, + contest = filtered_args[2], + problem = filtered_args[3], + language = language, + } else return { type = "error", message = "Too many arguments" } end end if state.platform and state.contest_id then - return { type = "problem_switch", problem = first } + return { type = "problem_switch", problem = first, language = language } end return { type = "error", message = "Unknown command or no contest context" } @@ -353,9 +409,9 @@ function M.handle_command(opts) elseif cmd.action == "diff" then diff_problem() elseif cmd.action == "next" then - navigate_problem(1) + navigate_problem(1, cmd.language) elseif cmd.action == "prev" then - navigate_problem(-1) + navigate_problem(-1, cmd.language) end return end @@ -398,7 +454,7 @@ function M.handle_command(opts) ) end - setup_problem(cmd.contest, cmd.problem) + setup_problem(cmd.contest, cmd.problem, cmd.language) end return end @@ -412,16 +468,16 @@ function M.handle_command(opts) vim.log.levels.WARN ) end - setup_problem(cmd.problem) + setup_problem(cmd.problem, nil, cmd.language) end return end if cmd.type == "problem_switch" then if state.platform == "cses" then - setup_problem(cmd.problem) + setup_problem(cmd.problem, nil, cmd.language) else - setup_problem(state.contest_id, cmd.problem) + setup_problem(state.contest_id, cmd.problem, cmd.language) end return end diff --git a/lua/cp/languages.lua b/lua/cp/languages.lua new file mode 100644 index 0000000..6884287 --- /dev/null +++ b/lua/cp/languages.lua @@ -0,0 +1,22 @@ +local M = {} + +M.CPP = "cpp" +M.PYTHON = "python" + +---@type table +M.filetype_to_language = { + cc = M.CPP, + cxx = M.CPP, + cpp = M.CPP, + c = M.CPP, + py = M.PYTHON, + py3 = M.PYTHON, +} + +---@type table +M.canonical_filetypes = { + [M.CPP] = "cpp", + [M.PYTHON] = "python", +} + +return M diff --git a/lua/cp/problem.lua b/lua/cp/problem.lua index 60406fd..caeac2c 100644 --- a/lua/cp/problem.lua +++ b/lua/cp/problem.lua @@ -15,10 +15,19 @@ local M = {} ---@param contest_id string ---@param problem_id? string ---@param config cp.Config +---@param language? string ---@return ProblemContext -function M.create_context(contest, contest_id, problem_id, config) +function M.create_context(contest, contest_id, problem_id, config, language) + vim.validate({ + contest = { contest, "string" }, + contest_id = { contest_id, "string" }, + problem_id = { problem_id, { "string", "nil" }, true }, + config = { config, "table" }, + language = { language, { "string", "nil" }, true }, + }) + local filename_fn = config.filename or require("cp.config").default_filename - local source_file = filename_fn(contest, contest_id, problem_id) + local source_file = filename_fn(contest, contest_id, problem_id, config, language) local base_name = vim.fn.fnamemodify(source_file, ":t:r") return { diff --git a/lua/cp/scrape.lua b/lua/cp/scrape.lua index bde861a..bf40059 100644 --- a/lua/cp/scrape.lua +++ b/lua/cp/scrape.lua @@ -45,6 +45,11 @@ end ---@param contest_id string ---@return {success: boolean, problems?: table[], error?: string} function M.scrape_contest_metadata(platform, contest_id) + vim.validate({ + platform = { platform, "string" }, + contest_id = { contest_id, "string" }, + }) + cache.load() local cached_data = cache.get_contest_data(platform, contest_id) @@ -121,6 +126,10 @@ end ---@param ctx ProblemContext ---@return {success: boolean, problem_id: string, test_count?: number, test_cases?: table[], url?: string, error?: string} function M.scrape_problem(ctx) + vim.validate({ + ctx = { ctx, "table" }, + }) + ensure_io_directory() if vim.fn.filereadable(ctx.input_file) == 1 and vim.fn.filereadable(ctx.expected_file) == 1 then diff --git a/lua/cp/snippets.lua b/lua/cp/snippets.lua index fbddfff..4906faa 100644 --- a/lua/cp/snippets.lua +++ b/lua/cp/snippets.lua @@ -10,11 +10,19 @@ function M.setup(config) local s, i, fmt = ls.snippet, ls.insert_node, require("luasnip.extras.fmt").fmt - local default_snippets = { - s( - "codeforces", - fmt( - [[#include + local languages = require("cp.languages") + local filetype_to_language = languages.filetype_to_language + + local language_to_filetype = {} + for ext, lang in pairs(filetype_to_language) do + if not language_to_filetype[lang] then + language_to_filetype[lang] = ext + end + end + + local template_definitions = { + cpp = { + codeforces = [[#include using namespace std; @@ -34,14 +42,8 @@ int main() {{ return 0; }}]], - { i(1) } - ) - ), - s( - "atcoder", - fmt( - [[#include + atcoder = [[#include using namespace std; @@ -65,14 +67,8 @@ int main() {{ return 0; }}]], - { i(1) } - ) - ), - s( - "cses", - fmt( - [[#include + cses = [[#include using namespace std; @@ -83,29 +79,41 @@ int main() {{ return 0; }}]], - { i(1) } - ) - ), + }, + + python = { + codeforces = [[def solve(): + {} + +if __name__ == "__main__": + tc = int(input()) + for _ in range(tc): + solve()]], + + atcoder = [[def solve(): + {} + +if __name__ == "__main__": + solve()]], + + cses = [[{}]], + }, } - local default_map = {} - for _, snippet in pairs(default_snippets) do - default_map[snippet.trigger] = snippet + for language, template_set in pairs(template_definitions) do + local snippets = {} + local filetype = languages.canonical_filetypes[language] + + for contest, template in pairs(template_set) do + table.insert(snippets, s(contest, fmt(template, { i(1) }))) + end + + for _, snippet in ipairs(config.snippets or {}) do + table.insert(snippets, snippet) + end + + ls.add_snippets(filetype, snippets) end - - local user_map = {} - for _, snippet in pairs(config.snippets or {}) do - user_map[snippet.trigger] = snippet - end - - local merged_map = vim.tbl_extend("force", default_map, user_map) - - local all_snippets = {} - for _, snippet in pairs(merged_map) do - table.insert(all_snippets, snippet) - end - - ls.add_snippets("cpp", all_snippets) end return M diff --git a/lua/cp/window.lua b/lua/cp/window.lua index a62465d..72ebe8b 100644 --- a/lua/cp/window.lua +++ b/lua/cp/window.lua @@ -1,3 +1,14 @@ +---@class WindowState +---@field windows table +---@field current_win integer +---@field layout string + +---@class WindowData +---@field bufnr integer +---@field view table +---@field width integer +---@field height integer + local M = {} function M.clearcol() @@ -8,6 +19,7 @@ function M.clearcol() vim.api.nvim_set_option_value("foldcolumn", "0", { scope = "local" }) end +---@return WindowState function M.save_layout() local windows = {} for _, win in ipairs(vim.api.nvim_list_wins()) do @@ -29,7 +41,14 @@ function M.save_layout() } end +---@param state? WindowState +---@param tile_fn? fun(source_buf: integer, input_buf: integer, output_buf: integer) function M.restore_layout(state, tile_fn) + vim.validate({ + state = { state, { "table", "nil" }, true }, + tile_fn = { tile_fn, { "function", "nil" }, true }, + }) + if not state then return end @@ -56,7 +75,21 @@ function M.restore_layout(state, tile_fn) local input_file = ("%s/io/%s.in"):format(base_fp, problem_id) local output_file = ("%s/io/%s.out"):format(base_fp, problem_id) local source_files = vim.fn.glob(problem_id .. ".*") - local source_file = source_files ~= "" and vim.split(source_files, "\n")[1] or (problem_id .. ".cc") + local source_file + if source_files ~= "" then + local files = vim.split(source_files, "\n") + local valid_extensions = { "cc", "cpp", "cxx", "c", "py", "py3" } + for _, file in ipairs(files) do + local ext = vim.fn.fnamemodify(file, ":e") + if vim.tbl_contains(valid_extensions, ext) then + source_file = file + break + end + end + source_file = source_file or files[1] + else + source_file = problem_id .. ".cc" + end if vim.fn.filereadable(source_file) == 0 then return @@ -90,7 +123,16 @@ function M.restore_layout(state, tile_fn) end end +---@param actual_output string +---@param expected_output string +---@param input_file string function M.setup_diff_layout(actual_output, expected_output, input_file) + vim.validate({ + actual_output = { actual_output, "string" }, + expected_output = { expected_output, "string" }, + input_file = { input_file, "string" }, + }) + vim.cmd.diffoff() vim.cmd("silent only") @@ -117,7 +159,16 @@ function M.setup_diff_layout(actual_output, expected_output, input_file) vim.cmd.wincmd("k") end +---@param source_buf integer +---@param input_buf integer +---@param output_buf integer local function default_tile(source_buf, input_buf, output_buf) + vim.validate({ + source_buf = { source_buf, "number" }, + input_buf = { input_buf, "number" }, + output_buf = { output_buf, "number" }, + }) + vim.api.nvim_set_current_buf(source_buf) vim.cmd.vsplit() vim.api.nvim_set_current_buf(output_buf) diff --git a/plugin/cp.lua b/plugin/cp.lua index 83a2816..f96db7f 100644 --- a/plugin/cp.lua +++ b/plugin/cp.lua @@ -11,15 +11,38 @@ vim.api.nvim_create_user_command("CP", function(opts) cp.handle_command(opts) end, { nargs = "*", + desc = "Competitive programming helper", complete = function(ArgLead, CmdLine, _) + local languages_module = require("cp.languages") + local languages = vim.tbl_keys(languages_module.canonical_filetypes) + + if ArgLead:match("^--lang=") then + local lang_completions = {} + for _, lang in ipairs(languages) do + table.insert(lang_completions, "--lang=" .. lang) + end + return vim.tbl_filter(function(completion) + return completion:find(ArgLead, 1, true) == 1 + end, lang_completions) + end + + if ArgLead == "--lang" then + return { "--lang" } + end + local args = vim.split(vim.trim(CmdLine), "%s+") local num_args = #args if CmdLine:sub(-1) == " " then num_args = num_args + 1 end + local lang_flag_present = vim.tbl_contains(args, "--lang") + or vim.iter(args):any(function(arg) + return arg:match("^--lang=") + end) + if num_args == 2 then - local candidates = {} + local candidates = { "--lang" } vim.list_extend(candidates, platforms) vim.list_extend(candidates, actions) local cp = require("cp") @@ -37,13 +60,17 @@ end, { return vim.tbl_filter(function(cmd) return cmd:find(ArgLead, 1, true) == 1 end, candidates) - elseif num_args == 4 then + elseif args[#args - 1] == "--lang" then + return vim.tbl_filter(function(lang) + return lang:find(ArgLead, 1, true) == 1 + end, languages) + elseif num_args == 4 and not lang_flag_present then if vim.tbl_contains(platforms, args[2]) then local cache = require("cp.cache") cache.load() local contest_data = cache.get_contest_data(args[2], args[3]) if contest_data and contest_data.problems then - local candidates = {} + local candidates = { "--lang" } for _, problem in ipairs(contest_data.problems) do table.insert(candidates, problem.id) end