From d5b3c9a881ba00d3e4a5588bb3d28e04d92b0d38 Mon Sep 17 00:00:00 2001 From: Barrett Ruth Date: Mon, 15 Sep 2025 22:39:44 -0400 Subject: [PATCH] feat: basic test mode runner --- lua/cp/execute.lua | 40 +++++++++++-- lua/cp/init.lua | 137 +++++++++++++++++++++++++++------------------ lua/cp/scrape.lua | 60 +++++++++++++++++--- lua/cp/test.lua | 57 +++++++++++++++++-- readme.md | 2 - 5 files changed, 223 insertions(+), 73 deletions(-) diff --git a/lua/cp/execute.lua b/lua/cp/execute.lua index 1b52a71..d341489 100644 --- a/lua/cp/execute.lua +++ b/lua/cp/execute.lua @@ -89,7 +89,7 @@ end ---@param language_config table ---@param substitutions table ---@return {code: integer, stderr: string} -local function compile_generic(language_config, substitutions) +function M.compile_generic(language_config, substitutions) vim.validate({ language_config = { language_config, "table" }, substitutions = { substitutions, "table" }, @@ -210,8 +210,40 @@ local function format_output(exec_result, expected_file, is_debug) end ---@param ctx ProblemContext ----@param contest_config table ----@param is_debug boolean +---@param contest_config ContestConfig +---@return boolean success +function M.compile_problem(ctx, contest_config) + vim.validate({ + ctx = { ctx, "table" }, + contest_config = { contest_config, "table" }, + }) + + local language = get_language_from_file(ctx.source_file, contest_config) + local language_config = contest_config[language] + + if not language_config then + logger.log("No configuration for language: " .. language, vim.log.levels.ERROR) + return false + end + + local substitutions = { + source = ctx.source_file, + binary = ctx.binary_file, + version = tostring(language_config.version), + } + + if language_config.compile then + local compile_result = M.compile_generic(language_config, substitutions) + if compile_result.code ~= 0 then + logger.log("compilation failed: " .. (compile_result.stderr or "unknown error"), vim.log.levels.ERROR) + return false + end + logger.log("compilation successful") + end + + return true +end + function M.run_problem(ctx, contest_config, is_debug) vim.validate({ ctx = { ctx, "table" }, @@ -237,7 +269,7 @@ function M.run_problem(ctx, contest_config, is_debug) local compile_cmd = is_debug and language_config.debug or language_config.compile if compile_cmd then - local compile_result = compile_generic(language_config, substitutions) + local compile_result = M.compile_generic(language_config, substitutions) if compile_result.code ~= 0 then vim.fn.writefile({ compile_result.stderr }, ctx.output_file) return diff --git a/lua/cp/init.lua b/lua/cp/init.lua index c72ad4c..eb78916 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -228,6 +228,12 @@ local function toggle_test_panel() return end + local execute = require("cp.execute") + local contest_config = config.contests[state.platform] + if not execute.compile_problem(ctx, contest_config) then + return + end + state.saved_session = vim.fn.tempname() vim.cmd(("mksession! %s"):format(state.saved_session)) @@ -239,8 +245,10 @@ local function toggle_test_panel() vim.bo.bufhidden = "wipe" local function navigate_test(delta) + logger.log(("navigating test by %d"):format(delta)) local test_state = test_module.get_test_panel_state() local new_index = test_state.current_index + delta + logger.log(("current: %d, new: %d, total: %d"):format(test_state.current_index, new_index, #test_state.test_cases)) if new_index >= 1 and new_index <= #test_state.test_cases then test_state.current_index = new_index toggle_test_panel() @@ -248,74 +256,97 @@ local function toggle_test_panel() end end - local function run_current_test() + local function refresh_test_panel() + if not test_buf or not vim.api.nvim_buf_is_valid(test_buf) then + return + end + + local test_state = test_module.get_test_panel_state() + local test_lines = {} + + for i, test_case in ipairs(test_state.test_cases) do + local status_text = string.upper(test_case.status) + if test_case.status == "timeout" then + status_text = "TIMEOUT" + end + local prefix = i == test_state.current_index and "> " or " " + local line = string.format("%s%d %s", prefix, i, status_text) + table.insert(test_lines, line) + end + + if test_state.test_cases[test_state.current_index] then + local current_test = test_state.test_cases[test_state.current_index] + table.insert(test_lines, "") + table.insert(test_lines, string.format("── Test %d ──", test_state.current_index)) + + table.insert(test_lines, "Input:") + for _, line in ipairs(vim.split(current_test.input, "\n", { plain = true, trimempty = true })) do + table.insert(test_lines, line) + end + + table.insert(test_lines, "Expected:") + for _, line in ipairs(vim.split(current_test.expected, "\n", { plain = true, trimempty = true })) do + table.insert(test_lines, line) + end + + if current_test.actual then + table.insert(test_lines, "Actual:") + for _, line in ipairs(vim.split(current_test.actual, "\n", { plain = true, trimempty = true })) do + table.insert(test_lines, line) + end + end + end + + table.insert(test_lines, "") + table.insert(test_lines, "[j/k] Navigate [Enter] Run all tests [q] Close") + + vim.api.nvim_buf_set_lines(test_buf, 0, -1, false, test_lines) + end + + local function navigate_test_case(delta) + local test_state = test_module.get_test_panel_state() + if #test_state.test_cases == 0 then + return + end + + test_state.current_index = test_state.current_index + delta + if test_state.current_index < 1 then + test_state.current_index = #test_state.test_cases + elseif test_state.current_index > #test_state.test_cases then + test_state.current_index = 1 + end + + refresh_test_panel() + end + + local function run_all_tests() local problem_ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) local contest_config = config.contests[state.platform] local test_state = test_module.get_test_panel_state() - test_module.run_test_case(problem_ctx, contest_config, test_state.current_index) - toggle_test_panel() - toggle_test_panel() + + if test_state.test_cases and #test_state.test_cases > 0 then + test_module.run_all_test_cases(problem_ctx, contest_config) + refresh_test_panel() + end end vim.keymap.set("n", "j", function() - navigate_test(1) + navigate_test_case(1) end, { buffer = test_buf, silent = true }) vim.keymap.set("n", "k", function() - navigate_test(-1) + navigate_test_case(-1) + end, { buffer = test_buf, silent = true }) + vim.keymap.set("n", "", function() + run_all_tests() end, { buffer = test_buf, silent = true }) - vim.keymap.set("n", "", run_current_test, { buffer = test_buf, silent = true }) vim.keymap.set("n", "q", function() toggle_test_panel() end, { buffer = test_buf, silent = true }) - local test_state = test_module.get_test_panel_state() - local test_lines = {} - - for i, test_case in ipairs(test_state.test_cases) do - local status_icon = "?" - local status_text = "PENDING" - - if test_case.status == "pass" then - status_icon = "✓" - status_text = "PASS" - elseif test_case.status == "fail" then - status_icon = "✗" - status_text = "FAIL" - end - - local time_text = test_case.time_ms and string.format("%.0fms", test_case.time_ms) or "" - local prefix = i == test_state.current_index and "> " or " " - - table.insert(test_lines, string.format("%s%d %s %s %s", prefix, i, status_icon, status_text, time_text)) - end - - table.insert(test_lines, "") - - local current_test = test_state.test_cases[test_state.current_index] - if current_test then - table.insert(test_lines, string.format("── Test %d ──", test_state.current_index)) - table.insert(test_lines, "Input:") - for _, line in ipairs(vim.split(current_test.input, "\n", { plain = true, trimempty = true })) do - table.insert(test_lines, line) - end - - table.insert(test_lines, "Expected:") - for _, line in ipairs(vim.split(current_test.expected, "\n", { plain = true, trimempty = true })) do - table.insert(test_lines, line) - end - - if current_test.actual then - table.insert(test_lines, "Actual:") - for _, line in ipairs(vim.split(current_test.actual, "\n", { plain = true, trimempty = true })) do - table.insert(test_lines, line) - end - end - end - - vim.api.nvim_buf_set_lines(test_buf, 0, -1, false, test_lines) - vim.bo.modifiable = false + refresh_test_panel() state.test_panel_active = true + local test_state = test_module.get_test_panel_state() logger.log(string.format("test panel opened (%d test cases)", #test_state.test_cases)) end diff --git a/lua/cp/scrape.lua b/lua/cp/scrape.lua index 291a64d..9352a69 100644 --- a/lua/cp/scrape.lua +++ b/lua/cp/scrape.lua @@ -148,10 +148,34 @@ function M.scrape_problem(ctx) ensure_io_directory() if vim.fn.filereadable(ctx.input_file) == 1 and vim.fn.filereadable(ctx.expected_file) == 1 then + local base_name = vim.fn.fnamemodify(ctx.input_file, ":r") + local test_cases = {} + local i = 1 + + while true do + local input_file = base_name .. "." .. i .. ".cpin" + local expected_file = base_name .. "." .. i .. ".cpout" + + if vim.fn.filereadable(input_file) == 1 and vim.fn.filereadable(expected_file) == 1 then + local input_content = table.concat(vim.fn.readfile(input_file), "\n") + local expected_content = table.concat(vim.fn.readfile(expected_file), "\n") + + table.insert(test_cases, { + index = i, + input = input_content, + output = expected_content + }) + i = i + 1 + else + break + end + end + return { success = true, problem_id = ctx.problem_name, - test_count = 1, + test_count = #test_cases, + test_cases = test_cases, } end @@ -204,6 +228,7 @@ function M.scrape_problem(ctx) timeout = 30000, }):wait() + if result.code ~= 0 then return { success = false, @@ -221,24 +246,41 @@ function M.scrape_problem(ctx) } end + if not data.success then return data end - if data.combined then - local combined_input = data.combined.input:gsub("\r", "") - local combined_output = data.combined.output:gsub("\r", "") + if data.test_cases and #data.test_cases > 0 then + local base_name = vim.fn.fnamemodify(ctx.input_file, ":r") - vim.fn.writefile(vim.split(combined_input, "\n", true), ctx.input_file) - vim.fn.writefile(vim.split(combined_output, "\n", true), ctx.expected_file) - elseif data.test_cases and #data.test_cases > 0 then - local combined_input = data.test_cases[1].input:gsub("\r", "") - local combined_output = data.test_cases[1].output:gsub("\r", "") + for i, test_case in ipairs(data.test_cases) do + local input_file = base_name .. "." .. i .. ".cpin" + local expected_file = base_name .. "." .. i .. ".cpout" + + local input_content = test_case.input:gsub("\r", "") + local expected_content = test_case.output:gsub("\r", "") + + if ctx.contest == "atcoder" then + input_content = "1\n" .. input_content + end + + vim.fn.writefile(vim.split(input_content, "\n", true), input_file) + vim.fn.writefile(vim.split(expected_content, "\n", true), expected_file) + end + + local combined_input = data.combined and data.combined.input:gsub("\r", "") or table.concat(vim.tbl_map(function(tc) return tc.input end, data.test_cases), "\n") + local combined_output = data.combined and data.combined.output:gsub("\r", "") or table.concat(vim.tbl_map(function(tc) return tc.output end, data.test_cases), "\n") + + if ctx.contest == "atcoder" then + combined_input = tostring(#data.test_cases) .. "\n" .. combined_input + end vim.fn.writefile(vim.split(combined_input, "\n", true), ctx.input_file) vim.fn.writefile(vim.split(combined_output, "\n", true), ctx.expected_file) end + return { success = true, problem_id = ctx.problem_name, diff --git a/lua/cp/test.lua b/lua/cp/test.lua index 5bc546a..8468717 100644 --- a/lua/cp/test.lua +++ b/lua/cp/test.lua @@ -2,10 +2,11 @@ ---@field index number ---@field input string ---@field expected string ----@field status "pending"|"pass"|"fail"|"running" +---@field status "pending"|"pass"|"fail"|"running"|"timeout" ---@field actual string? ---@field time_ms number? ---@field error string? +---@field selected boolean ---@class TestPanelState ---@field test_cases TestCase[] @@ -41,6 +42,7 @@ local function create_test_case(index, input, expected) actual = nil, time_ms = nil, error = nil, + selected = true, } end @@ -75,10 +77,32 @@ local function parse_test_cases_from_files(input_file, expected_file) return {} end - local input_content = table.concat(vim.fn.readfile(input_file), "\n") - local expected_content = table.concat(vim.fn.readfile(expected_file), "\n") + local base_name = vim.fn.fnamemodify(input_file, ":r") + local test_cases = {} + local i = 1 - return { create_test_case(1, input_content, expected_content) } + while true do + local individual_input_file = base_name .. "." .. i .. ".cpin" + local individual_expected_file = base_name .. "." .. i .. ".cpout" + + if vim.fn.filereadable(individual_input_file) == 1 and vim.fn.filereadable(individual_expected_file) == 1 then + local input_content = table.concat(vim.fn.readfile(individual_input_file), "\n") + local expected_content = table.concat(vim.fn.readfile(individual_expected_file), "\n") + + table.insert(test_cases, create_test_case(i, input_content, expected_content)) + i = i + 1 + else + break + end + end + + if #test_cases == 0 then + local input_content = table.concat(vim.fn.readfile(input_file), "\n") + local expected_content = table.concat(vim.fn.readfile(expected_file), "\n") + return { create_test_case(1, input_content, expected_content) } + end + + return test_cases end ---@param ctx ProblemContext @@ -126,6 +150,20 @@ local function run_single_test_case(ctx, contest_config, test_case) version = tostring(language_config.version or ""), } + if language_config.compile and vim.fn.filereadable(ctx.binary_file) == 0 then + logger.log("binary not found, compiling first...") + local compile_cmd = substitute_template(language_config.compile, substitutions) + local compile_result = vim.system(compile_cmd, { text = true }):wait() + if compile_result.code ~= 0 then + return { + status = "fail", + actual = "", + error = "Compilation failed: " .. (compile_result.stderr or "Unknown error"), + time_ms = 0, + } + end + end + local run_cmd = build_command(language_config.run, language_config.executable, substitutions) local start_time = vim.uv.hrtime() @@ -140,8 +178,17 @@ local function run_single_test_case(ctx, contest_config, test_case) local expected_output = test_case.expected:gsub("\n$", "") local matches = actual_output == expected_output + local status + if result.code == 143 or result.code == 124 then + status = "timeout" + elseif result.code == 0 and matches then + status = "pass" + else + status = "fail" + end + return { - status = result.code == 0 and matches and "pass" or "fail", + status = status, actual = actual_output, error = result.code ~= 0 and result.stderr or nil, time_ms = execution_time, diff --git a/readme.md b/readme.md index 6c17b50..121fd4b 100644 --- a/readme.md +++ b/readme.md @@ -71,5 +71,3 @@ follows: - test case management - new video with functionality, notify discord members - note that codeforces support is scuffed: https://codeforces.com/blog/entry/146423 -- codeforces: use round number & api not the contest id - - problems: api config