diff --git a/after/ftplugin/cptest.lua b/after/ftplugin/cptest.lua index 166db14..35831f4 100644 --- a/after/ftplugin/cptest.lua +++ b/after/ftplugin/cptest.lua @@ -2,6 +2,6 @@ vim.opt_local.number = false vim.opt_local.relativenumber = false vim.opt_local.statuscolumn = "" vim.opt_local.signcolumn = "no" -vim.opt_local.foldcolumn = "0" vim.opt_local.wrap = true vim.opt_local.linebreak = true +vim.opt_local.foldcolumn = "0" diff --git a/doc/cp.txt b/doc/cp.txt index 3cbdce5..47a64ce 100644 --- a/doc/cp.txt +++ b/doc/cp.txt @@ -49,17 +49,10 @@ Setup Commands ~ Action Commands ~ -:CP run Compile and run current problem with test input. - Shows execution time and output comparison. - Requires contest setup first. - -:CP debug Compile with debug flags and run current problem. - Includes sanitizers and debug symbols. - Requires contest setup first. - -:CP test Toggle test panel for individual test case +:CP test [--debug] Toggle test panel for individual test case debugging. Shows per-test results with three-pane layout for easy Expected/Actual comparison. + Use --debug flag to compile with debug flags Requires contest setup first. Navigation Commands ~ @@ -123,7 +116,6 @@ Optional configuration with lazy.nvim: > end, }, snippets = { ... }, -- LuaSnip snippets - tile = function(source_buf, input_buf, output_buf) ... end, filename = function(contest, contest_id, problem_id, config, language) ... end, } } @@ -139,8 +131,6 @@ Optional configuration with lazy.nvim: > during operation. • {scrapers} (`table`) Per-platform scraper control. Default enables all platforms. - • {tile}? (`function`) Custom window arrangement function. - `function(source_buf, input_buf, output_buf)` • {filename}? (`function`) Custom filename generation function. `function(contest, contest_id, problem_id, config, language)` Should return full filename with extension. @@ -170,9 +160,7 @@ Optional configuration with lazy.nvim: > *cp.Hooks* Fields: ~ - • {before_run}? (`function`) Called before `:CP run`. - `function(ctx: ProblemContext)` - • {before_debug}? (`function`) Called before `:CP debug`. + • {before_debug}? (`function`) Called before debug compilation. `function(ctx: ProblemContext)` • {setup_code}? (`function`) Called after source file is opened. Used to configure buffer settings. @@ -247,44 +235,38 @@ Example: Setting up and solving AtCoder contest ABC324 < This creates a.cc and scrapes test cases 4. Code your solution, then test: > - :CP run -< -5. If test fails, debug individual test cases: > :CP test < Navigate with j/k, run specific tests with Exit test panel with q or :CP test when done -6. If needed, compile with debug flags: > - :CP debug +5. If needed, debug with sanitizers: > + :CP test --debug < -7. Move to next problem: > +6. Move to next problem: > :CP next < This automatically sets up problem B -7. Continue solving problems with :CP next/:CP prev navigation -8. Submit solutions on AtCoder website +6. Continue solving problems with :CP next/:CP prev navigation +7. Submit solutions on AtCoder website Example: Quick setup for single Codeforces problem > :CP codeforces 1933 a " One command setup - :CP run " Test immediately + :CP test " Test immediately < TEST PANEL *cp-test* The test panel provides individual test case debugging with a three-pane layout showing test list, expected output, and actual output side-by-side. -Currently supported for AtCoder and CSES. - -Note: Codeforces is not supported due to the ambiguity of identifying -individual test case output. See https://codeforces.com/blog/entry/138406 -for ongoing efforts to resolve this. Activation ~ *:CP-test* -:CP test Toggle test panel on/off. When activated, +:CP test [--debug] Toggle test panel on/off. When activated, replaces current layout with test interface. Automatically compiles and runs all tests. - Toggle again to restore original layout. + Use --debug flag to compile with debug symbols + and sanitizers. Toggle again to restore original + layout. Interface ~ @@ -293,8 +275,6 @@ The test panel uses a three-pane layout for easy comparison: > ┌─ Test List ─────────────────────────────────────────────────┐ │ 1. PASS 12ms │ │> 2. FAIL 45ms │ - │ 3. 8ms │ - │ 4. │ │ │ │ ── Input ── │ │ 5 3 │ @@ -317,7 +297,7 @@ q Exit test panel (restore layout) Execution Details ~ Test cases are executed individually using the same compilation and -execution pipeline as |:CP-run|, but with isolated input/output for +execution pipeline, but with isolated input/output for precise failure analysis. All tests are automatically run when the panel opens. diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index a379415..3db7845 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -15,7 +15,8 @@ ---@class CachedTestCase ---@field index? number ---@field input string ----@field output string +---@field expected? string +---@field output? string local M = {} diff --git a/lua/cp/config.lua b/lua/cp/config.lua index d220207..12a322e 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -37,7 +37,6 @@ ---@field hooks Hooks ---@field debug boolean ---@field scrapers table ----@field tile? fun(source_buf: number, input_buf: number, output_buf: number) ---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string ---@class cp.UserConfig @@ -46,7 +45,6 @@ ---@field hooks? Hooks ---@field debug? boolean ---@field scrapers? table ----@field tile? fun(source_buf: number, input_buf: number, output_buf: number) ---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string local M = {} @@ -62,12 +60,7 @@ M.defaults = { setup_code = nil, }, debug = false, - scrapers = vim.iter(constants.PLATFORMS) - :map(function(platform) - return platform, true - end) - :totable(), - tile = nil, + scrapers = constants.PLATFORMS, filename = nil, } @@ -85,7 +78,6 @@ function M.setup(user_config) hooks = { user_config.hooks, { "table", "nil" }, true }, debug = { user_config.debug, { "boolean", "nil" }, true }, scrapers = { user_config.scrapers, { "table", "nil" }, true }, - tile = { user_config.tile, { "function", "nil" }, true }, filename = { user_config.filename, { "function", "nil" }, true }, }) diff --git a/lua/cp/constants.lua b/lua/cp/constants.lua index 59693a8..7d5b155 100644 --- a/lua/cp/constants.lua +++ b/lua/cp/constants.lua @@ -1,7 +1,7 @@ local M = {} M.PLATFORMS = { "atcoder", "codeforces", "cses" } -M.ACTIONS = { "run", "debug", "test", "next", "prev" } +M.ACTIONS = { "test", "next", "prev" } M.CPP = "cpp" M.PYTHON = "python" diff --git a/lua/cp/execute.lua b/lua/cp/execute.lua index 9398d2c..58e17e5 100644 --- a/lua/cp/execute.lua +++ b/lua/cp/execute.lua @@ -193,8 +193,9 @@ end ---@param ctx ProblemContext ---@param contest_config ContestConfig +---@param is_debug? boolean ---@return boolean success -function M.compile_problem(ctx, contest_config) +function M.compile_problem(ctx, contest_config, is_debug) vim.validate({ ctx = { ctx, "table" }, contest_config = { contest_config, "table" }, @@ -214,13 +215,15 @@ function M.compile_problem(ctx, contest_config) version = tostring(language_config.version), } - if language_config.compile then + local compile_cmd = (is_debug and language_config.debug) and language_config.debug or language_config.compile + if compile_cmd then + language_config.compile = compile_cmd local compile_result = M.compile_generic(language_config, substitutions) if compile_result.code ~= 0 then logger.log("compilation failed: " .. (compile_result.stderr or "unknown error"), vim.log.levels.ERROR) return false end - logger.log("compilation successful") + logger.log(("compilation successful (%s)"):format(is_debug and "debug mode" or "test mode")) end return true diff --git a/lua/cp/init.lua b/lua/cp/init.lua index 07f8d3c..68c9e8a 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -2,9 +2,7 @@ local M = {} local config_module = require("cp.config") local snippets = require("cp.snippets") -local execute = require("cp.execute") local scrape = require("cp.scrape") -local window = require("cp.window") local logger = require("cp.log") local problem = require("cp.problem") local cache = require("cp.cache") @@ -60,7 +58,7 @@ local function setup_problem(contest_id, problem_id, language) local ctx = problem.create_context(state.platform, contest_id, problem_id, config, language) - if config.scrapers[state.platform] then + if vim.tbl_contains(config.scrapers, state.platform) then local metadata_result = scrape.scrape_contest_metadata(state.platform, contest_id) if not metadata_result.success then logger.log( @@ -75,7 +73,7 @@ local function setup_problem(contest_id, problem_id, language) state.test_cases = cached_test_cases end - if config.scrapers[state.platform] then + if vim.tbl_contains(config.scrapers, state.platform) then local scrape_result = scrape.scrape_problem(ctx) if not scrape_result.success then @@ -133,13 +131,6 @@ local function setup_problem(contest_id, problem_id, language) config.hooks.setup_code(ctx) end - local src_buf = vim.api.nvim_get_current_buf() - local input_buf = vim.fn.bufnr(ctx.input_file, true) - local output_buf = vim.fn.bufnr(ctx.output_file, true) - - local tile_fn = config.tile or window.default_tile - tile_fn(src_buf, input_buf, output_buf) - logger.log(("switched to problem %s"):format(ctx.problem_name)) end @@ -152,60 +143,7 @@ local function get_current_problem() return filename end -local function run_problem() - local problem_id = get_current_problem() - if not problem_id then - return - end - - logger.log(("running problem: %s"):format(problem_id)) - - if not state.platform then - logger.log( - "No contest configured. Use :CP to set up first.", - vim.log.levels.ERROR - ) - return - end - - local contest_config = config.contests[state.platform] - local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) - - if config.hooks and config.hooks.before_run then - config.hooks.before_run(ctx) - end - - vim.schedule(function() - execute.run_problem(ctx, contest_config, false) - vim.cmd.checktime() - end) -end - -local function debug_problem() - local problem_id = get_current_problem() - if not problem_id then - return - end - - if not state.platform then - logger.log("no platform set", vim.log.levels.ERROR) - return - end - - local contest_config = config.contests[state.platform] - local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) - - if config.hooks and config.hooks.before_debug then - config.hooks.before_debug(ctx) - end - - vim.schedule(function() - execute.run_problem(ctx, contest_config, true) - vim.cmd.checktime() - end) -end - -local function toggle_test_panel() +local function toggle_test_panel(is_debug) if state.test_panel_active then if state.saved_session then vim.cmd(("source %s"):format(state.saved_session)) @@ -225,11 +163,6 @@ local function toggle_test_panel() return end - if state.platform == "codeforces" then - logger.log("test panel not yet supported for codeforces", vim.log.levels.ERROR) - return - end - local problem_id = get_current_problem() if not problem_id then return @@ -260,19 +193,20 @@ local function toggle_test_panel() vim.api.nvim_win_set_buf(main_win, tab_buf) vim.api.nvim_set_option_value("filetype", "cptest", { buf = tab_buf }) - vim.cmd("split") - local content_win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(content_win, actual_buf) + vim.cmd.split() + vim.api.nvim_win_set_buf(0, actual_buf) vim.api.nvim_set_option_value("filetype", "cptest", { buf = actual_buf }) - vim.cmd("vsplit") - local expected_win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(expected_win, expected_buf) + vim.cmd.vsplit() + vim.api.nvim_win_set_buf(0, expected_buf) vim.api.nvim_set_option_value("filetype", "cptest", { buf = expected_buf }) + local expected_win = vim.fn.bufwinid(expected_buf) + local actual_win = vim.fn.bufwinid(actual_buf) + local test_windows = { tab_win = main_win, - actual_win = content_win, + actual_win = actual_win, expected_win = expected_win, } local test_buffers = { @@ -347,14 +281,14 @@ local function toggle_test_panel() return end - local expected_lines = {} - local expected_text = current_test.expected - for _, line in ipairs(vim.split(expected_text, "\n", { plain = true, trimempty = true })) do - table.insert(expected_lines, line) - end + local expected_lines = vim.split(expected_text, "\n", { plain = true, trimempty = true }) vim.api.nvim_buf_set_lines(test_buffers.expected_buf, 0, -1, false, expected_lines) + + if vim.fn.has("nvim-0.8.0") == 1 then + vim.api.nvim_set_option_value("winbar", "Expected", { win = test_windows.expected_win }) + end end local function update_actual_pane() @@ -366,27 +300,32 @@ local function toggle_test_panel() end local actual_lines = {} + local enable_diff = false if current_test.actual then - for _, line in ipairs(vim.split(current_test.actual, "\n", { plain = true, trimempty = true })) do - table.insert(actual_lines, line) - end - - if current_test.status == "fail" then - vim.api.nvim_set_option_value("diff", true, { win = test_windows.expected_win }) - vim.api.nvim_set_option_value("diff", true, { win = test_windows.actual_win }) - else - vim.api.nvim_set_option_value("diff", false, { win = test_windows.expected_win }) - vim.api.nvim_set_option_value("diff", false, { win = test_windows.actual_win }) - end + actual_lines = vim.split(current_test.actual, "\n", { plain = true, trimempty = true }) + enable_diff = current_test.status == "fail" else - table.insert(actual_lines, "(not run yet)") - - vim.api.nvim_set_option_value("diff", false, { win = test_windows.expected_win }) - vim.api.nvim_set_option_value("diff", false, { win = test_windows.actual_win }) + actual_lines = { "(not run yet)" } end vim.api.nvim_buf_set_lines(test_buffers.actual_buf, 0, -1, false, actual_lines) + + if vim.fn.has("nvim-0.8.0") == 1 then + vim.api.nvim_set_option_value("winbar", "Actual", { win = test_windows.actual_win }) + end + + vim.api.nvim_set_option_value("diff", enable_diff, { win = test_windows.expected_win }) + vim.api.nvim_set_option_value("diff", enable_diff, { win = test_windows.actual_win }) + + if enable_diff then + vim.api.nvim_win_call(test_windows.expected_win, function() + vim.cmd.diffthis() + end) + vim.api.nvim_win_call(test_windows.actual_win, function() + vim.cmd.diffthis() + end) + end end local function refresh_test_panel() @@ -417,10 +356,10 @@ local function toggle_test_panel() refresh_test_panel() end - vim.keymap.set("n", "j", function() + vim.keymap.set("n", "", function() navigate_test_case(1) end, { buffer = test_buffers.tab_buf, silent = true }) - vim.keymap.set("n", "k", function() + vim.keymap.set("n", "", function() navigate_test_case(-1) end, { buffer = test_buffers.tab_buf, silent = true }) @@ -430,9 +369,13 @@ local function toggle_test_panel() end, { buffer = buf, silent = true }) end + if is_debug and config.hooks and config.hooks.before_debug then + config.hooks.before_debug(ctx) + end + local execute_module = require("cp.execute") local contest_config = config.contests[state.platform] - if execute_module.compile_problem(ctx, contest_config) then + if execute_module.compile_problem(ctx, contest_config, is_debug) then test_module.run_all_test_cases(ctx, contest_config) end @@ -515,6 +458,7 @@ local function parse_command(args) end local language = nil + local debug = false for i, arg in ipairs(args) do local lang_match = arg:match("^--lang=(.+)$") @@ -526,17 +470,19 @@ local function parse_command(args) else return { type = "error", message = "--lang requires a value" } end + elseif arg == "--debug" then + debug = true end end local filtered_args = vim.tbl_filter(function(arg) - return not (arg:match("^--lang") or arg == language) + return not (arg:match("^--lang") or arg == language or arg == "--debug") end, args) local first = filtered_args[1] if vim.tbl_contains(actions, first) then - return { type = "action", action = first, language = language } + return { type = "action", action = first, language = language, debug = debug } end if vim.tbl_contains(platforms, first) then @@ -604,12 +550,8 @@ function M.handle_command(opts) end if cmd.type == "action" then - if cmd.action == "run" then - run_problem() - elseif cmd.action == "debug" then - debug_problem() - elseif cmd.action == "test" then - toggle_test_panel() + if cmd.action == "test" then + toggle_test_panel(cmd.debug) elseif cmd.action == "next" then navigate_problem(1, cmd.language) elseif cmd.action == "prev" then @@ -626,7 +568,7 @@ function M.handle_command(opts) if cmd.type == "contest_setup" then if set_platform(cmd.platform) then state.contest_id = cmd.contest - if config.scrapers[cmd.platform] then + if vim.tbl_contains(config.scrapers, cmd.platform) then local metadata_result = scrape.scrape_contest_metadata(cmd.platform, cmd.contest) if not metadata_result.success then logger.log( @@ -649,7 +591,7 @@ function M.handle_command(opts) local problem_ids = {} local has_metadata = false - if config.scrapers[cmd.platform] then + if vim.tbl_contains(config.scrapers, cmd.platform) then local metadata_result = scrape.scrape_contest_metadata(cmd.platform, cmd.contest) if not metadata_result.success then logger.log( @@ -692,7 +634,7 @@ function M.handle_command(opts) if cmd.type == "cses_problem" then if set_platform(cmd.platform) then - if config.scrapers[cmd.platform] then + if vim.tbl_contains(config.scrapers, cmd.platform) then local metadata_result = scrape.scrape_contest_metadata(cmd.platform, "") if not metadata_result.success then logger.log( diff --git a/lua/cp/scrape.lua b/lua/cp/scrape.lua index 4dc7855..8eb6183 100644 --- a/lua/cp/scrape.lua +++ b/lua/cp/scrape.lua @@ -1,3 +1,14 @@ +---@class ScraperTestCase +---@field input string +---@field expected string + +---@class ScraperResult +---@field success boolean +---@field problem_id string +---@field url? string +---@field tests? ScraperTestCase[] +---@field error? string + local M = {} local logger = require("cp.log") local cache = require("cp.cache") @@ -139,7 +150,7 @@ function M.scrape_contest_metadata(platform, contest_id) end ---@param ctx ProblemContext ----@return {success: boolean, problem_id: string, test_count?: number, test_cases?: table[], url?: string, error?: string} +---@return {success: boolean, problem_id: string, test_count?: number, test_cases?: ScraperTestCase[], url?: string, error?: string} function M.scrape_problem(ctx) vim.validate({ ctx = { ctx, "table" }, @@ -249,50 +260,26 @@ function M.scrape_problem(ctx) return data end - if data.test_cases and #data.test_cases > 0 then + if data.tests and #data.tests > 0 then local base_name = vim.fn.fnamemodify(ctx.input_file, ":r") - for i, test_case in ipairs(data.test_cases) do + for i, test_case in ipairs(data.tests) do local input_file = base_name .. "." .. i .. ".cpin" local expected_file = base_name .. "." .. i .. ".cpout" local input_content = test_case.input:gsub("\r", "") - local expected_content = test_case.output:gsub("\r", "") + local expected_content = test_case.expected:gsub("\r", "") vim.fn.writefile(vim.split(input_content, "\n", true), input_file) vim.fn.writefile(vim.split(expected_content, "\n", true), expected_file) end - - local combined_input = data.combined and data.combined.input:gsub("\r", "") - or table.concat( - vim.tbl_map(function(tc) - return tc.input - end, data.test_cases), - "\n" - ) - local combined_output = data.combined and data.combined.output:gsub("\r", "") - or table.concat( - vim.tbl_map(function(tc) - return tc.output - end, data.test_cases), - "\n" - ) - - -- with atcoder, we combine together multiple test cases - -- TODO: per-platform settings to do this (i.e. do we stitch?) - if ctx.contest == "atcoder" then - combined_input = tostring(#data.test_cases) .. "\n" .. combined_input - end - - vim.fn.writefile(vim.split(combined_input, "\n", true), ctx.input_file) - vim.fn.writefile(vim.split(combined_output, "\n", true), ctx.expected_file) end return { success = true, problem_id = ctx.problem_name, - test_count = data.test_cases and #data.test_cases or 0, - test_cases = data.test_cases, + test_count = data.tests and #data.tests or 0, + test_cases = data.tests, url = data.url, } end diff --git a/lua/cp/test.lua b/lua/cp/test.lua index c8144ff..852baa9 100644 --- a/lua/cp/test.lua +++ b/lua/cp/test.lua @@ -68,7 +68,8 @@ local function parse_test_cases_from_cache(platform, contest_id, problem_id) for i, test_case in ipairs(cached_test_cases) do local index = test_case.index or i - table.insert(test_cases, create_test_case(index, test_case.input, test_case.output)) + local expected = test_case.expected or test_case.output or "" + table.insert(test_cases, create_test_case(index, test_case.input, expected)) end return test_cases @@ -171,9 +172,6 @@ local function run_single_test_case(ctx, contest_config, test_case) local run_cmd = build_command(language_config.run, language_config.executable, substitutions) local stdin_content = test_case.input .. "\n" - if ctx.contest == "atcoder" then - stdin_content = "1\n" .. stdin_content - end local start_time = vim.uv.hrtime() local result = vim.system(run_cmd, { diff --git a/plugin/cp.lua b/plugin/cp.lua index 735ff89..f112ae4 100644 --- a/plugin/cp.lua +++ b/plugin/cp.lua @@ -14,37 +14,14 @@ end, { nargs = "*", desc = "Competitive programming helper", complete = function(ArgLead, CmdLine, _) - local languages = vim.tbl_keys(constants.canonical_filetypes) - - if ArgLead:match("^--lang=") then - local lang_completions = {} - for _, lang in ipairs(languages) do - table.insert(lang_completions, "--lang=" .. lang) - end - return vim.tbl_filter(function(completion) - return completion:find(ArgLead, 1, true) == 1 - end, lang_completions) - end - - if ArgLead:match("^%-") and not ArgLead:match("^--lang") then - return vim.tbl_filter(function(completion) - return completion:find(ArgLead, 1, true) == 1 - end, { "--lang" }) - end - local args = vim.split(vim.trim(CmdLine), "%s+") local num_args = #args if CmdLine:sub(-1) == " " then num_args = num_args + 1 end - local lang_flag_present = vim.tbl_contains(args, "--lang") - or vim.iter(args):any(function(arg) - return arg:match("^--lang=") - end) - if num_args == 2 then - local candidates = { "--lang" } + local candidates = {} local cp = require("cp") local context = cp.get_current_context() if context.platform and context.contest_id then @@ -63,17 +40,13 @@ end, { return vim.tbl_filter(function(cmd) return cmd:find(ArgLead, 1, true) == 1 end, candidates) - elseif args[#args - 1] == "--lang" then - return vim.tbl_filter(function(lang) - return lang:find(ArgLead, 1, true) == 1 - end, languages) - elseif num_args == 4 and not lang_flag_present then + elseif num_args == 4 then if vim.tbl_contains(platforms, args[2]) then local cache = require("cp.cache") cache.load() local contest_data = cache.get_contest_data(args[2], args[3]) if contest_data and contest_data.problems then - local candidates = { "--lang" } + local candidates = {} for _, problem in ipairs(contest_data.problems) do table.insert(candidates, problem.id) end diff --git a/readme.md b/readme.md index 2c4cf19..19fffc3 100644 --- a/readme.md +++ b/readme.md @@ -66,6 +66,7 @@ follows: ## TODO +- fzf/telescope integration (whichever available) +- autocomplete with --lang and --debug - finer-tuned problem limits (i.e. per-problem codeforces time, memory) -- better highlighting - notify discord members diff --git a/scrapers/atcoder.py b/scrapers/atcoder.py index 7983cc7..3c86565 100644 --- a/scrapers/atcoder.py +++ b/scrapers/atcoder.py @@ -198,21 +198,15 @@ def main() -> None: print(json.dumps(result)) sys.exit(1) - individual_test_cases: list[dict[str, str]] = [] - for index, (input_data, output_data) in enumerate(tests, 1): - individual_test_cases.append( - {"index": index, "input": input_data, "output": output_data} - ) - - combined_input = "\n".join(tc["input"] for tc in individual_test_cases) - combined_output = "\n".join(tc["output"] for tc in individual_test_cases) + test_list: list[dict[str, str]] = [] + for input_data, output_data in tests: + test_list.append({"input": input_data, "expected": output_data}) result = { "success": True, "problem_id": problem_id, "url": url, - "test_cases": individual_test_cases, - "combined": {"input": combined_input, "output": combined_output}, + "tests": test_list, } print(json.dumps(result)) diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index bf252bd..f8e2290 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -7,80 +7,6 @@ import cloudscraper from bs4 import BeautifulSoup -def extract_combined_text(sections) -> list[str]: - texts = [] - - for section in sections: - pre = section.find("pre") - if not pre: - continue - - divs = pre.find_all("div") - if divs: - lines = [div.get_text().strip() for div in divs] - text = "\n".join(lines) - else: - text = pre.get_text().replace("\r", "").strip() - texts.append(text) - - return texts - - -def extract_lines_by_test_number(sections) -> dict[int, list[str]]: - lines_by_test = {} - - for section in sections: - pre = section.find("pre") - if not pre: - continue - - divs = pre.find_all("div") - for div in divs: - classes = div.get("class", []) - for class_name in classes: - if not class_name.startswith("test-example-line-"): - continue - - try: - test_num = int(class_name.split("-")[-1]) - if test_num not in lines_by_test: - lines_by_test[test_num] = [] - lines_by_test[test_num].append(div.get_text().strip()) - except (ValueError, IndexError): - continue - - return lines_by_test - - -def extract_individual_test_cases( - input_sections, output_sections -) -> list[tuple[str, str]]: - if not input_sections or not output_sections: - return [] - - input_by_test = extract_lines_by_test_number(input_sections) - output_by_test = extract_lines_by_test_number(output_sections) - - if not input_by_test or not output_by_test: - return [] - - tests = [] - test_numbers = sorted(set(input_by_test.keys()) & set(output_by_test.keys())) - - for test_num in test_numbers: - input_lines = input_by_test.get(test_num, []) - output_lines = output_by_test.get(test_num, []) - - if not input_lines or not output_lines: - continue - - input_text = "\n".join(input_lines) - output_text = "\n".join(output_lines) - tests.append((input_text, output_text)) - - return tests - - def scrape(url: str) -> list[tuple[str, str]]: try: scraper = cloudscraper.create_scraper() @@ -88,27 +14,117 @@ def scrape(url: str) -> list[tuple[str, str]]: response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") - tests: list[tuple[str, str]] = [] - input_sections = soup.find_all("div", class_="input") output_sections = soup.find_all("div", class_="output") - individual_tests = extract_individual_test_cases( - input_sections, output_sections - ) + individual_inputs = {} + individual_outputs = {} - if individual_tests: - return individual_tests + for inp_section in input_sections: + inp_pre = inp_section.find("pre") + if not inp_pre: + continue - all_inputs = extract_combined_text(input_sections) - all_outputs = extract_combined_text(output_sections) + test_line_divs = inp_pre.find_all( + "div", class_=lambda x: x and "test-example-line-" in x + ) + if not test_line_divs: + continue - if all_inputs and all_outputs: - combined_input = "\n".join(all_inputs) - combined_output = "\n".join(all_outputs) - tests.append((combined_input, combined_output)) + for div in test_line_divs: + classes = div.get("class", []) + class_name = next( + ( + cls + for cls in classes + if "test-example-line-" in cls and cls.split("-")[-1].isdigit() + ), + None, + ) + if not class_name: + continue - return tests + test_num = class_name.replace("test-example-line-", "") + if test_num not in individual_inputs: + individual_inputs[test_num] = [] + individual_inputs[test_num].append(div.get_text().strip()) + + for out_section in output_sections: + out_pre = out_section.find("pre") + if not out_pre: + continue + + test_line_divs = out_pre.find_all( + "div", class_=lambda x: x and "test-example-line-" in x + ) + if not test_line_divs: + continue + + for div in test_line_divs: + classes = div.get("class", []) + class_name = next( + ( + cls + for cls in classes + if "test-example-line-" in cls and cls.split("-")[-1].isdigit() + ), + None, + ) + if not class_name: + continue + + test_num = class_name.replace("test-example-line-", "") + if test_num not in individual_outputs: + individual_outputs[test_num] = [] + individual_outputs[test_num].append(div.get_text().strip()) + + if individual_inputs and individual_outputs: + common_tests = set(individual_inputs.keys()) & set( + individual_outputs.keys() + ) + if common_tests: + tests = [] + for test_num in sorted(common_tests): + input_text = "\n".join(individual_inputs[test_num]) + output_text = "\n".join(individual_outputs[test_num]) + prefixed_input = "1\n" + input_text + tests.append((prefixed_input, output_text)) + return tests + all_inputs = [] + all_outputs = [] + + for inp_section in input_sections: + inp_pre = inp_section.find("pre") + if not inp_pre: + continue + + divs = inp_pre.find_all("div") + if divs: + lines = [div.get_text().strip() for div in divs] + text = "\n".join(lines) + else: + text = inp_pre.get_text().replace("\r", "").strip() + all_inputs.append(text) + + for out_section in output_sections: + out_pre = out_section.find("pre") + if not out_pre: + continue + + divs = out_pre.find_all("div") + if divs: + lines = [div.get_text().strip() for div in divs] + text = "\n".join(lines) + else: + text = out_pre.get_text().replace("\r", "").strip() + all_outputs.append(text) + + if not all_inputs or not all_outputs: + return [] + + combined_input = "\n".join(all_inputs) + combined_output = "\n".join(all_outputs) + return [(combined_input, combined_output)] except Exception as e: print(f"CloudScraper failed: {e}", file=sys.stderr) @@ -165,39 +181,6 @@ def scrape_sample_tests(url: str) -> list[tuple[str, str]]: return scrape(url) -def scrape_with_both_formats( - url: str, -) -> tuple[list[tuple[str, str]], tuple[str, str] | None]: - try: - scraper = cloudscraper.create_scraper() - response = scraper.get(url, timeout=10) - response.raise_for_status() - - soup = BeautifulSoup(response.text, "html.parser") - - input_sections = soup.find_all("div", class_="input") - output_sections = soup.find_all("div", class_="output") - - individual_tests = extract_individual_test_cases( - input_sections, output_sections - ) - - all_inputs = extract_combined_text(input_sections) - all_outputs = extract_combined_text(output_sections) - - combined = None - if all_inputs and all_outputs: - combined_input = "\n".join(all_inputs) - combined_output = "\n".join(all_outputs) - combined = (combined_input, combined_output) - - return individual_tests, combined - - except Exception as e: - print(f"CloudScraper failed: {e}", file=sys.stderr) - return [], None - - def main() -> None: if len(sys.argv) < 2: result: dict[str, str | bool] = { @@ -250,10 +233,9 @@ def main() -> None: problem_id: str = contest_id + problem_letter.lower() url: str = parse_problem_url(contest_id, problem_letter) - print(f"Scraping: {url}", file=sys.stderr) - individual_tests, combined = scrape_with_both_formats(url) + tests: list[tuple[str, str]] = scrape_sample_tests(url) - if not individual_tests and not combined: + if not tests: result: dict[str, str | bool] = { "success": False, "error": f"No tests found for {contest_id} {problem_letter}", @@ -263,26 +245,16 @@ def main() -> None: print(json.dumps(result)) sys.exit(1) - test_cases: list[dict[str, str]] = [] - has_individual = len(individual_tests) > 0 - - if has_individual: - for input_data, output_data in individual_tests: - test_cases.append({"input": input_data, "output": output_data}) - elif combined: - test_cases.append({"input": combined[0], "output": combined[1]}) + test_list: list[dict[str, str]] = [] + for input_data, output_data in tests: + test_list.append({"input": input_data, "expected": output_data}) result: dict[str, str | bool | list] = { "success": True, "problem_id": problem_id, "url": url, - "test_cases": test_cases, - "has_individual_tests": has_individual, + "tests": test_list, } - - if combined: - result["combined"] = {"input": combined[0], "output": combined[1]} - print(json.dumps(result)) else: diff --git a/scrapers/cses.py b/scrapers/cses.py index 17ecc85..c91f671 100755 --- a/scrapers/cses.py +++ b/scrapers/cses.py @@ -198,15 +198,15 @@ def main() -> None: print(json.dumps(result)) sys.exit(1) - test_cases: list[dict[str, str]] = [] + test_list: list[dict[str, str]] = [] for input_data, output_data in tests: - test_cases.append({"input": input_data, "output": output_data}) + test_list.append({"input": input_data, "expected": output_data}) result = { "success": True, "problem_id": problem_id, "url": url, - "test_cases": test_cases, + "tests": test_list, } print(json.dumps(result))