diff --git a/after/ftplugin/cpin.lua b/after/ftplugin/cpin.lua new file mode 100644 index 0000000..76a9f86 --- /dev/null +++ b/after/ftplugin/cpin.lua @@ -0,0 +1,6 @@ +vim.opt_local.number = false +vim.opt_local.relativenumber = false +vim.opt_local.statuscolumn = "" +vim.opt_local.signcolumn = "no" +vim.opt_local.wrap = true +vim.opt_local.linebreak = true diff --git a/after/ftplugin/cpout.lua b/after/ftplugin/cpout.lua new file mode 100644 index 0000000..3f0bcdc --- /dev/null +++ b/after/ftplugin/cpout.lua @@ -0,0 +1,7 @@ +vim.opt_local.number = false +vim.opt_local.relativenumber = false +vim.opt_local.statuscolumn = "" +vim.opt_local.signcolumn = "no" +vim.opt_local.wrap = true +vim.opt_local.linebreak = true +vim.opt_local.modifiable = false diff --git a/doc/cp.txt b/doc/cp.txt index 7570cc1..306c70a 100644 --- a/doc/cp.txt +++ b/doc/cp.txt @@ -9,12 +9,13 @@ cp.nvim is a competitive programming plugin that automates problem setup, compilation, and testing workflow for online judges. Supported platforms: AtCoder, Codeforces, CSES +Supported languages: C++, Python REQUIREMENTS *cp-requirements* - Neovim 0.10.0+ - uv package manager (https://docs.astral.sh/uv/) -- C++ compiler (g++/clang++) +- Language runtime/compiler (g++, python3) Optional: - LuaSnip for template expansion (https://github.com/L3MON4D3/LuaSnip) diff --git a/ftdetect/cp.lua b/ftdetect/cp.lua new file mode 100644 index 0000000..2b6b593 --- /dev/null +++ b/ftdetect/cp.lua @@ -0,0 +1,6 @@ +vim.filetype.add({ + extension = { + cpin = "cpin", + cpout = "cpout", + }, +}) diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 70a55c3..565afc4 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -87,4 +87,26 @@ function M.clear_contest_data(platform, contest_id) end end +function M.get_test_cases(platform, contest_id, problem_id) + local problem_key = problem_id and (contest_id .. "_" .. problem_id) or contest_id + if not cache_data[platform] or not cache_data[platform][problem_key] then + return nil + end + return cache_data[platform][problem_key].test_cases +end + +function M.set_test_cases(platform, contest_id, problem_id, test_cases) + local problem_key = problem_id and (contest_id .. "_" .. problem_id) or contest_id + if not cache_data[platform] then + cache_data[platform] = {} + end + if not cache_data[platform][problem_key] then + cache_data[platform][problem_key] = {} + end + + cache_data[platform][problem_key].test_cases = test_cases + cache_data[platform][problem_key].test_cases_cached_at = os.time() + M.save() +end + return M diff --git a/lua/cp/config.lua b/lua/cp/config.lua index 9f39c60..48a245e 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -12,19 +12,48 @@ local M = {} M.defaults = { contests = { default = { - cpp_version = 20, - compile_flags = { "-O2", "-DLOCAL", "-Wall", "-Wextra" }, - debug_flags = { "-g3", "-fsanitize=address,undefined", "-DLOCAL" }, + cpp = { + compile = { + "g++", + "-std=c++{version}", + "-O2", + "-DLOCAL", + "-Wall", + "-Wextra", + "{source}", + "-o", + "{binary}", + }, + run = { "{binary}" }, + debug = { + "g++", + "-std=c++{version}", + "-g3", + "-fsanitize=address,undefined", + "-DLOCAL", + "{source}", + "-o", + "{binary}", + }, + executable = nil, + version = 20, + }, + python = { + compile = nil, + run = { "{source}" }, + debug = { "{source}" }, + executable = "python3", + }, timeout_ms = 2000, }, atcoder = { - cpp_version = 23, + cpp = { version = 23 }, }, codeforces = { - cpp_version = 23, + cpp = { version = 23 }, }, cses = { - cpp_version = 20, + cpp = { version = 20 }, }, }, snippets = {}, @@ -42,11 +71,6 @@ M.defaults = { ---@return table local function extend_contest_config(base_config, contest_config) local result = vim.tbl_deep_extend("force", base_config, contest_config) - - local std_flag = ("-std=c++%d"):format(result.cpp_version) - result.compile_flags = vim.list_extend({ std_flag }, result.compile_flags) - result.debug_flags = vim.list_extend({ std_flag }, result.debug_flags) - return result end diff --git a/lua/cp/execute.lua b/lua/cp/execute.lua index 774346b..00e9332 100644 --- a/lua/cp/execute.lua +++ b/lua/cp/execute.lua @@ -1,4 +1,41 @@ local M = {} +local logger = require("cp.log") + +local filetype_to_language = { + cpp = "cpp", + cxx = "cpp", + cc = "cpp", + c = "cpp", + py = "python", + py3 = "python", +} + +local function get_language_from_file(source_file) + local extension = vim.fn.fnamemodify(source_file, ":e") + local language = filetype_to_language[extension] or "cpp" + logger.log(("detected language: %s (extension: %s)"):format(language, extension)) + return language +end + +local function substitute_template(cmd_template, substitutions) + local result = {} + for _, arg in ipairs(cmd_template) do + local substituted = arg + for key, value in pairs(substitutions) do + substituted = substituted:gsub("{" .. key .. "}", value) + end + table.insert(result, substituted) + end + return result +end + +local function build_command(cmd_template, executable, substitutions) + local cmd = substitute_template(cmd_template, substitutions) + if executable then + table.insert(cmd, 1, executable) + end + return cmd +end local signal_codes = { [128] = "SIGILL", @@ -22,15 +59,34 @@ local function ensure_directories() vim.system({ "mkdir", "-p", "build", "io" }):wait() end -local function compile_cpp(source_path, binary_path, flags) - local compile_cmd = { "g++", unpack(flags), source_path, "-o", binary_path } - return vim.system(compile_cmd, { text = true }):wait() +local function compile_generic(language_config, substitutions) + if not language_config.compile then + logger.log("no compilation step required") + return { code = 0, stderr = "" } + end + + local compile_cmd = substitute_template(language_config.compile, substitutions) + logger.log(("compiling: %s"):format(table.concat(compile_cmd, " "))) + + local start_time = vim.loop.hrtime() + local result = vim.system(compile_cmd, { text = true }):wait() + local compile_time = (vim.loop.hrtime() - start_time) / 1000000 + + if result.code == 0 then + logger.log(("compilation successful (%.1fms)"):format(compile_time)) + else + logger.log(("compilation failed (%.1fms): %s"):format(compile_time, result.stderr), vim.log.levels.WARN) + end + + return result end -local function execute_binary(binary_path, input_data, timeout_ms) +local function execute_command(cmd, input_data, timeout_ms) + logger.log(("executing: %s"):format(table.concat(cmd, " "))) + local start_time = vim.loop.hrtime() - local result = vim.system({ binary_path }, { + local result = vim.system(cmd, { stdin = input_data, timeout = timeout_ms, text = true, @@ -41,6 +97,14 @@ local function execute_binary(binary_path, input_data, timeout_ms) local actual_code = result.code or 0 + if result.code == 124 then + logger.log(("execution timed out after %.1fms"):format(execution_time), vim.log.levels.WARN) + elseif actual_code ~= 0 then + logger.log(("execution failed (exit code %d, %.1fms)"):format(actual_code, execution_time), vim.log.levels.WARN) + else + logger.log(("execution successful (%.1fms)"):format(execution_time)) + end + return { stdout = result.stdout or "", stderr = result.stderr or "", @@ -96,20 +160,36 @@ end function M.run_problem(ctx, contest_config, is_debug) ensure_directories() - local flags = is_debug and contest_config.debug_flags or contest_config.compile_flags + local language = get_language_from_file(ctx.source_file) + local language_config = contest_config[language] - local compile_result = compile_cpp(ctx.source_file, ctx.binary_file, flags) - if compile_result.code ~= 0 then - vim.fn.writefile({ compile_result.stderr }, ctx.output_file) + if not language_config then + vim.fn.writefile({ "Error: No configuration for language: " .. language }, ctx.output_file) return end + local substitutions = { + source = ctx.source_file, + binary = ctx.binary_file, + version = tostring(language_config.version or ""), + } + + local compile_cmd = is_debug and language_config.debug or language_config.compile + if compile_cmd then + local compile_result = compile_generic(language_config, substitutions) + if compile_result.code ~= 0 then + vim.fn.writefile({ compile_result.stderr }, ctx.output_file) + return + end + end + local input_data = "" if vim.fn.filereadable(ctx.input_file) == 1 then input_data = table.concat(vim.fn.readfile(ctx.input_file), "\n") .. "\n" end - local exec_result = execute_binary(ctx.binary_file, input_data, contest_config.timeout_ms) + local run_cmd = build_command(language_config.run, language_config.executable, substitutions) + local exec_result = execute_command(run_cmd, input_data, contest_config.timeout_ms) local formatted_output = format_output(exec_result, ctx.expected_file, is_debug) local output_buf = vim.fn.bufnr(ctx.output_file) diff --git a/lua/cp/health.lua b/lua/cp/health.lua index d902934..738e9e2 100644 --- a/lua/cp/health.lua +++ b/lua/cp/health.lua @@ -62,12 +62,14 @@ end local function check_config() vim.health.ok("Plugin ready") - if vim.g.cp and vim.g.cp.platform then - local info = vim.g.cp.platform - if vim.g.cp.contest_id then - info = info .. " " .. vim.g.cp.contest_id - if vim.g.cp.problem_id then - info = info .. " " .. vim.g.cp.problem_id + local cp = require("cp") + local context = cp.get_current_context() + if context.platform then + local info = context.platform + if context.contest_id then + info = info .. " " .. context.contest_id + if context.problem_id then + info = info .. " " .. context.problem_id end end vim.health.info("Current context: " .. info) diff --git a/lua/cp/init.lua b/lua/cp/init.lua index b86dc98..01a046a 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -14,7 +14,6 @@ if not vim.fn.has("nvim-0.10.0") then return {} end -vim.g.cp = vim.g.cp or {} local user_config = {} local config = config_module.setup(user_config) logger.set_config(config) @@ -28,6 +27,8 @@ local state = { saved_layout = nil, saved_session = nil, temp_output = nil, + test_cases = nil, + test_states = {}, } local platforms = { "atcoder", "codeforces", "cses" } @@ -53,6 +54,9 @@ local function setup_problem(contest_id, problem_id) return end + local problem_name = state.platform == "cses" and contest_id or (contest_id .. (problem_id or "")) + logger.log(("setting up problem: %s"):format(problem_name)) + local metadata_result = scrape.scrape_contest_metadata(state.platform, contest_id) if not metadata_result.success then logger.log( @@ -79,6 +83,11 @@ local function setup_problem(contest_id, problem_id) state.contest_id = contest_id state.problem_id = problem_id + local cached_test_cases = cache.get_test_cases(state.platform, contest_id, problem_id) + if cached_test_cases then + state.test_cases = cached_test_cases + end + local ctx = problem.create_context(state.platform, contest_id, problem_id, config) local scrape_result = scrape.scrape_problem(ctx) @@ -86,9 +95,15 @@ local function setup_problem(contest_id, problem_id) if not scrape_result.success then logger.log("scraping failed: " .. (scrape_result.error or "unknown error"), vim.log.levels.WARN) logger.log("you can manually add test cases to io/ directory", vim.log.levels.INFO) + state.test_cases = nil else local test_count = scrape_result.test_count or 0 logger.log(("scraped %d test case(s) for %s"):format(test_count, scrape_result.problem_id)) + state.test_cases = scrape_result.test_cases + + if scrape_result.test_cases then + cache.set_test_cases(state.platform, contest_id, problem_id, scrape_result.test_cases) + end end vim.cmd.e(ctx.source_file) @@ -144,6 +159,8 @@ local function run_problem() return end + logger.log(("running problem: %s"):format(problem_id)) + if config.hooks and config.hooks.before_run then config.hooks.before_run(problem_id) end @@ -188,34 +205,44 @@ end local function diff_problem() if state.diff_mode then - local tile_fn = config.tile or window.default_tile - window.restore_layout(state.saved_layout, tile_fn) + vim.cmd.diffoff() + if state.saved_session then + vim.fn.delete(state.saved_session) + state.saved_session = nil + end + if state.temp_output then + vim.fn.delete(state.temp_output) + state.temp_output = nil + end state.diff_mode = false - state.saved_layout = nil - logger.log("exited diff mode") - else - local problem_id = get_current_problem() - if not problem_id then - return - end - - local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) - - if vim.fn.filereadable(ctx.expected_file) == 0 then - logger.log(("No expected output file found: %s"):format(ctx.expected_file), vim.log.levels.ERROR) - return - end - - state.saved_layout = window.save_layout() - - local result = vim.system({ "awk", "/^\\[[^]]*\\]:/ {exit} {print}", ctx.output_file }, { text = true }):wait() - local actual_output = result.stdout - - window.setup_diff_layout(actual_output, ctx.expected_file, ctx.input_file) - - state.diff_mode = true - logger.log("entered diff mode") + return end + + local problem_id = get_current_problem() + if not problem_id then + return + end + + local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) + + if vim.fn.filereadable(ctx.expected_file) == 0 then + logger.log("no expected output file found", vim.log.levels.WARN) + return + end + + if vim.fn.filereadable(ctx.output_file) == 0 then + logger.log("no output file found. run the problem first", vim.log.levels.WARN) + return + end + + state.saved_session = vim.fn.tempname() + vim.cmd(("mksession! %s"):format(state.saved_session)) + + vim.cmd("silent only") + vim.cmd(("edit %s"):format(ctx.expected_file)) + vim.cmd.diffthis() + vim.cmd(("vertical diffsplit %s"):format(ctx.output_file)) + state.diff_mode = true end ---@param delta number 1 for next, -1 for prev @@ -411,6 +438,14 @@ function M.setup(opts) end end +function M.get_current_context() + return { + platform = state.platform, + contest_id = state.contest_id, + problem_id = state.problem_id, + } +end + function M.is_initialized() return true end diff --git a/lua/cp/problem.lua b/lua/cp/problem.lua index 4aa73d2..60406fd 100644 --- a/lua/cp/problem.lua +++ b/lua/cp/problem.lua @@ -27,8 +27,8 @@ function M.create_context(contest, contest_id, problem_id, config) problem_id = problem_id, source_file = source_file, binary_file = ("build/%s.run"):format(base_name), - input_file = ("io/%s.in"):format(base_name), - output_file = ("io/%s.out"):format(base_name), + input_file = ("io/%s.cpin"):format(base_name), + output_file = ("io/%s.cpout"):format(base_name), expected_file = ("io/%s.expected"):format(base_name), problem_name = base_name, } diff --git a/lua/cp/scrape.lua b/lua/cp/scrape.lua index 82e2e4f..bde861a 100644 --- a/lua/cp/scrape.lua +++ b/lua/cp/scrape.lua @@ -74,9 +74,9 @@ function M.scrape_contest_metadata(platform, contest_id) local args if platform == "cses" then - args = { "uv", "run", scraper_path, "metadata" } + args = { "uv", "run", "--directory", plugin_path, scraper_path, "metadata" } else - args = { "uv", "run", scraper_path, "metadata", contest_id } + args = { "uv", "run", "--directory", plugin_path, scraper_path, "metadata", contest_id } end local result = vim.system(args, { @@ -119,7 +119,7 @@ function M.scrape_contest_metadata(platform, contest_id) end ---@param ctx ProblemContext ----@return {success: boolean, problem_id: string, test_count?: number, url?: string, error?: string} +---@return {success: boolean, problem_id: string, test_count?: number, test_cases?: table[], url?: string, error?: string} function M.scrape_problem(ctx) ensure_io_directory() @@ -152,9 +152,9 @@ function M.scrape_problem(ctx) local args if ctx.contest == "cses" then - args = { "uv", "run", scraper_path, "tests", ctx.contest_id } + args = { "uv", "run", "--directory", plugin_path, scraper_path, "tests", ctx.contest_id } else - args = { "uv", "run", scraper_path, "tests", ctx.contest_id, ctx.problem_id } + args = { "uv", "run", "--directory", plugin_path, scraper_path, "tests", ctx.contest_id, ctx.problem_id } end local result = vim.system(args, { @@ -185,30 +185,18 @@ function M.scrape_problem(ctx) end if data.test_cases and #data.test_cases > 0 then - local all_inputs = {} - local all_outputs = {} + local combined_input = data.test_cases[1].input:gsub("\r", "") + local combined_output = data.test_cases[1].output:gsub("\r", "") - for _, test_case in ipairs(data.test_cases) do - local input_lines = vim.split(test_case.input:gsub("\r", ""):gsub("\n+$", ""), "\n") - local output_lines = vim.split(test_case.output:gsub("\r", ""):gsub("\n+$", ""), "\n") - - for _, line in ipairs(input_lines) do - table.insert(all_inputs, line) - end - - for _, line in ipairs(output_lines) do - table.insert(all_outputs, line) - end - end - - vim.fn.writefile(all_inputs, ctx.input_file) - vim.fn.writefile(all_outputs, ctx.expected_file) + vim.fn.writefile(vim.split(combined_input, "\n", true), ctx.input_file) + vim.fn.writefile(vim.split(combined_output, "\n", true), ctx.expected_file) end return { success = true, problem_id = ctx.problem_name, test_count = data.test_cases and #data.test_cases or 0, + test_cases = data.test_cases, url = data.url, } end diff --git a/plugin/cp.lua b/plugin/cp.lua index 690823a..83a2816 100644 --- a/plugin/cp.lua +++ b/plugin/cp.lua @@ -22,10 +22,12 @@ end, { local candidates = {} vim.list_extend(candidates, platforms) vim.list_extend(candidates, actions) - if vim.g.cp and vim.g.cp.platform and vim.g.cp.contest_id then + local cp = require("cp") + local context = cp.get_current_context() + if context.platform and context.contest_id then local cache = require("cp.cache") cache.load() - local contest_data = cache.get_contest_data(vim.g.cp.platform, vim.g.cp.contest_id) + local contest_data = cache.get_contest_data(context.platform, context.contest_id) if contest_data and contest_data.problems then for _, problem in ipairs(contest_data.problems) do table.insert(candidates, problem.id) diff --git a/readme.md b/readme.md index f6b43ed..8cafc5e 100644 --- a/readme.md +++ b/readme.md @@ -9,9 +9,10 @@ https://private-user-images.githubusercontent.com/62671086/489116291-391976d1-c2 ## Features - Support for multiple online judges ([AtCoder](https://atcoder.jp/), [Codeforces](https://codeforces.com/), [CSES](https://cses.fi)) +- Multi-language support (C++, Python) - Automatic problem scraping and test case management - Integrated build, run, and debug commands -- Diff mode for comparing output with expected results +- Enhanced test viewer with individual test case management - LuaSnip integration for contest-specific snippets ## Requirements @@ -56,9 +57,14 @@ follows: 4. Submit the problem (on the remote!) +## Similar Projects + +- [competitest.nvim](https://github.com/xeluxee/competitest.nvim) + ## TODO - finer-tuned problem limits (i.e. per-problem codeforces time, memory) - better highlighting - test case management - USACO support +- new video with functionality, notify discord members diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index 6343287..4610ae9 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -19,28 +19,35 @@ def scrape(url: str) -> list[tuple[str, str]]: input_sections = soup.find_all("div", class_="input") output_sections = soup.find_all("div", class_="output") - for inp_section, out_section in zip(input_sections, output_sections): + all_inputs = [] + all_outputs = [] + + for inp_section in input_sections: inp_pre = inp_section.find("pre") + if inp_pre: + divs = inp_pre.find_all("div") + if divs: + lines = [div.get_text().strip() for div in divs] + text = "\n".join(lines) + else: + text = inp_pre.get_text().replace("\r", "") + all_inputs.append(text) + + for out_section in output_sections: out_pre = out_section.find("pre") + if out_pre: + divs = out_pre.find_all("div") + if divs: + lines = [div.get_text().strip() for div in divs] + text = "\n".join(lines) + else: + text = out_pre.get_text().replace("\r", "") + all_outputs.append(text) - if inp_pre and out_pre: - input_lines: list[str] = [] - output_lines: list[str] = [] - - input_text_raw = inp_pre.get_text().strip().replace("\r", "") - input_lines = [ - line.strip() for line in input_text_raw.split("\n") if line.strip() - ] - - output_text_raw = out_pre.get_text().strip().replace("\r", "") - output_lines = [ - line.strip() for line in output_text_raw.split("\n") if line.strip() - ] - - if input_lines and output_lines: - input_text = "\n".join(input_lines) - output_text = "\n".join(output_lines) - tests.append((input_text, output_text)) + if all_inputs and all_outputs: + combined_input = "\n".join(all_inputs) + combined_output = "\n".join(all_outputs) + tests.append((combined_input, combined_output)) return tests @@ -112,7 +119,7 @@ def main() -> None: if mode == "metadata": if len(sys.argv) != 3: - result = { + result: dict[str, str | bool] = { "success": False, "error": "Usage: codeforces.py metadata ", } @@ -123,14 +130,14 @@ def main() -> None: problems: list[dict[str, str]] = scrape_contest_problems(contest_id) if not problems: - result = { + result: dict[str, str | bool] = { "success": False, "error": f"No problems found for contest {contest_id}", } print(json.dumps(result)) sys.exit(1) - result = { + result: dict[str, str | bool | list] = { "success": True, "contest_id": contest_id, "problems": problems, @@ -139,7 +146,7 @@ def main() -> None: elif mode == "tests": if len(sys.argv) != 4: - result = { + result: dict[str, str | bool] = { "success": False, "error": "Usage: codeforces.py tests ", } @@ -154,7 +161,7 @@ def main() -> None: tests: list[tuple[str, str]] = scrape_sample_tests(url) if not tests: - result = { + result: dict[str, str | bool] = { "success": False, "error": f"No tests found for {contest_id} {problem_letter}", "problem_id": problem_id, @@ -167,7 +174,7 @@ def main() -> None: for input_data, output_data in tests: test_cases.append({"input": input_data, "output": output_data}) - result = { + result: dict[str, str | bool | list] = { "success": True, "problem_id": problem_id, "url": url, @@ -176,7 +183,7 @@ def main() -> None: print(json.dumps(result)) else: - result = { + result: dict[str, str | bool] = { "success": False, "error": f"Unknown mode: {mode}. Use 'metadata' or 'tests'", }