diff --git a/lua/cp/config.lua b/lua/cp/config.lua index 5860363..1bbda23 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -20,15 +20,7 @@ M.defaults = { } local function extend_contest_config(base_config, contest_config) - local result = vim.deepcopy(base_config) - - for key, value in pairs(contest_config) do - if key == "compile_flags" or key == "debug_flags" then - vim.list_extend(result[key], value) - else - result[key] = value - end - end + local result = vim.tbl_deep_extend("force", base_config, contest_config) local std_flag = ("-std=c++%d"):format(result.cpp_version) table.insert(result.compile_flags, 1, std_flag) diff --git a/lua/cp/execute.lua b/lua/cp/execute.lua new file mode 100644 index 0000000..f4d0d78 --- /dev/null +++ b/lua/cp/execute.lua @@ -0,0 +1,111 @@ +local M = {} + +local signal_codes = { + [128] = "SIGILL", + [130] = "SIGABRT", + [131] = "SIGBUS", + [136] = "SIGFPE", + [135] = "SIGSEGV", + [137] = "SIGPIPE", + [139] = "SIGTERM", +} + +local function get_paths(problem_id) + return { + source = ("%s.cc"):format(problem_id), + binary = ("build/%s"):format(problem_id), + input = ("io/%s.in"):format(problem_id), + output = ("io/%s.out"):format(problem_id), + expected = ("io/%s.expected"):format(problem_id), + } +end + +local function ensure_directories() + vim.system({ "mkdir", "-p", "build", "io" }):wait() +end + +local function compile_cpp(source_path, binary_path, flags) + local compile_cmd = { "g++", unpack(flags), source_path, "-o", binary_path } + return vim.system(compile_cmd, { text = true }):wait() +end + +local function execute_binary(binary_path, input_data, timeout_ms) + local start_time = vim.loop.hrtime() + + local result = vim.system({ binary_path }, { + stdin = input_data, + timeout = timeout_ms, + text = true, + }):wait() + + local end_time = vim.loop.hrtime() + local execution_time = (end_time - start_time) / 1000000 + + return { + stdout = result.stdout or "", + stderr = result.stderr or "", + code = result.code, + time_ms = execution_time, + timed_out = result.code == 124, + } +end + +local function format_output(exec_result, expected_file) + local lines = { exec_result.stdout } + + if exec_result.timed_out then + table.insert(lines, "\n[code]: 124 (TIMEOUT)") + elseif exec_result.code >= 128 then + local signal_name = signal_codes[exec_result.code] or "SIGNAL" + table.insert(lines, ("\n[code]: %d (%s)"):format(exec_result.code, signal_name)) + else + table.insert(lines, ("\n[code]: %d"):format(exec_result.code)) + end + + table.insert(lines, ("\n[time]: %.2f ms"):format(exec_result.time_ms)) + table.insert(lines, "\n[debug]: false") + + if vim.fn.filereadable(expected_file) == 1 and exec_result.code == 0 then + local expected_content = vim.fn.readfile(expected_file) + local actual_lines = vim.split(exec_result.stdout, "\n") + + local matches = #actual_lines == #expected_content + if matches then + for i, line in ipairs(actual_lines) do + if line ~= expected_content[i] then + matches = false + break + end + end + end + + table.insert(lines, ("\n[matches]: %s"):format(matches and "true" or "false")) + end + + return table.concat(lines, "") +end + +function M.run_problem(problem_id, contest_config, is_debug) + ensure_directories() + + local paths = get_paths(problem_id) + local flags = is_debug and contest_config.debug_flags or contest_config.compile_flags + + local compile_result = compile_cpp(paths.source, paths.binary, flags) + if compile_result.code ~= 0 then + vim.fn.writefile({ compile_result.stderr }, paths.output) + return + end + + local input_data = "" + if vim.fn.filereadable(paths.input) == 1 then + input_data = table.concat(vim.fn.readfile(paths.input), "\n") .. "\n" + end + + local exec_result = execute_binary(paths.binary, input_data, contest_config.timeout_ms) + local formatted_output = format_output(exec_result, paths.expected) + + vim.fn.writefile(vim.split(formatted_output, "\n"), paths.output) +end + +return M diff --git a/lua/cp/init.lua b/lua/cp/init.lua index fb73e63..09a1dc4 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -1,5 +1,7 @@ local config_module = require("cp.config") local snippets = require("cp.snippets") +local execute = require("cp.execute") +local scrape = require("cp.scrape") local M = {} local config = {} @@ -40,9 +42,9 @@ local function setup_python_env() if vim.fn.isdirectory(venv_dir) == 0 then log("setting up Python environment for scrapers...") - local result = vim.fn.system(("cd %s && uv sync"):format(vim.fn.shellescape(plugin_path))) - if vim.v.shell_error ~= 0 then - log("failed to setup Python environment: " .. result, vim.log.levels.ERROR) + local result = vim.system({ "uv", "sync" }, { cwd = plugin_path, text = true }):wait() + if result.code ~= 0 then + log("failed to setup Python environment: " .. result.stderr, vim.log.levels.ERROR) return false end log("python environment setup complete") @@ -63,8 +65,8 @@ local function setup_contest(contest_type) end vim.g.cp_contest = contest_type - vim.fn.system(("cp -fr %s/* ."):format(config.template_dir)) - vim.fn.system(("make setup VERSION=%s"):format(config.contests[contest_type].cpp_version)) + vim.fn.mkdir("build", "p") + vim.fn.mkdir("io", "p") log(("set up %s contest environment"):format(contest_type)) end @@ -89,17 +91,24 @@ local function setup_problem(problem_id, problem_letter) vim.cmd.only() - local filename, full_problem_id - if (vim.g.cp_contest == "atcoder" or vim.g.cp_contest == "codeforces") and problem_letter then - full_problem_id = problem_id .. problem_letter - filename = full_problem_id .. ".cc" - vim.fn.system(("make scrape %s %s %s"):format(vim.g.cp_contest, problem_id, problem_letter)) + local scrape_result = scrape.scrape_problem(vim.g.cp_contest, problem_id, problem_letter) + + if not scrape_result.success then + log("scraping failed: " .. scrape_result.error, vim.log.levels.WARN) + log("you can manually add test cases to io/ directory", vim.log.levels.INFO) else - full_problem_id = problem_id - filename = problem_id .. ".cc" - vim.fn.system(("make scrape %s %s"):format(vim.g.cp_contest, problem_id)) + log(("scraped %d test case(s) for %s"):format(scrape_result.test_count, scrape_result.problem_id)) end + local full_problem_id = scrape_result.success and scrape_result.problem_id + or ( + (vim.g.cp_contest == "atcoder" or vim.g.cp_contest == "codeforces") + and problem_letter + and problem_id .. problem_letter:upper() + or problem_id + ) + local filename = full_problem_id .. ".cc" + vim.cmd.e(filename) if vim.api.nvim_buf_get_lines(0, 0, -1, true)[1] == "" then @@ -150,10 +159,16 @@ local function run_problem() lsp.lsp_format({ async = true }) end - vim.system({ "make", "run", vim.fn.expand("%:t") }, {}, function() - vim.schedule(function() - vim.cmd.checktime() - end) + if not vim.g.cp_contest then + log("no contest mode set", vim.log.levels.ERROR) + return + end + + local contest_config = config.contests[vim.g.cp_contest] + + vim.schedule(function() + execute.run_problem(problem_id, contest_config, false) + vim.cmd.checktime() end) end @@ -168,10 +183,16 @@ local function debug_problem() lsp.lsp_format({ async = true }) end - vim.system({ "make", "debug", vim.fn.expand("%:t") }, {}, function() - vim.schedule(function() - vim.cmd.checktime() - end) + if not vim.g.cp_contest then + log("no contest mode set", vim.log.levels.ERROR) + return + end + + local contest_config = config.contests[vim.g.cp_contest] + + vim.schedule(function() + execute.run_problem(problem_id, contest_config, true) + vim.cmd.checktime() end) end @@ -202,7 +223,8 @@ local function diff_problem() end local temp_output = vim.fn.tempname() - vim.fn.system(("awk '/^\\[[^]]*\\]:/ {exit} {print}' %s > %s"):format(vim.fn.shellescape(output), temp_output)) + local result = vim.system({ "awk", "/^\\[[^]]*\\]:/ {exit} {print}", output }, { text = true }):wait() + vim.fn.writefile(vim.split(result.stdout, "\n"), temp_output) local session_file = vim.fn.tempname() .. ".vim" vim.cmd(("silent! mksession! %s"):format(session_file)) @@ -239,7 +261,6 @@ function M.setup(user_config) config = config_module.setup(user_config) local plugin_path = get_plugin_path() - config.template_dir = plugin_path .. "/templates" config.snippets.path = plugin_path .. "/templates/snippets" snippets.setup(config) diff --git a/lua/cp/scrape.lua b/lua/cp/scrape.lua new file mode 100644 index 0000000..dd0bb79 --- /dev/null +++ b/lua/cp/scrape.lua @@ -0,0 +1,68 @@ +local M = {} + +local function get_plugin_path() + local plugin_path = debug.getinfo(1, "S").source:sub(2) + return vim.fn.fnamemodify(plugin_path, ":h:h:h") +end + +local function ensure_io_directory() + vim.fn.mkdir("io", "p") +end + +function M.scrape_problem(contest, problem_id, problem_letter) + ensure_io_directory() + + local plugin_path = get_plugin_path() + local scraper_path = plugin_path .. "/templates/scrapers/" .. contest .. ".py" + + local args + if contest == "cses" then + args = { "uv", "run", scraper_path, problem_id } + else + args = { "uv", "run", scraper_path, problem_id, problem_letter } + end + + local result = vim.system(args, { + cwd = plugin_path, + text = true, + timeout = 30000, + }):wait() + + if result.code ~= 0 then + return { + success = false, + error = "Failed to run scraper: " .. (result.stderr or "Unknown error"), + } + end + + local ok, data = pcall(vim.json.decode, result.stdout) + if not ok then + return { + success = false, + error = "Failed to parse scraper output: " .. tostring(data), + } + end + + if not data.success then + return data + end + + local full_problem_id = data.problem_id + local input_file = "io/" .. full_problem_id .. ".in" + local expected_file = "io/" .. full_problem_id .. ".expected" + + if #data.test_cases > 0 then + local first_test = data.test_cases[1] + vim.fn.writefile(vim.split(first_test.input, "\n"), input_file) + vim.fn.writefile(vim.split(first_test.output, "\n"), expected_file) + end + + return { + success = true, + problem_id = full_problem_id, + test_count = #data.test_cases, + url = data.url, + } +end + +return M diff --git a/readme.md b/readme.md index 15405b5..cda6d7c 100644 --- a/readme.md +++ b/readme.md @@ -19,8 +19,6 @@ neovim plugin for competitive programming. - `make` - [uv](https://docs.astral.sh/uv/): problem scraping (optional) - [LuaSnip](https://github.com/L3MON4D3/LuaSnip): contest-specific snippets (optional) -- [vim-zoom](https://github.com/dhruvasagar/vim-zoom): better diff view - (optional) ## Installation diff --git a/templates/compile_flags.txt b/templates/compile_flags.txt deleted file mode 100644 index 04b1b00..0000000 --- a/templates/compile_flags.txt +++ /dev/null @@ -1,2 +0,0 @@ --O2 --DLOCAL diff --git a/templates/debug_flags.txt b/templates/debug_flags.txt deleted file mode 100644 index a2c29c5..0000000 --- a/templates/debug_flags.txt +++ /dev/null @@ -1,3 +0,0 @@ --g3 --fsanitize=address,undefined --DLOCAL diff --git a/templates/scrapers/atcoder.py b/templates/scrapers/atcoder.py index 374ffdb..788a573 100644 --- a/templates/scrapers/atcoder.py +++ b/templates/scrapers/atcoder.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import json import sys import requests @@ -57,12 +58,17 @@ def scrape(url: str) -> list[tuple[str, str]]: def main(): if len(sys.argv) != 3: - print("Usage: atcoder.py ", file=sys.stderr) - print("Example: atcoder.py abc042 a", file=sys.stderr) + result = { + "success": False, + "error": "Usage: atcoder.py ", + "problem_id": None, + } + print(json.dumps(result)) sys.exit(1) contest_id = sys.argv[1] problem_letter = sys.argv[2] + problem_id = contest_id + problem_letter url = parse_problem_url(contest_id, problem_letter) print(f"Scraping: {url}", file=sys.stderr) @@ -70,17 +76,27 @@ def main(): tests = scrape(url) if not tests: - print(f"No tests found for {contest_id} {problem_letter}", file=sys.stderr) + result = { + "success": False, + "error": f"No tests found for {contest_id} {problem_letter}", + "problem_id": problem_id, + "url": url, + } + print(json.dumps(result)) sys.exit(1) - print("---INPUT---") - print(len(tests)) + test_cases = [] for input_data, output_data in tests: - print(input_data) - print("---OUTPUT---") - for input_data, output_data in tests: - print(output_data) - print("---END---") + test_cases.append({"input": input_data, "output": output_data}) + + result = { + "success": True, + "problem_id": problem_id, + "url": url, + "test_cases": test_cases, + } + + print(json.dumps(result)) if __name__ == "__main__": diff --git a/templates/scrapers/codeforces.py b/templates/scrapers/codeforces.py index ed31990..d1c24fa 100644 --- a/templates/scrapers/codeforces.py +++ b/templates/scrapers/codeforces.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import json import sys import cloudscraper @@ -73,31 +74,43 @@ def scrape_sample_tests(url: str): def main(): if len(sys.argv) != 3: - print("Usage: codeforces.py ", file=sys.stderr) - print("Example: codeforces.py 1234 A", file=sys.stderr) + result = { + "success": False, + "error": "Usage: codeforces.py ", + "problem_id": None, + } + print(json.dumps(result)) sys.exit(1) contest_id = sys.argv[1] problem_letter = sys.argv[2] + problem_id = contest_id + problem_letter.upper() url = parse_problem_url(contest_id, problem_letter) tests = scrape_sample_tests(url) if not tests: - print(f"No tests found for {contest_id} {problem_letter}", file=sys.stderr) - print( - "Consider adding test cases manually to the io/ directory", file=sys.stderr - ) + result = { + "success": False, + "error": f"No tests found for {contest_id} {problem_letter}", + "problem_id": problem_id, + "url": url, + } + print(json.dumps(result)) sys.exit(1) - print("---INPUT---") - print(len(tests)) + test_cases = [] for input_data, output_data in tests: - print(input_data) - print("---OUTPUT---") - for input_data, output_data in tests: - print(output_data) - print("---END---") + test_cases.append({"input": input_data, "output": output_data}) + + result = { + "success": True, + "problem_id": problem_id, + "url": url, + "test_cases": test_cases, + } + + print(json.dumps(result)) if __name__ == "__main__": diff --git a/templates/scrapers/cses.py b/templates/scrapers/cses.py index 38d43aa..8da2ba6 100755 --- a/templates/scrapers/cses.py +++ b/templates/scrapers/cses.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import json import sys import requests @@ -57,31 +58,54 @@ def scrape(url: str) -> list[tuple[str, str]]: def main(): if len(sys.argv) != 2: - print("Usage: cses.py ", file=sys.stderr) + result = { + "success": False, + "error": "Usage: cses.py ", + "problem_id": None, + } + print(json.dumps(result)) sys.exit(1) problem_input = sys.argv[1] url = parse_problem_url(problem_input) if not url: - print(f"Invalid problem input: {problem_input}", file=sys.stderr) - print("Use either problem ID (e.g., 1068) or full URL", file=sys.stderr) + result = { + "success": False, + "error": f"Invalid problem input: {problem_input}. Use either problem ID (e.g., 1068) or full URL", + "problem_id": problem_input if problem_input.isdigit() else None, + } + print(json.dumps(result)) sys.exit(1) tests = scrape(url) + problem_id = ( + problem_input if problem_input.isdigit() else problem_input.split("/")[-1] + ) + if not tests: - print(f"No tests found for {problem_input}", file=sys.stderr) + result = { + "success": False, + "error": f"No tests found for {problem_input}", + "problem_id": problem_id, + "url": url, + } + print(json.dumps(result)) sys.exit(1) - print("---INPUT---") - print(len(tests)) + test_cases = [] for input_data, output_data in tests: - print(input_data) - print("---OUTPUT---") - for input_data, output_data in tests: - print(output_data) - print("---END---") + test_cases.append({"input": input_data, "output": output_data}) + + result = { + "success": True, + "problem_id": problem_id, + "url": url, + "test_cases": test_cases, + } + + print(json.dumps(result)) if __name__ == "__main__":