Merge pull request #32 from barrett-ruth/feat/testmode

onslaught of features
This commit is contained in:
Barrett Ruth 2025-09-15 16:59:28 +02:00 committed by GitHub
commit e806b23020
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 294 additions and 108 deletions

6
after/ftplugin/cpin.lua Normal file
View file

@ -0,0 +1,6 @@
vim.opt_local.number = false
vim.opt_local.relativenumber = false
vim.opt_local.statuscolumn = ""
vim.opt_local.signcolumn = "no"
vim.opt_local.wrap = true
vim.opt_local.linebreak = true

7
after/ftplugin/cpout.lua Normal file
View file

@ -0,0 +1,7 @@
vim.opt_local.number = false
vim.opt_local.relativenumber = false
vim.opt_local.statuscolumn = ""
vim.opt_local.signcolumn = "no"
vim.opt_local.wrap = true
vim.opt_local.linebreak = true
vim.opt_local.modifiable = false

View file

@ -9,12 +9,13 @@ cp.nvim is a competitive programming plugin that automates problem setup,
compilation, and testing workflow for online judges.
Supported platforms: AtCoder, Codeforces, CSES
Supported languages: C++, Python
REQUIREMENTS *cp-requirements*
- Neovim 0.10.0+
- uv package manager (https://docs.astral.sh/uv/)
- C++ compiler (g++/clang++)
- Language runtime/compiler (g++, python3)
Optional:
- LuaSnip for template expansion (https://github.com/L3MON4D3/LuaSnip)

6
ftdetect/cp.lua Normal file
View file

@ -0,0 +1,6 @@
vim.filetype.add({
extension = {
cpin = "cpin",
cpout = "cpout",
},
})

View file

@ -87,4 +87,26 @@ function M.clear_contest_data(platform, contest_id)
end
end
function M.get_test_cases(platform, contest_id, problem_id)
local problem_key = problem_id and (contest_id .. "_" .. problem_id) or contest_id
if not cache_data[platform] or not cache_data[platform][problem_key] then
return nil
end
return cache_data[platform][problem_key].test_cases
end
function M.set_test_cases(platform, contest_id, problem_id, test_cases)
local problem_key = problem_id and (contest_id .. "_" .. problem_id) or contest_id
if not cache_data[platform] then
cache_data[platform] = {}
end
if not cache_data[platform][problem_key] then
cache_data[platform][problem_key] = {}
end
cache_data[platform][problem_key].test_cases = test_cases
cache_data[platform][problem_key].test_cases_cached_at = os.time()
M.save()
end
return M

View file

@ -12,19 +12,48 @@ local M = {}
M.defaults = {
contests = {
default = {
cpp_version = 20,
compile_flags = { "-O2", "-DLOCAL", "-Wall", "-Wextra" },
debug_flags = { "-g3", "-fsanitize=address,undefined", "-DLOCAL" },
cpp = {
compile = {
"g++",
"-std=c++{version}",
"-O2",
"-DLOCAL",
"-Wall",
"-Wextra",
"{source}",
"-o",
"{binary}",
},
run = { "{binary}" },
debug = {
"g++",
"-std=c++{version}",
"-g3",
"-fsanitize=address,undefined",
"-DLOCAL",
"{source}",
"-o",
"{binary}",
},
executable = nil,
version = 20,
},
python = {
compile = nil,
run = { "{source}" },
debug = { "{source}" },
executable = "python3",
},
timeout_ms = 2000,
},
atcoder = {
cpp_version = 23,
cpp = { version = 23 },
},
codeforces = {
cpp_version = 23,
cpp = { version = 23 },
},
cses = {
cpp_version = 20,
cpp = { version = 20 },
},
},
snippets = {},
@ -42,11 +71,6 @@ M.defaults = {
---@return table
local function extend_contest_config(base_config, contest_config)
local result = vim.tbl_deep_extend("force", base_config, contest_config)
local std_flag = ("-std=c++%d"):format(result.cpp_version)
result.compile_flags = vim.list_extend({ std_flag }, result.compile_flags)
result.debug_flags = vim.list_extend({ std_flag }, result.debug_flags)
return result
end

View file

@ -1,4 +1,41 @@
local M = {}
local logger = require("cp.log")
local filetype_to_language = {
cpp = "cpp",
cxx = "cpp",
cc = "cpp",
c = "cpp",
py = "python",
py3 = "python",
}
local function get_language_from_file(source_file)
local extension = vim.fn.fnamemodify(source_file, ":e")
local language = filetype_to_language[extension] or "cpp"
logger.log(("detected language: %s (extension: %s)"):format(language, extension))
return language
end
local function substitute_template(cmd_template, substitutions)
local result = {}
for _, arg in ipairs(cmd_template) do
local substituted = arg
for key, value in pairs(substitutions) do
substituted = substituted:gsub("{" .. key .. "}", value)
end
table.insert(result, substituted)
end
return result
end
local function build_command(cmd_template, executable, substitutions)
local cmd = substitute_template(cmd_template, substitutions)
if executable then
table.insert(cmd, 1, executable)
end
return cmd
end
local signal_codes = {
[128] = "SIGILL",
@ -22,15 +59,34 @@ local function ensure_directories()
vim.system({ "mkdir", "-p", "build", "io" }):wait()
end
local function compile_cpp(source_path, binary_path, flags)
local compile_cmd = { "g++", unpack(flags), source_path, "-o", binary_path }
return vim.system(compile_cmd, { text = true }):wait()
local function compile_generic(language_config, substitutions)
if not language_config.compile then
logger.log("no compilation step required")
return { code = 0, stderr = "" }
end
local compile_cmd = substitute_template(language_config.compile, substitutions)
logger.log(("compiling: %s"):format(table.concat(compile_cmd, " ")))
local start_time = vim.loop.hrtime()
local result = vim.system(compile_cmd, { text = true }):wait()
local compile_time = (vim.loop.hrtime() - start_time) / 1000000
if result.code == 0 then
logger.log(("compilation successful (%.1fms)"):format(compile_time))
else
logger.log(("compilation failed (%.1fms): %s"):format(compile_time, result.stderr), vim.log.levels.WARN)
end
return result
end
local function execute_binary(binary_path, input_data, timeout_ms)
local function execute_command(cmd, input_data, timeout_ms)
logger.log(("executing: %s"):format(table.concat(cmd, " ")))
local start_time = vim.loop.hrtime()
local result = vim.system({ binary_path }, {
local result = vim.system(cmd, {
stdin = input_data,
timeout = timeout_ms,
text = true,
@ -41,6 +97,14 @@ local function execute_binary(binary_path, input_data, timeout_ms)
local actual_code = result.code or 0
if result.code == 124 then
logger.log(("execution timed out after %.1fms"):format(execution_time), vim.log.levels.WARN)
elseif actual_code ~= 0 then
logger.log(("execution failed (exit code %d, %.1fms)"):format(actual_code, execution_time), vim.log.levels.WARN)
else
logger.log(("execution successful (%.1fms)"):format(execution_time))
end
return {
stdout = result.stdout or "",
stderr = result.stderr or "",
@ -96,20 +160,36 @@ end
function M.run_problem(ctx, contest_config, is_debug)
ensure_directories()
local flags = is_debug and contest_config.debug_flags or contest_config.compile_flags
local language = get_language_from_file(ctx.source_file)
local language_config = contest_config[language]
local compile_result = compile_cpp(ctx.source_file, ctx.binary_file, flags)
if compile_result.code ~= 0 then
vim.fn.writefile({ compile_result.stderr }, ctx.output_file)
if not language_config then
vim.fn.writefile({ "Error: No configuration for language: " .. language }, ctx.output_file)
return
end
local substitutions = {
source = ctx.source_file,
binary = ctx.binary_file,
version = tostring(language_config.version or ""),
}
local compile_cmd = is_debug and language_config.debug or language_config.compile
if compile_cmd then
local compile_result = compile_generic(language_config, substitutions)
if compile_result.code ~= 0 then
vim.fn.writefile({ compile_result.stderr }, ctx.output_file)
return
end
end
local input_data = ""
if vim.fn.filereadable(ctx.input_file) == 1 then
input_data = table.concat(vim.fn.readfile(ctx.input_file), "\n") .. "\n"
end
local exec_result = execute_binary(ctx.binary_file, input_data, contest_config.timeout_ms)
local run_cmd = build_command(language_config.run, language_config.executable, substitutions)
local exec_result = execute_command(run_cmd, input_data, contest_config.timeout_ms)
local formatted_output = format_output(exec_result, ctx.expected_file, is_debug)
local output_buf = vim.fn.bufnr(ctx.output_file)

View file

@ -62,12 +62,14 @@ end
local function check_config()
vim.health.ok("Plugin ready")
if vim.g.cp and vim.g.cp.platform then
local info = vim.g.cp.platform
if vim.g.cp.contest_id then
info = info .. " " .. vim.g.cp.contest_id
if vim.g.cp.problem_id then
info = info .. " " .. vim.g.cp.problem_id
local cp = require("cp")
local context = cp.get_current_context()
if context.platform then
local info = context.platform
if context.contest_id then
info = info .. " " .. context.contest_id
if context.problem_id then
info = info .. " " .. context.problem_id
end
end
vim.health.info("Current context: " .. info)

View file

@ -14,7 +14,6 @@ if not vim.fn.has("nvim-0.10.0") then
return {}
end
vim.g.cp = vim.g.cp or {}
local user_config = {}
local config = config_module.setup(user_config)
logger.set_config(config)
@ -28,6 +27,8 @@ local state = {
saved_layout = nil,
saved_session = nil,
temp_output = nil,
test_cases = nil,
test_states = {},
}
local platforms = { "atcoder", "codeforces", "cses" }
@ -53,6 +54,9 @@ local function setup_problem(contest_id, problem_id)
return
end
local problem_name = state.platform == "cses" and contest_id or (contest_id .. (problem_id or ""))
logger.log(("setting up problem: %s"):format(problem_name))
local metadata_result = scrape.scrape_contest_metadata(state.platform, contest_id)
if not metadata_result.success then
logger.log(
@ -79,6 +83,11 @@ local function setup_problem(contest_id, problem_id)
state.contest_id = contest_id
state.problem_id = problem_id
local cached_test_cases = cache.get_test_cases(state.platform, contest_id, problem_id)
if cached_test_cases then
state.test_cases = cached_test_cases
end
local ctx = problem.create_context(state.platform, contest_id, problem_id, config)
local scrape_result = scrape.scrape_problem(ctx)
@ -86,9 +95,15 @@ local function setup_problem(contest_id, problem_id)
if not scrape_result.success then
logger.log("scraping failed: " .. (scrape_result.error or "unknown error"), vim.log.levels.WARN)
logger.log("you can manually add test cases to io/ directory", vim.log.levels.INFO)
state.test_cases = nil
else
local test_count = scrape_result.test_count or 0
logger.log(("scraped %d test case(s) for %s"):format(test_count, scrape_result.problem_id))
state.test_cases = scrape_result.test_cases
if scrape_result.test_cases then
cache.set_test_cases(state.platform, contest_id, problem_id, scrape_result.test_cases)
end
end
vim.cmd.e(ctx.source_file)
@ -144,6 +159,8 @@ local function run_problem()
return
end
logger.log(("running problem: %s"):format(problem_id))
if config.hooks and config.hooks.before_run then
config.hooks.before_run(problem_id)
end
@ -188,34 +205,44 @@ end
local function diff_problem()
if state.diff_mode then
local tile_fn = config.tile or window.default_tile
window.restore_layout(state.saved_layout, tile_fn)
vim.cmd.diffoff()
if state.saved_session then
vim.fn.delete(state.saved_session)
state.saved_session = nil
end
if state.temp_output then
vim.fn.delete(state.temp_output)
state.temp_output = nil
end
state.diff_mode = false
state.saved_layout = nil
logger.log("exited diff mode")
else
local problem_id = get_current_problem()
if not problem_id then
return
end
local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config)
if vim.fn.filereadable(ctx.expected_file) == 0 then
logger.log(("No expected output file found: %s"):format(ctx.expected_file), vim.log.levels.ERROR)
return
end
state.saved_layout = window.save_layout()
local result = vim.system({ "awk", "/^\\[[^]]*\\]:/ {exit} {print}", ctx.output_file }, { text = true }):wait()
local actual_output = result.stdout
window.setup_diff_layout(actual_output, ctx.expected_file, ctx.input_file)
state.diff_mode = true
logger.log("entered diff mode")
return
end
local problem_id = get_current_problem()
if not problem_id then
return
end
local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config)
if vim.fn.filereadable(ctx.expected_file) == 0 then
logger.log("no expected output file found", vim.log.levels.WARN)
return
end
if vim.fn.filereadable(ctx.output_file) == 0 then
logger.log("no output file found. run the problem first", vim.log.levels.WARN)
return
end
state.saved_session = vim.fn.tempname()
vim.cmd(("mksession! %s"):format(state.saved_session))
vim.cmd("silent only")
vim.cmd(("edit %s"):format(ctx.expected_file))
vim.cmd.diffthis()
vim.cmd(("vertical diffsplit %s"):format(ctx.output_file))
state.diff_mode = true
end
---@param delta number 1 for next, -1 for prev
@ -411,6 +438,14 @@ function M.setup(opts)
end
end
function M.get_current_context()
return {
platform = state.platform,
contest_id = state.contest_id,
problem_id = state.problem_id,
}
end
function M.is_initialized()
return true
end

View file

@ -27,8 +27,8 @@ function M.create_context(contest, contest_id, problem_id, config)
problem_id = problem_id,
source_file = source_file,
binary_file = ("build/%s.run"):format(base_name),
input_file = ("io/%s.in"):format(base_name),
output_file = ("io/%s.out"):format(base_name),
input_file = ("io/%s.cpin"):format(base_name),
output_file = ("io/%s.cpout"):format(base_name),
expected_file = ("io/%s.expected"):format(base_name),
problem_name = base_name,
}

View file

@ -74,9 +74,9 @@ function M.scrape_contest_metadata(platform, contest_id)
local args
if platform == "cses" then
args = { "uv", "run", scraper_path, "metadata" }
args = { "uv", "run", "--directory", plugin_path, scraper_path, "metadata" }
else
args = { "uv", "run", scraper_path, "metadata", contest_id }
args = { "uv", "run", "--directory", plugin_path, scraper_path, "metadata", contest_id }
end
local result = vim.system(args, {
@ -119,7 +119,7 @@ function M.scrape_contest_metadata(platform, contest_id)
end
---@param ctx ProblemContext
---@return {success: boolean, problem_id: string, test_count?: number, url?: string, error?: string}
---@return {success: boolean, problem_id: string, test_count?: number, test_cases?: table[], url?: string, error?: string}
function M.scrape_problem(ctx)
ensure_io_directory()
@ -152,9 +152,9 @@ function M.scrape_problem(ctx)
local args
if ctx.contest == "cses" then
args = { "uv", "run", scraper_path, "tests", ctx.contest_id }
args = { "uv", "run", "--directory", plugin_path, scraper_path, "tests", ctx.contest_id }
else
args = { "uv", "run", scraper_path, "tests", ctx.contest_id, ctx.problem_id }
args = { "uv", "run", "--directory", plugin_path, scraper_path, "tests", ctx.contest_id, ctx.problem_id }
end
local result = vim.system(args, {
@ -185,30 +185,18 @@ function M.scrape_problem(ctx)
end
if data.test_cases and #data.test_cases > 0 then
local all_inputs = {}
local all_outputs = {}
local combined_input = data.test_cases[1].input:gsub("\r", "")
local combined_output = data.test_cases[1].output:gsub("\r", "")
for _, test_case in ipairs(data.test_cases) do
local input_lines = vim.split(test_case.input:gsub("\r", ""):gsub("\n+$", ""), "\n")
local output_lines = vim.split(test_case.output:gsub("\r", ""):gsub("\n+$", ""), "\n")
for _, line in ipairs(input_lines) do
table.insert(all_inputs, line)
end
for _, line in ipairs(output_lines) do
table.insert(all_outputs, line)
end
end
vim.fn.writefile(all_inputs, ctx.input_file)
vim.fn.writefile(all_outputs, ctx.expected_file)
vim.fn.writefile(vim.split(combined_input, "\n", true), ctx.input_file)
vim.fn.writefile(vim.split(combined_output, "\n", true), ctx.expected_file)
end
return {
success = true,
problem_id = ctx.problem_name,
test_count = data.test_cases and #data.test_cases or 0,
test_cases = data.test_cases,
url = data.url,
}
end

View file

@ -22,10 +22,12 @@ end, {
local candidates = {}
vim.list_extend(candidates, platforms)
vim.list_extend(candidates, actions)
if vim.g.cp and vim.g.cp.platform and vim.g.cp.contest_id then
local cp = require("cp")
local context = cp.get_current_context()
if context.platform and context.contest_id then
local cache = require("cp.cache")
cache.load()
local contest_data = cache.get_contest_data(vim.g.cp.platform, vim.g.cp.contest_id)
local contest_data = cache.get_contest_data(context.platform, context.contest_id)
if contest_data and contest_data.problems then
for _, problem in ipairs(contest_data.problems) do
table.insert(candidates, problem.id)

View file

@ -9,9 +9,10 @@ https://private-user-images.githubusercontent.com/62671086/489116291-391976d1-c2
## Features
- Support for multiple online judges ([AtCoder](https://atcoder.jp/), [Codeforces](https://codeforces.com/), [CSES](https://cses.fi))
- Multi-language support (C++, Python)
- Automatic problem scraping and test case management
- Integrated build, run, and debug commands
- Diff mode for comparing output with expected results
- Enhanced test viewer with individual test case management
- LuaSnip integration for contest-specific snippets
## Requirements
@ -56,9 +57,14 @@ follows:
4. Submit the problem (on the remote!)
## Similar Projects
- [competitest.nvim](https://github.com/xeluxee/competitest.nvim)
## TODO
- finer-tuned problem limits (i.e. per-problem codeforces time, memory)
- better highlighting
- test case management
- USACO support
- new video with functionality, notify discord members

View file

@ -19,28 +19,35 @@ def scrape(url: str) -> list[tuple[str, str]]:
input_sections = soup.find_all("div", class_="input")
output_sections = soup.find_all("div", class_="output")
for inp_section, out_section in zip(input_sections, output_sections):
all_inputs = []
all_outputs = []
for inp_section in input_sections:
inp_pre = inp_section.find("pre")
if inp_pre:
divs = inp_pre.find_all("div")
if divs:
lines = [div.get_text().strip() for div in divs]
text = "\n".join(lines)
else:
text = inp_pre.get_text().replace("\r", "")
all_inputs.append(text)
for out_section in output_sections:
out_pre = out_section.find("pre")
if out_pre:
divs = out_pre.find_all("div")
if divs:
lines = [div.get_text().strip() for div in divs]
text = "\n".join(lines)
else:
text = out_pre.get_text().replace("\r", "")
all_outputs.append(text)
if inp_pre and out_pre:
input_lines: list[str] = []
output_lines: list[str] = []
input_text_raw = inp_pre.get_text().strip().replace("\r", "")
input_lines = [
line.strip() for line in input_text_raw.split("\n") if line.strip()
]
output_text_raw = out_pre.get_text().strip().replace("\r", "")
output_lines = [
line.strip() for line in output_text_raw.split("\n") if line.strip()
]
if input_lines and output_lines:
input_text = "\n".join(input_lines)
output_text = "\n".join(output_lines)
tests.append((input_text, output_text))
if all_inputs and all_outputs:
combined_input = "\n".join(all_inputs)
combined_output = "\n".join(all_outputs)
tests.append((combined_input, combined_output))
return tests
@ -112,7 +119,7 @@ def main() -> None:
if mode == "metadata":
if len(sys.argv) != 3:
result = {
result: dict[str, str | bool] = {
"success": False,
"error": "Usage: codeforces.py metadata <contest_id>",
}
@ -123,14 +130,14 @@ def main() -> None:
problems: list[dict[str, str]] = scrape_contest_problems(contest_id)
if not problems:
result = {
result: dict[str, str | bool] = {
"success": False,
"error": f"No problems found for contest {contest_id}",
}
print(json.dumps(result))
sys.exit(1)
result = {
result: dict[str, str | bool | list] = {
"success": True,
"contest_id": contest_id,
"problems": problems,
@ -139,7 +146,7 @@ def main() -> None:
elif mode == "tests":
if len(sys.argv) != 4:
result = {
result: dict[str, str | bool] = {
"success": False,
"error": "Usage: codeforces.py tests <contest_id> <problem_letter>",
}
@ -154,7 +161,7 @@ def main() -> None:
tests: list[tuple[str, str]] = scrape_sample_tests(url)
if not tests:
result = {
result: dict[str, str | bool] = {
"success": False,
"error": f"No tests found for {contest_id} {problem_letter}",
"problem_id": problem_id,
@ -167,7 +174,7 @@ def main() -> None:
for input_data, output_data in tests:
test_cases.append({"input": input_data, "output": output_data})
result = {
result: dict[str, str | bool | list] = {
"success": True,
"problem_id": problem_id,
"url": url,
@ -176,7 +183,7 @@ def main() -> None:
print(json.dumps(result))
else:
result = {
result: dict[str, str | bool] = {
"success": False,
"error": f"Unknown mode: {mode}. Use 'metadata' or 'tests'",
}