From e10eab22d6aeff0c6229fc6a3835eed9bd9d8cdc Mon Sep 17 00:00:00 2001 From: Barrett Ruth Date: Mon, 15 Sep 2025 10:37:40 -0400 Subject: [PATCH] fix: revert multiple test cases --- after/ftplugin/cp-test.lua | 45 ---------- doc/cp.txt | 4 +- lua/cp/execute.lua | 91 ------------------- lua/cp/init.lua | 179 +++++++------------------------------ lua/cp/scrape.lua | 26 +----- plugin/cp.lua | 2 +- scrapers/codeforces.py | 33 +++---- 7 files changed, 51 insertions(+), 329 deletions(-) delete mode 100644 after/ftplugin/cp-test.lua diff --git a/after/ftplugin/cp-test.lua b/after/ftplugin/cp-test.lua deleted file mode 100644 index faed7af..0000000 --- a/after/ftplugin/cp-test.lua +++ /dev/null @@ -1,45 +0,0 @@ -vim.opt_local.number = false -vim.opt_local.relativenumber = false -vim.opt_local.statuscolumn = "" -vim.opt_local.signcolumn = "no" -vim.opt_local.wrap = false -vim.opt_local.linebreak = false -vim.opt_local.foldmethod = "marker" -vim.opt_local.foldmarker = "{{{,}}}" -vim.opt_local.foldlevel = 0 -vim.opt_local.foldtext = "" - -local function get_test_id_from_line() - local line = vim.api.nvim_get_current_line() - local test_id = line:match("%[.%] Test (%d+)") - return test_id and tonumber(test_id) -end - -local function toggle_test() - local test_id = get_test_id_from_line() - if not test_id then - return - end - - local cp = require("cp") - cp.toggle_test(test_id) -end - -local function run_single_test() - local test_id = get_test_id_from_line() - if not test_id then - return - end - - local cp = require("cp") - cp.run_single_test(test_id) -end - -local function run_all_enabled_tests() - local cp = require("cp") - cp.run_all_enabled_tests() -end - -vim.keymap.set("n", "t", toggle_test, { buffer = true, desc = "Toggle test enabled/disabled" }) -vim.keymap.set("n", "r", run_single_test, { buffer = true, desc = "Run single test" }) -vim.keymap.set("n", "R", run_all_enabled_tests, { buffer = true, desc = "Run all enabled tests" }) diff --git a/doc/cp.txt b/doc/cp.txt index 7b848ab..306c70a 100644 --- a/doc/cp.txt +++ b/doc/cp.txt @@ -52,8 +52,8 @@ Action Commands ~ :CP debug Compile with debug flags and run current problem. Includes sanitizers and debug symbols. -:CP test Open enhanced test viewer showing individual - test case results with pass/fail status. +:CP diff Enter diff mode to compare actual vs expected + output. Run again to exit diff mode. Navigation Commands ~ diff --git a/lua/cp/execute.lua b/lua/cp/execute.lua index 186dd8b..233363e 100644 --- a/lua/cp/execute.lua +++ b/lua/cp/execute.lua @@ -203,96 +203,5 @@ function M.run_problem(ctx, contest_config, is_debug) end end -function M.run_individual_tests(ctx, test_cases, contest_config, is_debug) - ensure_directories() - - if not test_cases or #test_cases == 0 then - return {} - end - - logger.log(("running %d individual tests"):format(#test_cases)) - - local language = get_language_from_file(ctx.source_file) - local language_config = contest_config[language] - - if not language_config then - return { - compile_error = "Error: No configuration for language: " .. language, - results = {}, - } - end - - local substitutions = { - source = ctx.source_file, - binary = ctx.binary_file, - version = tostring(language_config.version or ""), - } - - local compile_cmd = is_debug and language_config.debug or language_config.compile - if compile_cmd then - local compile_result = compile_generic(language_config, substitutions) - if compile_result.code ~= 0 then - return { - compile_error = compile_result.stderr, - results = {}, - } - end - end - - local run_cmd = build_command(language_config.run, language_config.executable, substitutions) - - local results = {} - for i, test_case in ipairs(test_cases) do - local exec_result = execute_command(run_cmd, test_case.input, contest_config.timeout_ms) - - local actual_lines = vim.split(exec_result.stdout, "\n") - while #actual_lines > 0 and actual_lines[#actual_lines] == "" do - table.remove(actual_lines) - end - - local expected_lines = vim.split(test_case.output, "\n") - while #expected_lines > 0 and expected_lines[#expected_lines] == "" do - table.remove(expected_lines) - end - - local matches = #actual_lines == #expected_lines - if matches then - for j, line in ipairs(actual_lines) do - if line ~= expected_lines[j] then - matches = false - break - end - end - end - - table.insert(results, { - id = i, - status = exec_result.code == 0 and (matches and "PASS" or "FAIL") or "ERROR", - time_ms = exec_result.time_ms, - input = test_case.input, - expected = test_case.output, - actual = exec_result.stdout, - exit_code = exec_result.code, - timed_out = exec_result.timed_out, - enabled = true, - }) - end - - local passed = 0 - local total_time = 0 - for _, result in ipairs(results) do - if result.status == "PASS" then - passed = passed + 1 - end - total_time = total_time + result.time_ms - end - - logger.log(("test results: %d/%d passed, total execution time %.1fms"):format(passed, #results, total_time)) - - return { - compile_error = nil, - results = results, - } -end return M diff --git a/lua/cp/init.lua b/lua/cp/init.lua index 8e1e884..149a879 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -32,7 +32,7 @@ local state = { } local platforms = { "atcoder", "codeforces", "cses" } -local actions = { "run", "debug", "test", "next", "prev" } +local actions = { "run", "debug", "diff", "next", "prev" } local function get_current_problem_key() if not state.platform or not state.contest_id then @@ -232,91 +232,46 @@ local function debug_problem() end) end -local function test_problem() +local function diff_problem() + if state.diff_mode then + vim.cmd.diffoff() + if state.saved_session then + vim.fn.delete(state.saved_session) + state.saved_session = nil + end + if state.temp_output then + vim.fn.delete(state.temp_output) + state.temp_output = nil + end + state.diff_mode = false + return + end + local problem_id = get_current_problem() if not problem_id then return end - logger.log(("opening test viewer for problem: %s"):format(problem_id)) - - if not state.test_cases then - logger.log("No test case data available. Try scraping the problem first.", vim.log.levels.ERROR) - return - end - local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) - local contest_config = config.contests[state.platform] - local test_results = execute.run_individual_tests(ctx, state.test_cases, contest_config, false) - - if test_results.compile_error then - logger.log("Compilation failed: " .. test_results.compile_error, vim.log.levels.ERROR) + if vim.fn.filereadable(ctx.expected_file) == 0 then + logger.log("no expected output file found", vim.log.levels.WARN) return end - local buf_name = ("cp-test://%s"):format(problem_id) - local existing_buf = vim.fn.bufnr(buf_name) - local buf - - if existing_buf ~= -1 then - buf = existing_buf - else - buf = vim.api.nvim_create_buf(false, true) - vim.api.nvim_buf_set_name(buf, buf_name) + if vim.fn.filereadable(ctx.output_file) == 0 then + logger.log("no output file found. run the problem first", vim.log.levels.WARN) + return end - local lines = {} - local passed = 0 - local total = #test_results.results + state.saved_session = vim.fn.tempname() + vim.cmd(("mksession! %s"):format(state.saved_session)) - local test_states = get_test_states() - - for _, result in ipairs(test_results.results) do - local status_icon = result.status == "PASS" and "✓" or "✗" - local enabled_icon = test_states[result.id] and "[x]" or "[ ]" - local time_str = ("%.1fms"):format(result.time_ms) - - table.insert( - lines, - ("%s Test %d %s %s (%s) {{{"):format(enabled_icon, result.id, status_icon, result.status, time_str) - ) - table.insert(lines, " Input:") - for _, line in ipairs(vim.split(result.input, "\n")) do - table.insert(lines, " " .. line) - end - - if result.status == "PASS" then - table.insert(lines, " Output:") - for _, line in ipairs(vim.split(result.actual, "\n")) do - table.insert(lines, " " .. line) - end - passed = passed + 1 - else - table.insert(lines, " Expected:") - for _, line in ipairs(vim.split(result.expected, "\n")) do - table.insert(lines, " " .. line) - end - table.insert(lines, " Got:") - for _, line in ipairs(vim.split(result.actual, "\n")) do - table.insert(lines, " " .. line) - end - end - - table.insert(lines, "}}}") - table.insert(lines, "") - end - - table.insert(lines, ("Summary: %d/%d passed"):format(passed, total)) - - vim.api.nvim_buf_set_lines(buf, 0, -1, false, lines) - vim.bo[buf].filetype = "cp-test" - vim.bo[buf].modifiable = false - - vim.cmd.split() - vim.api.nvim_set_current_buf(buf) - - logger.log(("Test results: %d/%d passed"):format(passed, total)) + vim.cmd("silent only") + vim.cmd(("edit %s"):format(ctx.expected_file)) + vim.cmd.diffthis() + vim.cmd(("vertical diffsplit %s"):format(ctx.output_file)) + state.diff_mode = true end ---@param delta number 1 for next, -1 for prev @@ -424,8 +379,8 @@ function M.handle_command(opts) run_problem() elseif cmd.action == "debug" then debug_problem() - elseif cmd.action == "test" then - test_problem() + elseif cmd.action == "diff" then + diff_problem() elseif cmd.action == "next" then navigate_problem(1) elseif cmd.action == "prev" then @@ -501,80 +456,6 @@ function M.handle_command(opts) end end -function M.toggle_test(test_id) - local test_states = get_test_states() - test_states[test_id] = not test_states[test_id] - - local problem_key = get_current_problem_key() - if problem_key then - state.test_states[problem_key] = test_states - end - - test_problem() -end - -function M.run_single_test(test_id) - if not state.test_cases or not state.test_cases[test_id] then - logger.log("Test case not found", vim.log.levels.ERROR) - return - end - - local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) - local contest_config = config.contests[state.platform] - - local single_test = { state.test_cases[test_id] } - local test_results = execute.run_individual_tests(ctx, single_test, contest_config, false) - - if test_results.compile_error then - logger.log("Compilation failed: " .. test_results.compile_error, vim.log.levels.ERROR) - return - end - - local result = test_results.results[1] - if result then - logger.log(("Test %d: %s (%.1fms)"):format(test_id, result.status, result.time_ms)) - end -end - -function M.run_all_enabled_tests() - if not state.test_cases then - logger.log("No test cases available", vim.log.levels.ERROR) - return - end - - local test_states = get_test_states() - local enabled_tests = {} - - for i, test_case in ipairs(state.test_cases) do - if test_states[i] then - table.insert(enabled_tests, test_case) - end - end - - if #enabled_tests == 0 then - logger.log("No tests enabled", vim.log.levels.WARN) - return - end - - local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) - local contest_config = config.contests[state.platform] - - local test_results = execute.run_individual_tests(ctx, enabled_tests, contest_config, false) - - if test_results.compile_error then - logger.log("Compilation failed: " .. test_results.compile_error, vim.log.levels.ERROR) - return - end - - local passed = 0 - for _, result in ipairs(test_results.results) do - if result.status == "PASS" then - passed = passed + 1 - end - end - - logger.log(("Enabled tests: %d/%d passed"):format(passed, #enabled_tests)) -end function M.setup(opts) opts = opts or {} diff --git a/lua/cp/scrape.lua b/lua/cp/scrape.lua index 8b99f3b..b1c85ca 100644 --- a/lua/cp/scrape.lua +++ b/lua/cp/scrape.lua @@ -185,29 +185,11 @@ function M.scrape_problem(ctx) end if data.test_cases and #data.test_cases > 0 then - local all_inputs = {} - local all_outputs = {} + local combined_input = data.test_cases[1].input:gsub("\r", "") + local combined_output = data.test_cases[1].output:gsub("\r", "") - for i, test_case in ipairs(data.test_cases) do - local input_lines = vim.split(test_case.input:gsub("\r", ""):gsub("\n+$", ""), "\n") - local output_lines = vim.split(test_case.output:gsub("\r", ""):gsub("\n+$", ""), "\n") - - for _, line in ipairs(input_lines) do - table.insert(all_inputs, line) - end - - for _, line in ipairs(output_lines) do - table.insert(all_outputs, line) - end - - if i < #data.test_cases then - table.insert(all_inputs, "") - table.insert(all_outputs, "") - end - end - - vim.fn.writefile(all_inputs, ctx.input_file) - vim.fn.writefile(all_outputs, ctx.expected_file) + vim.fn.writefile(vim.split(combined_input, "\n"), ctx.input_file) + vim.fn.writefile(vim.split(combined_output, "\n"), ctx.expected_file) end return { diff --git a/plugin/cp.lua b/plugin/cp.lua index fc57b96..83a2816 100644 --- a/plugin/cp.lua +++ b/plugin/cp.lua @@ -4,7 +4,7 @@ end vim.g.loaded_cp = 1 local platforms = { "atcoder", "codeforces", "cses" } -local actions = { "run", "debug", "test", "next", "prev" } +local actions = { "run", "debug", "diff", "next", "prev" } vim.api.nvim_create_user_command("CP", function(opts) local cp = require("cp") diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index 6343287..9b885ce 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -19,28 +19,23 @@ def scrape(url: str) -> list[tuple[str, str]]: input_sections = soup.find_all("div", class_="input") output_sections = soup.find_all("div", class_="output") - for inp_section, out_section in zip(input_sections, output_sections): + all_inputs = [] + all_outputs = [] + + for inp_section in input_sections: inp_pre = inp_section.find("pre") + if inp_pre: + all_inputs.append(inp_pre.get_text().strip().replace("\r", "")) + + for out_section in output_sections: out_pre = out_section.find("pre") + if out_pre: + all_outputs.append(out_pre.get_text().strip().replace("\r", "")) - if inp_pre and out_pre: - input_lines: list[str] = [] - output_lines: list[str] = [] - - input_text_raw = inp_pre.get_text().strip().replace("\r", "") - input_lines = [ - line.strip() for line in input_text_raw.split("\n") if line.strip() - ] - - output_text_raw = out_pre.get_text().strip().replace("\r", "") - output_lines = [ - line.strip() for line in output_text_raw.split("\n") if line.strip() - ] - - if input_lines and output_lines: - input_text = "\n".join(input_lines) - output_text = "\n".join(output_lines) - tests.append((input_text, output_text)) + if all_inputs and all_outputs: + combined_input = "\n".join(all_inputs) + combined_output = "\n".join(all_outputs) + tests.append((combined_input, combined_output)) return tests