From 1ef68a4847a6f1ed10d60b788bbc7b03e41b636d Mon Sep 17 00:00:00 2001 From: Barrett Ruth Date: Mon, 15 Sep 2025 08:11:15 -0500 Subject: [PATCH] feat: first draft of arbitrary compile mode --- after/ftplugin/cp-test.lua | 45 +++++++++++++++ lua/cp/cache.lua | 22 ++++++++ lua/cp/config.lua | 46 ++++++++++++---- lua/cp/execute.lua | 109 +++++++++++++++++++++++++++++++++++-- lua/cp/init.lua | 5 +- lua/cp/scrape.lua | 1 + 6 files changed, 212 insertions(+), 16 deletions(-) create mode 100644 after/ftplugin/cp-test.lua diff --git a/after/ftplugin/cp-test.lua b/after/ftplugin/cp-test.lua new file mode 100644 index 0000000..faed7af --- /dev/null +++ b/after/ftplugin/cp-test.lua @@ -0,0 +1,45 @@ +vim.opt_local.number = false +vim.opt_local.relativenumber = false +vim.opt_local.statuscolumn = "" +vim.opt_local.signcolumn = "no" +vim.opt_local.wrap = false +vim.opt_local.linebreak = false +vim.opt_local.foldmethod = "marker" +vim.opt_local.foldmarker = "{{{,}}}" +vim.opt_local.foldlevel = 0 +vim.opt_local.foldtext = "" + +local function get_test_id_from_line() + local line = vim.api.nvim_get_current_line() + local test_id = line:match("%[.%] Test (%d+)") + return test_id and tonumber(test_id) +end + +local function toggle_test() + local test_id = get_test_id_from_line() + if not test_id then + return + end + + local cp = require("cp") + cp.toggle_test(test_id) +end + +local function run_single_test() + local test_id = get_test_id_from_line() + if not test_id then + return + end + + local cp = require("cp") + cp.run_single_test(test_id) +end + +local function run_all_enabled_tests() + local cp = require("cp") + cp.run_all_enabled_tests() +end + +vim.keymap.set("n", "t", toggle_test, { buffer = true, desc = "Toggle test enabled/disabled" }) +vim.keymap.set("n", "r", run_single_test, { buffer = true, desc = "Run single test" }) +vim.keymap.set("n", "R", run_all_enabled_tests, { buffer = true, desc = "Run all enabled tests" }) diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 70a55c3..565afc4 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -87,4 +87,26 @@ function M.clear_contest_data(platform, contest_id) end end +function M.get_test_cases(platform, contest_id, problem_id) + local problem_key = problem_id and (contest_id .. "_" .. problem_id) or contest_id + if not cache_data[platform] or not cache_data[platform][problem_key] then + return nil + end + return cache_data[platform][problem_key].test_cases +end + +function M.set_test_cases(platform, contest_id, problem_id, test_cases) + local problem_key = problem_id and (contest_id .. "_" .. problem_id) or contest_id + if not cache_data[platform] then + cache_data[platform] = {} + end + if not cache_data[platform][problem_key] then + cache_data[platform][problem_key] = {} + end + + cache_data[platform][problem_key].test_cases = test_cases + cache_data[platform][problem_key].test_cases_cached_at = os.time() + M.save() +end + return M diff --git a/lua/cp/config.lua b/lua/cp/config.lua index 9f39c60..48a245e 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -12,19 +12,48 @@ local M = {} M.defaults = { contests = { default = { - cpp_version = 20, - compile_flags = { "-O2", "-DLOCAL", "-Wall", "-Wextra" }, - debug_flags = { "-g3", "-fsanitize=address,undefined", "-DLOCAL" }, + cpp = { + compile = { + "g++", + "-std=c++{version}", + "-O2", + "-DLOCAL", + "-Wall", + "-Wextra", + "{source}", + "-o", + "{binary}", + }, + run = { "{binary}" }, + debug = { + "g++", + "-std=c++{version}", + "-g3", + "-fsanitize=address,undefined", + "-DLOCAL", + "{source}", + "-o", + "{binary}", + }, + executable = nil, + version = 20, + }, + python = { + compile = nil, + run = { "{source}" }, + debug = { "{source}" }, + executable = "python3", + }, timeout_ms = 2000, }, atcoder = { - cpp_version = 23, + cpp = { version = 23 }, }, codeforces = { - cpp_version = 23, + cpp = { version = 23 }, }, cses = { - cpp_version = 20, + cpp = { version = 20 }, }, }, snippets = {}, @@ -42,11 +71,6 @@ M.defaults = { ---@return table local function extend_contest_config(base_config, contest_config) local result = vim.tbl_deep_extend("force", base_config, contest_config) - - local std_flag = ("-std=c++%d"):format(result.cpp_version) - result.compile_flags = vim.list_extend({ std_flag }, result.compile_flags) - result.debug_flags = vim.list_extend({ std_flag }, result.debug_flags) - return result end diff --git a/lua/cp/execute.lua b/lua/cp/execute.lua index 774346b..2182939 100644 --- a/lua/cp/execute.lua +++ b/lua/cp/execute.lua @@ -1,5 +1,43 @@ local M = {} +local filetype_to_language = { + cpp = "cpp", + cxx = "cpp", + cc = "cpp", + c = "cpp", + py = "python", + py3 = "python", + rs = "rust", + java = "java", + js = "javascript", + go = "go", +} + +local function get_language_from_file(source_file) + local extension = vim.fn.fnamemodify(source_file, ":e") + return filetype_to_language[extension] or "cpp" +end + +local function substitute_template(cmd_template, substitutions) + local result = {} + for _, arg in ipairs(cmd_template) do + local substituted = arg + for key, value in pairs(substitutions) do + substituted = substituted:gsub("{" .. key .. "}", value) + end + table.insert(result, substituted) + end + return result +end + +local function build_command(cmd_template, executable, substitutions) + local cmd = substitute_template(cmd_template, substitutions) + if executable then + table.insert(cmd, 1, executable) + end + return cmd +end + local signal_codes = { [128] = "SIGILL", [130] = "SIGINT", @@ -22,15 +60,19 @@ local function ensure_directories() vim.system({ "mkdir", "-p", "build", "io" }):wait() end -local function compile_cpp(source_path, binary_path, flags) - local compile_cmd = { "g++", unpack(flags), source_path, "-o", binary_path } +local function compile_generic(language_config, substitutions) + if not language_config.compile then + return { code = 0, stderr = "" } + end + + local compile_cmd = substitute_template(language_config.compile, substitutions) return vim.system(compile_cmd, { text = true }):wait() end -local function execute_binary(binary_path, input_data, timeout_ms) +local function execute_command(cmd, input_data, timeout_ms) local start_time = vim.loop.hrtime() - local result = vim.system({ binary_path }, { + local result = vim.system(cmd, { stdin = input_data, timeout = timeout_ms, text = true, @@ -123,4 +165,63 @@ function M.run_problem(ctx, contest_config, is_debug) end end +function M.run_individual_tests(ctx, test_cases, contest_config, is_debug) + ensure_directories() + + if not test_cases or #test_cases == 0 then + return {} + end + + local flags = is_debug and contest_config.debug_flags or contest_config.compile_flags + local compile_result = compile_cpp(ctx.source_file, ctx.binary_file, flags) + if compile_result.code ~= 0 then + return { + compile_error = compile_result.stderr, + results = {}, + } + end + + local results = {} + for i, test_case in ipairs(test_cases) do + local exec_result = execute_binary(ctx.binary_file, test_case.input, contest_config.timeout_ms) + + local actual_lines = vim.split(exec_result.stdout, "\n") + while #actual_lines > 0 and actual_lines[#actual_lines] == "" do + table.remove(actual_lines) + end + + local expected_lines = vim.split(test_case.output, "\n") + while #expected_lines > 0 and expected_lines[#expected_lines] == "" do + table.remove(expected_lines) + end + + local matches = #actual_lines == #expected_lines + if matches then + for j, line in ipairs(actual_lines) do + if line ~= expected_lines[j] then + matches = false + break + end + end + end + + table.insert(results, { + id = i, + status = exec_result.code == 0 and (matches and "PASS" or "FAIL") or "ERROR", + time_ms = exec_result.time_ms, + input = test_case.input, + expected = test_case.output, + actual = exec_result.stdout, + exit_code = exec_result.code, + timed_out = exec_result.timed_out, + enabled = true, + }) + end + + return { + compile_error = nil, + results = results, + } +end + return M diff --git a/lua/cp/init.lua b/lua/cp/init.lua index 92d7614..90abffb 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -270,7 +270,10 @@ local function test_problem() local enabled_icon = test_states[result.id] and "[x]" or "[ ]" local time_str = ("%.1fms"):format(result.time_ms) - table.insert(lines, ("%s Test %d %s %s (%s) {{{"):format(enabled_icon, result.id, status_icon, result.status, time_str)) + table.insert( + lines, + ("%s Test %d %s %s (%s) {{{"):format(enabled_icon, result.id, status_icon, result.status, time_str) + ) table.insert(lines, " Input:") for _, line in ipairs(vim.split(result.input, "\n")) do table.insert(lines, " " .. line) diff --git a/lua/cp/scrape.lua b/lua/cp/scrape.lua index 82e2e4f..138dacc 100644 --- a/lua/cp/scrape.lua +++ b/lua/cp/scrape.lua @@ -209,6 +209,7 @@ function M.scrape_problem(ctx) success = true, problem_id = ctx.problem_name, test_count = data.test_cases and #data.test_cases or 0, + test_cases = data.test_cases, url = data.url, } end