diff --git a/doc/cp.txt b/doc/cp.txt index 6f1a680..59bd23a 100644 --- a/doc/cp.txt +++ b/doc/cp.txt @@ -92,6 +92,13 @@ example, with lazy.nvim (https://github.com/folke/lazy.nvim): vim.api.nvim_set_current_buf(input_buf) vim.cmd('wincmd h | wincmd h') end, + filename = function(contest, problem_id, problem_letter) + if contest == "atcoder" then + return problem_id:lower() .. (problem_letter or "") .. ".cpp" + else + return problem_id:lower() .. (problem_letter or "") .. ".cc" + end + end, }) end } @@ -122,6 +129,10 @@ tile Custom function to arrange windows function(source_buf, input_buf, output_buf) (default: nil, uses built-in layout) +filename Custom function to generate filenames + function(contest, problem_id, problem_letter) + (default: nil, uses problem_id + letter + ".cc") + WORKFLOW *cp-workflow* 1. Set up contest environment: > diff --git a/lua/cp/config.lua b/lua/cp/config.lua index b0e369d..59f409d 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -4,6 +4,7 @@ ---@field hooks table ---@field debug boolean ---@field tile? fun(source_buf: number, input_buf: number, output_buf: number) +---@field filename? fun(contest: string, problem_id: string, problem_letter?: string): string local M = {} @@ -33,6 +34,7 @@ M.defaults = { }, debug = false, tile = nil, + filename = nil, } ---@param base_config table @@ -62,6 +64,7 @@ function M.setup(user_config) hooks = { user_config.hooks, { "table", "nil" }, true }, debug = { user_config.debug, { "boolean", "nil" }, true }, tile = { user_config.tile, { "function", "nil" }, true }, + filename = { user_config.filename, { "function", "nil" }, true }, }) if user_config.hooks then @@ -84,4 +87,16 @@ function M.setup(user_config) return config end +local function default_filename(contest, problem_id, problem_letter) + local full_problem_id = problem_id:lower() + if contest == "atcoder" or contest == "codeforces" then + if problem_letter then + full_problem_id = full_problem_id .. problem_letter:lower() + end + end + return full_problem_id .. ".cc" +end + +M.default_filename = default_filename + return M diff --git a/lua/cp/execute.lua b/lua/cp/execute.lua index e37c296..3887a61 100644 --- a/lua/cp/execute.lua +++ b/lua/cp/execute.lua @@ -10,16 +10,6 @@ local signal_codes = { [139] = "SIGTERM", } -local function get_paths(problem_id) - return { - source = ("%s.cc"):format(problem_id), - binary = ("build/%s"):format(problem_id), - input = ("io/%s.in"):format(problem_id), - output = ("io/%s.out"):format(problem_id), - expected = ("io/%s.expected"):format(problem_id), - } -end - local function ensure_directories() vim.system({ "mkdir", "-p", "build", "io" }):wait() end @@ -85,27 +75,29 @@ local function format_output(exec_result, expected_file, is_debug) return table.concat(lines, "") end -function M.run_problem(problem_id, contest_config, is_debug) +---@param ctx ProblemContext +---@param contest_config table +---@param is_debug boolean +function M.run_problem(ctx, contest_config, is_debug) ensure_directories() - local paths = get_paths(problem_id) local flags = is_debug and contest_config.debug_flags or contest_config.compile_flags - local compile_result = compile_cpp(paths.source, paths.binary, flags) + local compile_result = compile_cpp(ctx.source_file, ctx.binary_file, flags) if compile_result.code ~= 0 then - vim.fn.writefile({ compile_result.stderr }, paths.output) + vim.fn.writefile({ compile_result.stderr }, ctx.output_file) return end local input_data = "" - if vim.fn.filereadable(paths.input) == 1 then - input_data = table.concat(vim.fn.readfile(paths.input), "\n") .. "\n" + if vim.fn.filereadable(ctx.input_file) == 1 then + input_data = table.concat(vim.fn.readfile(ctx.input_file), "\n") .. "\n" end - local exec_result = execute_binary(paths.binary, input_data, contest_config.timeout_ms) - local formatted_output = format_output(exec_result, paths.expected, is_debug) + local exec_result = execute_binary(ctx.binary_file, input_data, contest_config.timeout_ms) + local formatted_output = format_output(exec_result, ctx.expected_file, is_debug) - vim.fn.writefile(vim.split(formatted_output, "\n"), paths.output) + vim.fn.writefile(vim.split(formatted_output, "\n"), ctx.output_file) end return M diff --git a/lua/cp/init.lua b/lua/cp/init.lua index 73de220..4d26bb6 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -4,6 +4,7 @@ local execute = require("cp.execute") local scrape = require("cp.scrape") local window = require("cp.window") local logger = require("cp.log") +local problem = require("cp.problem") local M = {} local config = {} @@ -51,6 +52,8 @@ local function setup_problem(problem_id, problem_letter) vim.cmd("silent only") + local ctx = problem.create_context(vim.g.cp_contest, problem_id, problem_letter, config) + local scrape_result = scrape.scrape_problem(vim.g.cp_contest, problem_id, problem_letter) if not scrape_result.success then @@ -60,16 +63,7 @@ local function setup_problem(problem_id, problem_letter) logger.log(("scraped %d test case(s) for %s"):format(scrape_result.test_count, scrape_result.problem_id)) end - local full_problem_id = scrape_result.success and scrape_result.problem_id - or ( - (vim.g.cp_contest == "atcoder" or vim.g.cp_contest == "codeforces") - and problem_letter - and problem_id .. problem_letter:upper() - or problem_id - ) - local filename = full_problem_id .. ".cc" - - vim.cmd.e(filename) + vim.cmd.e(ctx.source_file) if vim.api.nvim_buf_get_lines(0, 0, -1, true)[1] == "" then local has_luasnip, luasnip = pcall(require, "luasnip") @@ -97,18 +91,14 @@ local function setup_problem(problem_id, problem_letter) vim.diagnostic.enable(false) - local base_fp = vim.fn.fnamemodify(filename, ":p:h") - local input_file = ("%s/io/%s.in"):format(base_fp, full_problem_id) - local output_file = ("%s/io/%s.out"):format(base_fp, full_problem_id) - local source_buf = vim.api.nvim_get_current_buf() - local input_buf = vim.fn.bufnr(input_file, true) - local output_buf = vim.fn.bufnr(output_file, true) + local input_buf = vim.fn.bufnr(ctx.input_file, true) + local output_buf = vim.fn.bufnr(ctx.output_file, true) local tile_fn = config.tile or window.default_tile tile_fn(source_buf, input_buf, output_buf) - logger.log(("switched to problem %s"):format(full_problem_id)) + logger.log(("switched to problem %s"):format(ctx.problem_name)) end local function get_current_problem() @@ -138,7 +128,8 @@ local function run_problem() local contest_config = config.contests[vim.g.cp_contest] vim.schedule(function() - execute.run_problem(problem_id, contest_config, false) + local ctx = problem.create_context(vim.g.cp_contest, problem_id, nil, config) + execute.run_problem(ctx, contest_config, false) vim.cmd.checktime() end) end @@ -161,7 +152,8 @@ local function debug_problem() local contest_config = config.contests[vim.g.cp_contest] vim.schedule(function() - execute.run_problem(problem_id, contest_config, true) + local ctx = problem.create_context(vim.g.cp_contest, problem_id, nil, config) + execute.run_problem(ctx, contest_config, true) vim.cmd.checktime() end) end @@ -179,22 +171,19 @@ local function diff_problem() return end - local base_fp = vim.fn.getcwd() - local output = ("%s/io/%s.out"):format(base_fp, problem_id) - local expected = ("%s/io/%s.expected"):format(base_fp, problem_id) - local input = ("%s/io/%s.in"):format(base_fp, problem_id) + local ctx = problem.create_context(vim.g.cp_contest, problem_id, nil, config) - if vim.fn.filereadable(expected) == 0 then - logger.log(("No expected output file found: %s"):format(expected), vim.log.levels.ERROR) + if vim.fn.filereadable(ctx.expected_file) == 0 then + logger.log(("No expected output file found: %s"):format(ctx.expected_file), vim.log.levels.ERROR) return end vim.g.cp_saved_layout = window.save_layout() - local result = vim.system({ "awk", "/^\\[[^]]*\\]:/ {exit} {print}", output }, { text = true }):wait() + local result = vim.system({ "awk", "/^\\[[^]]*\\]:/ {exit} {print}", ctx.output_file }, { text = true }):wait() local actual_output = result.stdout - window.setup_diff_layout(actual_output, expected, input) + window.setup_diff_layout(actual_output, ctx.expected_file, ctx.input_file) vim.g.cp_diff_mode = true logger.log("entered diff mode") diff --git a/lua/cp/problem.lua b/lua/cp/problem.lua new file mode 100644 index 0000000..bc41175 --- /dev/null +++ b/lua/cp/problem.lua @@ -0,0 +1,37 @@ +---@class ProblemContext +---@field contest string Contest name (e.g. "atcoder", "codeforces") +---@field problem_id string Raw problem ID (e.g. "abc123", "1933") +---@field problem_letter? string Problem letter for AtCoder/Codeforces (e.g. "a", "b") +---@field source_file string Source filename (e.g. "abc123a.cpp") +---@field binary_file string Binary output path (e.g. "build/abc123a") +---@field input_file string Input test file path (e.g. "io/abc123a.in") +---@field output_file string Output file path (e.g. "io/abc123a.out") +---@field expected_file string Expected output path (e.g. "io/abc123a.expected") +---@field problem_name string Canonical problem identifier (e.g. "abc123a") + +local M = {} + +---@param contest string +---@param problem_id string +---@param problem_letter? string +---@param config cp.Config +---@return ProblemContext +function M.create_context(contest, problem_id, problem_letter, config) + local filename_fn = config.filename or require("cp.config").default_filename + local source_file = filename_fn(contest, problem_id, problem_letter) + local base_name = vim.fn.fnamemodify(source_file, ":t:r") + + return { + contest = contest, + problem_id = problem_id, + problem_letter = problem_letter, + source_file = source_file, + binary_file = ("build/%s"):format(base_name), + input_file = ("io/%s.in"):format(base_name), + output_file = ("io/%s.out"):format(base_name), + expected_file = ("io/%s.expected"):format(base_name), + problem_name = base_name, + } +end + +return M diff --git a/lua/cp/window.lua b/lua/cp/window.lua index 1b80f45..a62465d 100644 --- a/lua/cp/window.lua +++ b/lua/cp/window.lua @@ -41,7 +41,7 @@ function M.restore_layout(state, tile_fn) for win, win_state in pairs(state.windows) do if vim.api.nvim_win_is_valid(win) and vim.api.nvim_buf_is_valid(win_state.bufnr) then local bufname = vim.api.nvim_buf_get_name(win_state.bufnr) - if bufname:match("%.cc$") then + if not bufname:match("%.in$") and not bufname:match("%.out$") and not bufname:match("%.expected$") then problem_id = vim.fn.fnamemodify(bufname, ":t:r") break end @@ -55,7 +55,12 @@ function M.restore_layout(state, tile_fn) local base_fp = vim.fn.getcwd() local input_file = ("%s/io/%s.in"):format(base_fp, problem_id) local output_file = ("%s/io/%s.out"):format(base_fp, problem_id) - local source_file = problem_id .. ".cc" + local source_files = vim.fn.glob(problem_id .. ".*") + local source_file = source_files ~= "" and vim.split(source_files, "\n")[1] or (problem_id .. ".cc") + + if vim.fn.filereadable(source_file) == 0 then + return + end vim.cmd.edit(source_file) local source_buf = vim.api.nvim_get_current_buf()