From 67d2a8054cbccde8356b4fc7728eb9455070b566 Mon Sep 17 00:00:00 2001 From: Barrett Ruth Date: Mon, 15 Sep 2025 07:05:31 -0500 Subject: [PATCH] feat: local state over vim.g --- lua/cp/health.lua | 14 +-- lua/cp/init.lua | 240 +++++++++++++++++++++++++++++++++++++++------- plugin/cp.lua | 8 +- readme.md | 4 + 4 files changed, 224 insertions(+), 42 deletions(-) diff --git a/lua/cp/health.lua b/lua/cp/health.lua index d902934..738e9e2 100644 --- a/lua/cp/health.lua +++ b/lua/cp/health.lua @@ -62,12 +62,14 @@ end local function check_config() vim.health.ok("Plugin ready") - if vim.g.cp and vim.g.cp.platform then - local info = vim.g.cp.platform - if vim.g.cp.contest_id then - info = info .. " " .. vim.g.cp.contest_id - if vim.g.cp.problem_id then - info = info .. " " .. vim.g.cp.problem_id + local cp = require("cp") + local context = cp.get_current_context() + if context.platform then + local info = context.platform + if context.contest_id then + info = info .. " " .. context.contest_id + if context.problem_id then + info = info .. " " .. context.problem_id end end vim.health.info("Current context: " .. info) diff --git a/lua/cp/init.lua b/lua/cp/init.lua index b86dc98..92d7614 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -14,7 +14,6 @@ if not vim.fn.has("nvim-0.10.0") then return {} end -vim.g.cp = vim.g.cp or {} local user_config = {} local config = config_module.setup(user_config) logger.set_config(config) @@ -28,10 +27,41 @@ local state = { saved_layout = nil, saved_session = nil, temp_output = nil, + test_cases = nil, + test_states = {}, } local platforms = { "atcoder", "codeforces", "cses" } -local actions = { "run", "debug", "diff", "next", "prev" } +local actions = { "run", "debug", "test", "next", "prev" } + +local function get_current_problem_key() + if not state.platform or not state.contest_id then + return nil + end + if state.platform == "cses" then + return state.contest_id + else + return state.contest_id .. "_" .. (state.problem_id or "") + end +end + +local function get_test_states() + local problem_key = get_current_problem_key() + if not problem_key then + return {} + end + + if not state.test_states[problem_key] then + state.test_states[problem_key] = {} + if state.test_cases then + for i = 1, #state.test_cases do + state.test_states[problem_key][i] = true + end + end + end + + return state.test_states[problem_key] +end local function set_platform(platform) if not vim.tbl_contains(platforms, platform) then @@ -79,6 +109,11 @@ local function setup_problem(contest_id, problem_id) state.contest_id = contest_id state.problem_id = problem_id + local cached_test_cases = cache.get_test_cases(state.platform, contest_id, problem_id) + if cached_test_cases then + state.test_cases = cached_test_cases + end + local ctx = problem.create_context(state.platform, contest_id, problem_id, config) local scrape_result = scrape.scrape_problem(ctx) @@ -86,9 +121,15 @@ local function setup_problem(contest_id, problem_id) if not scrape_result.success then logger.log("scraping failed: " .. (scrape_result.error or "unknown error"), vim.log.levels.WARN) logger.log("you can manually add test cases to io/ directory", vim.log.levels.INFO) + state.test_cases = nil else local test_count = scrape_result.test_count or 0 logger.log(("scraped %d test case(s) for %s"):format(test_count, scrape_result.problem_id)) + state.test_cases = scrape_result.test_cases + + if scrape_result.test_cases then + cache.set_test_cases(state.platform, contest_id, problem_id, scrape_result.test_cases) + end end vim.cmd.e(ctx.source_file) @@ -186,36 +227,86 @@ local function debug_problem() end) end -local function diff_problem() - if state.diff_mode then - local tile_fn = config.tile or window.default_tile - window.restore_layout(state.saved_layout, tile_fn) - state.diff_mode = false - state.saved_layout = nil - logger.log("exited diff mode") - else - local problem_id = get_current_problem() - if not problem_id then - return - end - - local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) - - if vim.fn.filereadable(ctx.expected_file) == 0 then - logger.log(("No expected output file found: %s"):format(ctx.expected_file), vim.log.levels.ERROR) - return - end - - state.saved_layout = window.save_layout() - - local result = vim.system({ "awk", "/^\\[[^]]*\\]:/ {exit} {print}", ctx.output_file }, { text = true }):wait() - local actual_output = result.stdout - - window.setup_diff_layout(actual_output, ctx.expected_file, ctx.input_file) - - state.diff_mode = true - logger.log("entered diff mode") +local function test_problem() + local problem_id = get_current_problem() + if not problem_id then + return end + + if not state.test_cases then + logger.log("No test case data available. Try scraping the problem first.", vim.log.levels.ERROR) + return + end + + local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) + local contest_config = config.contests[state.platform] + + local test_results = execute.run_individual_tests(ctx, state.test_cases, contest_config, false) + + if test_results.compile_error then + logger.log("Compilation failed: " .. test_results.compile_error, vim.log.levels.ERROR) + return + end + + local buf_name = ("cp-test://%s"):format(problem_id) + local existing_buf = vim.fn.bufnr(buf_name) + local buf + + if existing_buf ~= -1 then + buf = existing_buf + else + buf = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_name(buf, buf_name) + end + + local lines = {} + local passed = 0 + local total = #test_results.results + + local test_states = get_test_states() + + for _, result in ipairs(test_results.results) do + local status_icon = result.status == "PASS" and "✓" or "✗" + local enabled_icon = test_states[result.id] and "[x]" or "[ ]" + local time_str = ("%.1fms"):format(result.time_ms) + + table.insert(lines, ("%s Test %d %s %s (%s) {{{"):format(enabled_icon, result.id, status_icon, result.status, time_str)) + table.insert(lines, " Input:") + for _, line in ipairs(vim.split(result.input, "\n")) do + table.insert(lines, " " .. line) + end + + if result.status == "PASS" then + table.insert(lines, " Output:") + for _, line in ipairs(vim.split(result.actual, "\n")) do + table.insert(lines, " " .. line) + end + passed = passed + 1 + else + table.insert(lines, " Expected:") + for _, line in ipairs(vim.split(result.expected, "\n")) do + table.insert(lines, " " .. line) + end + table.insert(lines, " Got:") + for _, line in ipairs(vim.split(result.actual, "\n")) do + table.insert(lines, " " .. line) + end + end + + table.insert(lines, "}}}") + table.insert(lines, "") + end + + table.insert(lines, ("Summary: %d/%d passed"):format(passed, total)) + + vim.api.nvim_buf_set_lines(buf, 0, -1, false, lines) + vim.bo[buf].filetype = "cp-test" + vim.bo[buf].modifiable = false + + vim.cmd.split() + vim.api.nvim_set_current_buf(buf) + + logger.log(("Test results: %d/%d passed"):format(passed, total)) end ---@param delta number 1 for next, -1 for prev @@ -323,8 +414,8 @@ function M.handle_command(opts) run_problem() elseif cmd.action == "debug" then debug_problem() - elseif cmd.action == "diff" then - diff_problem() + elseif cmd.action == "test" then + test_problem() elseif cmd.action == "next" then navigate_problem(1) elseif cmd.action == "prev" then @@ -400,6 +491,81 @@ function M.handle_command(opts) end end +function M.toggle_test(test_id) + local test_states = get_test_states() + test_states[test_id] = not test_states[test_id] + + local problem_key = get_current_problem_key() + if problem_key then + state.test_states[problem_key] = test_states + end + + test_problem() +end + +function M.run_single_test(test_id) + if not state.test_cases or not state.test_cases[test_id] then + logger.log("Test case not found", vim.log.levels.ERROR) + return + end + + local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) + local contest_config = config.contests[state.platform] + + local single_test = { state.test_cases[test_id] } + local test_results = execute.run_individual_tests(ctx, single_test, contest_config, false) + + if test_results.compile_error then + logger.log("Compilation failed: " .. test_results.compile_error, vim.log.levels.ERROR) + return + end + + local result = test_results.results[1] + if result then + logger.log(("Test %d: %s (%.1fms)"):format(test_id, result.status, result.time_ms)) + end +end + +function M.run_all_enabled_tests() + if not state.test_cases then + logger.log("No test cases available", vim.log.levels.ERROR) + return + end + + local test_states = get_test_states() + local enabled_tests = {} + + for i, test_case in ipairs(state.test_cases) do + if test_states[i] then + table.insert(enabled_tests, test_case) + end + end + + if #enabled_tests == 0 then + logger.log("No tests enabled", vim.log.levels.WARN) + return + end + + local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) + local contest_config = config.contests[state.platform] + + local test_results = execute.run_individual_tests(ctx, enabled_tests, contest_config, false) + + if test_results.compile_error then + logger.log("Compilation failed: " .. test_results.compile_error, vim.log.levels.ERROR) + return + end + + local passed = 0 + for _, result in ipairs(test_results.results) do + if result.status == "PASS" then + passed = passed + 1 + end + end + + logger.log(("Enabled tests: %d/%d passed"):format(passed, #enabled_tests)) +end + function M.setup(opts) opts = opts or {} user_config = opts @@ -411,6 +577,14 @@ function M.setup(opts) end end +function M.get_current_context() + return { + platform = state.platform, + contest_id = state.contest_id, + problem_id = state.problem_id, + } +end + function M.is_initialized() return true end diff --git a/plugin/cp.lua b/plugin/cp.lua index 690823a..fc57b96 100644 --- a/plugin/cp.lua +++ b/plugin/cp.lua @@ -4,7 +4,7 @@ end vim.g.loaded_cp = 1 local platforms = { "atcoder", "codeforces", "cses" } -local actions = { "run", "debug", "diff", "next", "prev" } +local actions = { "run", "debug", "test", "next", "prev" } vim.api.nvim_create_user_command("CP", function(opts) local cp = require("cp") @@ -22,10 +22,12 @@ end, { local candidates = {} vim.list_extend(candidates, platforms) vim.list_extend(candidates, actions) - if vim.g.cp and vim.g.cp.platform and vim.g.cp.contest_id then + local cp = require("cp") + local context = cp.get_current_context() + if context.platform and context.contest_id then local cache = require("cp.cache") cache.load() - local contest_data = cache.get_contest_data(vim.g.cp.platform, vim.g.cp.contest_id) + local contest_data = cache.get_contest_data(context.platform, context.contest_id) if contest_data and contest_data.problems then for _, problem in ipairs(contest_data.problems) do table.insert(candidates, problem.id) diff --git a/readme.md b/readme.md index f6b43ed..5adcaec 100644 --- a/readme.md +++ b/readme.md @@ -56,6 +56,10 @@ follows: 4. Submit the problem (on the remote!) +## Similar Projects + +- [competitest.nvim](https://github.com/xeluxee/competitest.nvim) + ## TODO - finer-tuned problem limits (i.e. per-problem codeforces time, memory)