diff --git a/doc/cp.txt b/doc/cp.txt index 32bb734..6625e0f 100644 --- a/doc/cp.txt +++ b/doc/cp.txt @@ -100,7 +100,6 @@ Here's an example configuration with lazy.nvim: > extension = "py", }, default_language = "cpp", - timeout_ms = 2000, }, }, hooks = { @@ -156,8 +155,6 @@ Here's an example configuration with lazy.nvim: > • {python} (`LanguageConfig`) Python language configuration. • {default_language} (`string`, default: `"cpp"`) Default language when `--lang` not specified. - • {timeout_ms} (`number`, default: `2000`) Execution timeout in - milliseconds. *cp.LanguageConfig* @@ -315,20 +312,27 @@ Activation ~ Interface ~ The run panel uses a professional table layout with precise column alignment: -(note that the diff is indeed highlighted, not the weird amalgamation of +(observe that the diff is indeed highlighted, not the weird amalgamation of characters below) > - ┌──────┬────────┬────────┬───────────┐ ┌─ Expected vs Actual ──────────────────┐ - │ # │ Status │ Time │ Exit Code │ │ 45.70ms │ Exit: 0 │ - ├──────┼────────┼────────┼───────────┤ ├────────────────────────────────────────┤ - │ 1 │ AC │12.00ms │ 0 │ │ │ - │ >2 │ WA │45.70ms │ 1 │ │ 4[-2-]{+3+} │ - ├──────┴────────┴────────┴───────────┤ │ 100 │ - │5 3 │ │ hello w[-o-]r{+o+}ld │ - ├──────┬────────┬────────┬───────────┤ │ │ - │ 3 │ AC │ 9.00ms │ 0 │ └────────────────────────────────────────┘ - │ 4 │ RTE │ 0.00ms │139 (SIGUSR2)│ - └──────┴────────┴────────┴───────────┘ + ┌─────┬────────┬──────────────┬───────────┬──────────┬─────────────┐ + │ # │ Status │ Runtime (ms) │ Time (ms) │ Mem (MB) │ Exit Code │ + ├─────┼────────┼──────────────┼───────────┼──────────┼─────────────┤ + │ 1 │ AC │ 12.0 │ 2000 │ 256 │ 0 │ + │> 2 │ WA │ 45.70 │ 2000 │ 256 │ 1 │ + ├─────┴────────┴──────────────┴───────────┴──────────┴─────────────┤ + │Input: │ + │5 3 │ + ├─────┬────────┬──────────────┬───────────┬──────────┬─────────────┤ + │ 3 │ AC │ 9.0 │ 2000 │ 256 │ 0 │ + │ 4 │ RTE │ 0.0 │ 2000 │ 256 │139 (SIGUSR2)│ + └─────┴────────┴──────────────┴───────────┴──────────┴─────────────┘ + ┌──────────────────────────────────────────────────────────────────┐ + │Expected vs Actual │ + │4[-2-]{+3+} │ + │ 100 │ + │ hello w[-o-]r{+o+}ld │ + └──────────────────────────────────────────────────────────────────┘ < Status Indicators ~ @@ -344,8 +348,8 @@ Keymaps ~ *cp-test-keys* Navigate to next test case (configurable via run_panel.next_test_key) Navigate to previous test case (configurable via run_panel.prev_test_key) -t Toggle diff mode between vim and git (configurable via run_panel.toggle_diff_key) -q Exit test panel and restore layout + Toggle diff mode between vim and git (configurable via run_panel.toggle_diff_key) + Exit test panel and restore layout Diff Modes ~ diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 86d26d8..4f339a4 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -7,6 +7,8 @@ ---@field expires_at? number ---@field test_cases? CachedTestCase[] ---@field test_cases_cached_at? number +---@field timeout_ms? number +---@field memory_mb? number ---@class Problem ---@field id string @@ -22,6 +24,7 @@ local M = {} local cache_file = vim.fn.stdpath('data') .. '/cp-nvim.json' local cache_data = {} +local loaded = false ---@param platform string ---@return number? @@ -58,14 +61,20 @@ local function is_cache_valid(contest_data, platform) end function M.load() + if loaded then + return + end + if vim.fn.filereadable(cache_file) == 0 then cache_data = {} + loaded = true return end local content = vim.fn.readfile(cache_file) if #content == 0 then cache_data = {} + loaded = true return end @@ -75,6 +84,7 @@ function M.load() else cache_data = {} end + loaded = true end function M.save() @@ -167,12 +177,16 @@ end ---@param contest_id string ---@param problem_id? string ---@param test_cases CachedTestCase[] -function M.set_test_cases(platform, contest_id, problem_id, test_cases) +---@param timeout_ms? number +---@param memory_mb? number +function M.set_test_cases(platform, contest_id, problem_id, test_cases, timeout_ms, memory_mb) vim.validate({ platform = { platform, 'string' }, contest_id = { contest_id, 'string' }, problem_id = { problem_id, { 'string', 'nil' }, true }, test_cases = { test_cases, 'table' }, + timeout_ms = { timeout_ms, { 'number', 'nil' }, true }, + memory_mb = { memory_mb, { 'number', 'nil' }, true }, }) local problem_key = problem_id and (contest_id .. '_' .. problem_id) or contest_id @@ -185,7 +199,33 @@ function M.set_test_cases(platform, contest_id, problem_id, test_cases) cache_data[platform][problem_key].test_cases = test_cases cache_data[platform][problem_key].test_cases_cached_at = os.time() + if timeout_ms then + cache_data[platform][problem_key].timeout_ms = timeout_ms + end + if memory_mb then + cache_data[platform][problem_key].memory_mb = memory_mb + end M.save() end +---@param platform string +---@param contest_id string +---@param problem_id? string +---@return number?, number? +function M.get_constraints(platform, contest_id, problem_id) + vim.validate({ + platform = { platform, 'string' }, + contest_id = { contest_id, 'string' }, + problem_id = { problem_id, { 'string', 'nil' }, true }, + }) + + local problem_key = problem_id and (contest_id .. '_' .. problem_id) or contest_id + if not cache_data[platform] or not cache_data[platform][problem_key] then + return nil, nil + end + + local problem_data = cache_data[platform][problem_key] + return problem_data.timeout_ms, problem_data.memory_mb +end + return M diff --git a/lua/cp/config.lua b/lua/cp/config.lua index 0616e9f..35a523f 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -18,13 +18,11 @@ ---@field cpp LanguageConfig ---@field python LanguageConfig ---@field default_language string ----@field timeout_ms number ---@class PartialContestConfig ---@field cpp? PartialLanguageConfig ---@field python? PartialLanguageConfig ---@field default_language? string ----@field timeout_ms? number ---@class Hooks ---@field before_run? fun(ctx: ProblemContext) @@ -84,7 +82,7 @@ M.defaults = { diff_mode = 'vim', next_test_key = '', prev_test_key = '', - toggle_diff_key = 't', + toggle_diff_key = '', max_output_lines = 50, }, diff = { diff --git a/lua/cp/execute.lua b/lua/cp/execute.lua index b59a954..d336ad8 100644 --- a/lua/cp/execute.lua +++ b/lua/cp/execute.lua @@ -103,7 +103,7 @@ end ---@param cmd string[] ---@param input_data string ----@param timeout_ms integer +---@param timeout_ms number ---@return ExecuteResult local function execute_command(cmd, input_data, timeout_ms) vim.validate({ @@ -278,8 +278,13 @@ function M.run_problem(ctx, contest_config, is_debug) input_data = table.concat(vim.fn.readfile(ctx.input_file), '\n') .. '\n' end + local cache = require('cp.cache') + cache.load() + local timeout_ms, _ = cache.get_constraints(ctx.contest, ctx.contest_id, ctx.problem_id) + timeout_ms = timeout_ms or 2000 + local run_cmd = build_command(language_config.test, language_config.executable, substitutions) - local exec_result = execute_command(run_cmd, input_data, contest_config.timeout_ms) + local exec_result = execute_command(run_cmd, input_data, timeout_ms) local formatted_output = format_output(exec_result, ctx.expected_file, is_debug) local output_buf = vim.fn.bufnr(ctx.output_file) diff --git a/lua/cp/highlight.lua b/lua/cp/highlight.lua index 8cf8bf4..02bf1ae 100644 --- a/lua/cp/highlight.lua +++ b/lua/cp/highlight.lua @@ -10,7 +10,6 @@ local M = {} ----Parse git diff markers and extract highlight information ---@param text string Raw git diff output line ---@return string cleaned_text, DiffHighlight[] local function parse_diff_line(text) @@ -52,7 +51,6 @@ local function parse_diff_line(text) return result_text, highlights end ----Parse complete git diff output ---@param diff_output string ---@return ParsedDiff function M.parse_git_diff(diff_output) @@ -64,10 +62,8 @@ function M.parse_git_diff(diff_output) local content_lines = {} local all_highlights = {} - -- Skip git diff header lines local content_started = false for _, line in ipairs(lines) do - -- Skip header lines (@@, +++, ---, index, etc.) if content_started or ( @@ -80,33 +76,27 @@ function M.parse_git_diff(diff_output) then content_started = true - -- Process content lines if line:match('^%+') then - -- Added line - remove + prefix and parse highlights - local clean_line = line:sub(2) -- Remove + prefix + local clean_line = line:sub(2) local parsed_line, line_highlights = parse_diff_line(clean_line) table.insert(content_lines, parsed_line) - -- Set line numbers for highlights local line_num = #content_lines for _, highlight in ipairs(line_highlights) do - highlight.line = line_num - 1 -- 0-based for extmarks + highlight.line = line_num - 1 table.insert(all_highlights, highlight) end - elseif not line:match('^%-') and not line:match('^\\') then -- Skip removed lines and "\ No newline" messages - -- Word-diff content line or unchanged line + elseif not line:match('^%-') and not line:match('^\\') then local clean_line = line:match('^%s') and line:sub(2) or line local parsed_line, line_highlights = parse_diff_line(clean_line) - -- Only add non-empty lines if parsed_line ~= '' then table.insert(content_lines, parsed_line) - -- Set line numbers for highlights local line_num = #content_lines for _, highlight in ipairs(line_highlights) do - highlight.line = line_num - 1 -- 0-based for extmarks + highlight.line = line_num - 1 table.insert(all_highlights, highlight) end end @@ -120,12 +110,10 @@ function M.parse_git_diff(diff_output) } end ----Apply highlights to a buffer using extmarks ---@param bufnr number ---@param highlights DiffHighlight[] ---@param namespace number function M.apply_highlights(bufnr, highlights, namespace) - -- Clear existing highlights in this namespace vim.api.nvim_buf_clear_namespace(bufnr, namespace, 0, -1) for _, highlight in ipairs(highlights) do @@ -139,13 +127,11 @@ function M.apply_highlights(bufnr, highlights, namespace) end end ----Create namespace for diff highlights ---@return number function M.create_namespace() return vim.api.nvim_create_namespace('cp_diff_highlights') end ----Parse and apply git diff to buffer ---@param bufnr number ---@param diff_output string ---@param namespace number diff --git a/lua/cp/init.lua b/lua/cp/init.lua index dda2b2e..cdf5aba 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -259,6 +259,8 @@ local function toggle_run_panel(is_debug) vim.api.nvim_set_option_value('filetype', 'cptest', { buf = expected_buf }) vim.api.nvim_set_option_value('filetype', 'cptest', { buf = actual_buf }) + vim.api.nvim_set_option_value('winbar', 'Expected', { win = expected_win }) + vim.api.nvim_set_option_value('winbar', 'Actual', { win = actual_win }) local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) @@ -296,6 +298,7 @@ local function toggle_run_panel(is_debug) vim.api.nvim_win_set_buf(diff_win, diff_buf) vim.api.nvim_set_option_value('filetype', 'cptest', { buf = diff_buf }) + vim.api.nvim_set_option_value('winbar', 'Expected vs Actual', { win = diff_win }) local diff_backend = require('cp.diff') local backend = diff_backend.get_best_backend('git') @@ -441,7 +444,7 @@ local function toggle_run_panel(is_debug) end setup_keybindings_for_buffer = function(buf) - vim.keymap.set('n', 'q', function() + vim.keymap.set('n', '', function() toggle_run_panel() end, { buffer = buf, silent = true }) vim.keymap.set('n', config.run_panel.toggle_diff_key, function() diff --git a/lua/cp/scrape.lua b/lua/cp/scrape.lua index 56a9227..11a0e73 100644 --- a/lua/cp/scrape.lua +++ b/lua/cp/scrape.lua @@ -7,6 +7,8 @@ ---@field problem_id string ---@field url? string ---@field tests? ScraperTestCase[] +---@field timeout_ms? number +---@field memory_mb? number ---@field error? string local M = {} @@ -86,7 +88,6 @@ function M.scrape_contest_metadata(platform, contest_id) end local plugin_path = get_plugin_path() - local scraper_path = plugin_path .. '/scrapers/' .. platform .. '.py' local args if platform == 'cses' then @@ -95,7 +96,8 @@ function M.scrape_contest_metadata(platform, contest_id) 'run', '--directory', plugin_path, - scraper_path, + '-m', + 'scrapers.' .. platform, 'metadata', } else @@ -104,7 +106,8 @@ function M.scrape_contest_metadata(platform, contest_id) 'run', '--directory', plugin_path, - scraper_path, + '-m', + 'scrapers.' .. platform, 'metadata', contest_id, } @@ -152,7 +155,7 @@ function M.scrape_contest_metadata(platform, contest_id) end ---@param ctx ProblemContext ----@return {success: boolean, problem_id: string, test_count?: number, test_cases?: ScraperTestCase[], url?: string, error?: string} +---@return {success: boolean, problem_id: string, test_count?: number, test_cases?: ScraperTestCase[], timeout_ms?: number, memory_mb?: number, url?: string, error?: string} function M.scrape_problem(ctx) vim.validate({ ctx = { ctx, 'table' }, @@ -209,7 +212,6 @@ function M.scrape_problem(ctx) end local plugin_path = get_plugin_path() - local scraper_path = plugin_path .. '/scrapers/' .. ctx.contest .. '.py' local args if ctx.contest == 'cses' then @@ -218,7 +220,8 @@ function M.scrape_problem(ctx) 'run', '--directory', plugin_path, - scraper_path, + '-m', + 'scrapers.' .. ctx.contest, 'tests', ctx.contest_id, } @@ -228,7 +231,8 @@ function M.scrape_problem(ctx) 'run', '--directory', plugin_path, - scraper_path, + '-m', + 'scrapers.' .. ctx.contest, 'tests', ctx.contest_id, ctx.problem_id, @@ -277,6 +281,24 @@ function M.scrape_problem(ctx) vim.fn.writefile(vim.split(input_content, '\n', true), input_file) vim.fn.writefile(vim.split(expected_content, '\n', true), expected_file) end + + local cached_test_cases = {} + for i, test_case in ipairs(data.tests) do + table.insert(cached_test_cases, { + index = i, + input = test_case.input, + expected = test_case.expected, + }) + end + + cache.set_test_cases( + ctx.contest, + ctx.contest_id, + ctx.problem_id, + cached_test_cases, + data.timeout_ms, + data.memory_mb + ) end return { @@ -284,6 +306,8 @@ function M.scrape_problem(ctx) problem_id = ctx.problem_name, test_count = data.tests and #data.tests or 0, test_cases = data.tests, + timeout_ms = data.timeout_ms, + memory_mb = data.memory_mb, url = data.url, } end diff --git a/lua/cp/test.lua b/lua/cp/test.lua index 3e0e147..e894d4a 100644 --- a/lua/cp/test.lua +++ b/lua/cp/test.lua @@ -12,6 +12,10 @@ ---@field signal string? ---@field timed_out boolean? +---@class ProblemConstraints +---@field timeout_ms number +---@field memory_mb number + ---@class RunPanelState ---@field test_cases TestCase[] ---@field current_index number @@ -19,6 +23,7 @@ ---@field namespace number? ---@field is_active boolean ---@field saved_layout table? +---@field constraints ProblemConstraints? local M = {} local constants = require('cp.constants') @@ -32,6 +37,7 @@ local run_panel_state = { namespace = nil, is_active = false, saved_layout = nil, + constraints = nil, } ---@param index number @@ -114,6 +120,25 @@ local function parse_test_cases_from_files(input_file, expected_file) return test_cases end +---@param platform string +---@param contest_id string +---@param problem_id string? +---@return ProblemConstraints? +local function load_constraints_from_cache(platform, contest_id, problem_id) + local cache = require('cp.cache') + cache.load() + local timeout_ms, memory_mb = cache.get_constraints(platform, contest_id, problem_id) + + if timeout_ms and memory_mb then + return { + timeout_ms = timeout_ms, + memory_mb = memory_mb, + } + end + + return nil +end + ---@param ctx ProblemContext ---@param contest_config ContestConfig ---@param test_case TestCase @@ -177,10 +202,15 @@ local function run_single_test_case(ctx, contest_config, cp_config, test_case) local stdin_content = test_case.input .. '\n' local start_time = vim.uv.hrtime() + local timeout_ms = run_panel_state.constraints and run_panel_state.constraints.timeout_ms or 2000 + + if not run_panel_state.constraints then + logger.log('no problem constraints available, using default 2000ms timeout') + end local result = vim .system(run_cmd, { stdin = stdin_content, - timeout = contest_config.timeout_ms or 2000, + timeout = timeout_ms, text = true, }) :wait() @@ -241,8 +271,17 @@ function M.load_test_cases(ctx, state) run_panel_state.test_cases = test_cases run_panel_state.current_index = 1 + run_panel_state.constraints = + load_constraints_from_cache(state.platform, state.contest_id, state.problem_id) - logger.log(('loaded %d test case(s)'):format(#test_cases)) + local constraint_info = run_panel_state.constraints + and string.format( + ' with %dms/%dMB limits', + run_panel_state.constraints.timeout_ms, + run_panel_state.constraints.memory_mb + ) + or '' + logger.log(('loaded %d test case(s)%s'):format(#test_cases, constraint_info)) return #test_cases > 0 end diff --git a/lua/cp/test_render.lua b/lua/cp/test_render.lua index 9bbe699..29a435a 100644 --- a/lua/cp/test_render.lua +++ b/lua/cp/test_render.lua @@ -53,27 +53,40 @@ local function format_exit_code(code) return signal_name and string.format('%d (%s)', code, signal_name) or tostring(code) end --- Compute column widths + aggregates local function compute_cols(test_state) - local w = { num = 3, status = 8, time = 6, exit = 11 } + local w = { num = 5, status = 8, time = 6, timeout = 8, memory = 8, exit = 11 } + + local timeout_str = '' + local memory_str = '' + if test_state.constraints then + timeout_str = tostring(test_state.constraints.timeout_ms) + memory_str = string.format('%.0f', test_state.constraints.memory_mb) + else + timeout_str = '—' + memory_str = '—' + end for i, tc in ipairs(test_state.test_cases) do local prefix = (i == test_state.current_index) and '>' or ' ' - w.num = math.max(w.num, #(prefix .. i)) - w.status = math.max(w.status, #(' ' .. M.get_status_info(tc).text)) - local time_str = tc.time_ms and (string.format('%.2f', tc.time_ms) .. 'ms') or '—' - w.time = math.max(w.time, #time_str) - w.exit = math.max(w.exit, #(' ' .. format_exit_code(tc.code))) + w.num = math.max(w.num, #(' ' .. prefix .. i .. ' ')) + w.status = math.max(w.status, #(' ' .. M.get_status_info(tc).text .. ' ')) + local time_str = tc.time_ms and string.format('%.2f', tc.time_ms) or '—' + w.time = math.max(w.time, #(' ' .. time_str .. ' ')) + w.timeout = math.max(w.timeout, #(' ' .. timeout_str .. ' ')) + w.memory = math.max(w.memory, #(' ' .. memory_str .. ' ')) + w.exit = math.max(w.exit, #(' ' .. format_exit_code(tc.code) .. ' ')) end - w.num = math.max(w.num, #' #') - w.status = math.max(w.status, #' Status') - w.time = math.max(w.time, #' Time') - w.exit = math.max(w.exit, #' Exit Code') + w.num = math.max(w.num, #' # ') + w.status = math.max(w.status, #' Status ') + w.time = math.max(w.time, #' Runtime (ms) ') + w.timeout = math.max(w.timeout, #' Time (ms) ') + w.memory = math.max(w.memory, #' Mem (MB) ') + w.exit = math.max(w.exit, #' Exit Code ') - local sum = w.num + w.status + w.time + w.exit - local inner = sum + 3 -- three inner vertical dividers - local total = inner + 2 -- two outer borders + local sum = w.num + w.status + w.time + w.timeout + w.memory + w.exit + local inner = sum + 5 + local total = inner + 2 return { w = w, sum = sum, inner = inner, total = total } end @@ -86,6 +99,32 @@ local function center(text, width) return string.rep(' ', left) .. text .. string.rep(' ', pad - left) end +local function right_align(text, width) + local content = (' %s '):format(text) + local pad = width - #content + if pad <= 0 then + return content + end + return string.rep(' ', pad) .. content +end + +local function format_num_column(prefix, idx, width) + local num_str = tostring(idx) + local content + if #num_str == 1 then + content = ' ' .. prefix .. ' ' .. num_str .. ' ' + else + content = ' ' .. prefix .. num_str .. ' ' + end + local total_pad = width - #content + if total_pad <= 0 then + return content + end + local left_pad = math.floor(total_pad / 2) + local right_pad = total_pad - left_pad + return string.rep(' ', left_pad) .. content .. string.rep(' ', right_pad) +end + local function top_border(c) local w = c.w return '┌' @@ -95,6 +134,10 @@ local function top_border(c) .. '┬' .. string.rep('─', w.time) .. '┬' + .. string.rep('─', w.timeout) + .. '┬' + .. string.rep('─', w.memory) + .. '┬' .. string.rep('─', w.exit) .. '┐' end @@ -108,6 +151,10 @@ local function row_sep(c) .. '┼' .. string.rep('─', w.time) .. '┼' + .. string.rep('─', w.timeout) + .. '┼' + .. string.rep('─', w.memory) + .. '┼' .. string.rep('─', w.exit) .. '┤' end @@ -121,6 +168,10 @@ local function bottom_border(c) .. '┴' .. string.rep('─', w.time) .. '┴' + .. string.rep('─', w.timeout) + .. '┴' + .. string.rep('─', w.memory) + .. '┴' .. string.rep('─', w.exit) .. '┘' end @@ -134,6 +185,10 @@ local function flat_fence_above(c) .. '┴' .. string.rep('─', w.time) .. '┴' + .. string.rep('─', w.timeout) + .. '┴' + .. string.rep('─', w.memory) + .. '┴' .. string.rep('─', w.exit) .. '┤' end @@ -147,6 +202,10 @@ local function flat_fence_below(c) .. '┬' .. string.rep('─', w.time) .. '┬' + .. string.rep('─', w.timeout) + .. '┬' + .. string.rep('─', w.memory) + .. '┬' .. string.rep('─', w.exit) .. '┤' end @@ -162,34 +221,52 @@ local function header_line(c) .. '│' .. center('Status', w.status) .. '│' - .. center('Time', w.time) + .. center('Runtime (ms)', w.time) + .. '│' + .. center('Time (ms)', w.timeout) + .. '│' + .. center('Mem (MB)', w.memory) .. '│' .. center('Exit Code', w.exit) .. '│' end -local function data_row(c, idx, tc, is_current) +local function data_row(c, idx, tc, is_current, test_state) local w = c.w local prefix = is_current and '>' or ' ' local status = M.get_status_info(tc) - local time = tc.time_ms and (string.format('%.2f', tc.time_ms) .. 'ms') or '—' + local time = tc.time_ms and string.format('%.2f', tc.time_ms) or '—' local exit = format_exit_code(tc.code) + local timeout = '' + local memory = '' + if test_state.constraints then + timeout = tostring(test_state.constraints.timeout_ms) + memory = string.format('%.0f', test_state.constraints.memory_mb) + else + timeout = '—' + memory = '—' + end + local line = '│' - .. center(prefix .. idx, w.num) + .. format_num_column(prefix, idx, w.num) .. '│' - .. center(status.text, w.status) + .. right_align(status.text, w.status) .. '│' - .. center(time, w.time) + .. right_align(time, w.time) .. '│' - .. center(exit, w.exit) + .. right_align(timeout, w.timeout) + .. '│' + .. right_align(memory, w.memory) + .. '│' + .. right_align(exit, w.exit) .. '│' local hi if status.text ~= '' then - local pad = w.status - #status.text - local left = math.floor(pad / 2) - local status_start_col = 1 + w.num + 1 + left + local content = ' ' .. status.text .. ' ' + local pad = w.status - #content + local status_start_col = 1 + w.num + 1 + pad + 1 local status_end_col = status_start_col + #status.text hi = { col_start = status_start_col, @@ -213,7 +290,7 @@ function M.render_test_list(test_state) for i, tc in ipairs(test_state.test_cases) do local is_current = (i == test_state.current_index) - local row, hi = data_row(c, i, tc, is_current) + local row, hi = data_row(c, i, tc, is_current, test_state) table.insert(lines, row) if hi then hi.line = #lines - 1 @@ -226,6 +303,10 @@ function M.render_test_list(test_state) if has_input then table.insert(lines, flat_fence_above(c)) + local input_header = 'Input:' + local header_pad = c.inner - #input_header + table.insert(lines, '│' .. input_header .. string.rep(' ', header_pad) .. '│') + for _, input_line in ipairs(vim.split(tc.input, '\n', { plain = true, trimempty = false })) do local s = input_line or '' if #s > c.inner then diff --git a/scrapers/atcoder.py b/scrapers/atcoder.py index b9b39ea..e251c44 100644 --- a/scrapers/atcoder.py +++ b/scrapers/atcoder.py @@ -1,18 +1,50 @@ #!/usr/bin/env python3 import json +import re import sys +from dataclasses import asdict import requests from bs4 import BeautifulSoup, Tag +from .models import MetadataResult, ProblemSummary, TestCase, TestsResult + + +def extract_problem_limits(soup: BeautifulSoup) -> tuple[int, float]: + timeout_ms = None + memory_mb = None + + paragraphs = soup.find_all("p") + for p in paragraphs: + text = p.get_text() + if "Time Limit:" in text and "Memory Limit:" in text: + time_match = re.search(r"Time Limit:\s*(\d+)\s*sec", text) + if time_match: + seconds = int(time_match.group(1)) + timeout_ms = seconds * 1000 + + memory_match = re.search(r"Memory Limit:\s*(\d+)\s*MiB", text) + if memory_match: + memory_mib = int(memory_match.group(1)) + memory_mb = round(memory_mib * 1.048576, 2) + break + + if timeout_ms is None: + raise ValueError("Could not find valid timeout in problem constraints") + + if memory_mb is None: + raise ValueError("Could not find valid memory limit in problem constraints") + + return timeout_ms, memory_mb + def parse_problem_url(contest_id: str, problem_letter: str) -> str: task_id: str = f"{contest_id}_{problem_letter}" return f"https://atcoder.jp/contests/{contest_id}/tasks/{task_id}" -def extract_problem_from_row(row, contest_id: str) -> dict[str, str] | None: +def extract_problem_from_row(row, contest_id: str) -> ProblemSummary | None: cells = row.find_all("td") if len(cells) < 2: return None @@ -34,10 +66,10 @@ def extract_problem_from_row(row, contest_id: str) -> dict[str, str] | None: if not problem_letter or not task_name: return None - return {"id": problem_letter.lower(), "name": task_name} + return ProblemSummary(id=problem_letter.lower(), name=task_name) -def scrape_contest_problems(contest_id: str) -> list[dict[str, str]]: +def scrape_contest_problems(contest_id: str) -> list[ProblemSummary]: try: contest_url = f"https://atcoder.jp/contests/{contest_id}/tasks" headers = { @@ -53,13 +85,13 @@ def scrape_contest_problems(contest_id: str) -> list[dict[str, str]]: return [] rows = task_table.find_all("tr")[1:] - problems: list[dict[str, str]] = [] + problems: list[ProblemSummary] = [] for row in rows: problem = extract_problem_from_row(row, contest_id) if problem: problems.append(problem) - problems.sort(key=lambda x: x["id"]) + problems.sort(key=lambda x: x.id) return problems except Exception as e: @@ -95,7 +127,7 @@ def extract_test_case_from_headers(sample_headers, i: int) -> tuple[str, str] | return (input_text, output_text) -def scrape(url: str) -> list[tuple[str, str]]: +def scrape(url: str) -> list[TestCase]: try: headers = { "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" @@ -109,12 +141,13 @@ def scrape(url: str) -> list[tuple[str, str]]: "h3", string=lambda x: x and "sample" in x.lower() if x else False ) - tests: list[tuple[str, str]] = [] + tests: list[TestCase] = [] i = 0 while i < len(sample_headers): test_case = extract_test_case_from_headers(sample_headers, i) if test_case: - tests.append(test_case) + input_text, output_text = test_case + tests.append(TestCase(input=input_text, expected=output_text)) i += 2 else: i += 1 @@ -128,64 +161,55 @@ def scrape(url: str) -> list[tuple[str, str]]: def main() -> None: if len(sys.argv) < 2: - print( - json.dumps( - { - "success": False, - "error": "Usage: atcoder.py metadata OR atcoder.py tests ", - } - ) + result = MetadataResult( + success=False, + error="Usage: atcoder.py metadata OR atcoder.py tests ", ) + print(json.dumps(asdict(result))) sys.exit(1) mode: str = sys.argv[1] if mode == "metadata": if len(sys.argv) != 3: - print( - json.dumps( - { - "success": False, - "error": "Usage: atcoder.py metadata ", - } - ) + result = MetadataResult( + success=False, + error="Usage: atcoder.py metadata ", ) + print(json.dumps(asdict(result))) sys.exit(1) contest_id: str = sys.argv[2] - problems: list[dict[str, str]] = scrape_contest_problems(contest_id) + problems: list[ProblemSummary] = scrape_contest_problems(contest_id) if not problems: - print( - json.dumps( - { - "success": False, - "error": f"No problems found for contest {contest_id}", - } - ) + result = MetadataResult( + success=False, + error=f"No problems found for contest {contest_id}", ) + print(json.dumps(asdict(result))) sys.exit(1) - print( - json.dumps( - { - "success": True, - "contest_id": contest_id, - "problems": problems, - } - ) + result = MetadataResult( + success=True, + error="", + contest_id=contest_id, + problems=problems, ) + print(json.dumps(asdict(result))) elif mode == "tests": if len(sys.argv) != 4: - print( - json.dumps( - { - "success": False, - "error": "Usage: atcoder.py tests ", - } - ) + tests_result = TestsResult( + success=False, + error="Usage: atcoder.py tests ", + problem_id="", + url="", + tests=[], + timeout_ms=0, + memory_mb=0, ) + print(json.dumps(asdict(tests_result))) sys.exit(1) test_contest_id: str = sys.argv[2] @@ -193,46 +217,59 @@ def main() -> None: problem_id: str = f"{test_contest_id}_{problem_letter.lower()}" url: str = parse_problem_url(test_contest_id, problem_letter) - print(f"Scraping: {url}", file=sys.stderr) + tests: list[TestCase] = scrape(url) - tests: list[tuple[str, str]] = scrape(url) - if not tests: - print( - json.dumps( - { - "success": False, - "error": f"No tests found for {test_contest_id} {problem_letter}", - "problem_id": problem_id, - "url": url, - } - ) + try: + headers = { + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" + } + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + soup = BeautifulSoup(response.text, "html.parser") + timeout_ms, memory_mb = extract_problem_limits(soup) + except Exception as e: + tests_result = TestsResult( + success=False, + error=f"Failed to extract constraints: {e}", + problem_id=problem_id, + url=url, + tests=[], + timeout_ms=0, + memory_mb=0, ) + print(json.dumps(asdict(tests_result))) sys.exit(1) - test_list: list[dict[str, str]] = [ - {"input": i, "expected": o} for i, o in tests - ] - - print( - json.dumps( - { - "success": True, - "problem_id": problem_id, - "url": url, - "tests": test_list, - } + if not tests: + tests_result = TestsResult( + success=False, + error=f"No tests found for {test_contest_id} {problem_letter}", + problem_id=problem_id, + url=url, + tests=[], + timeout_ms=timeout_ms, + memory_mb=memory_mb, ) + print(json.dumps(asdict(tests_result))) + sys.exit(1) + + tests_result = TestsResult( + success=True, + error="", + problem_id=problem_id, + url=url, + tests=tests, + timeout_ms=timeout_ms, + memory_mb=memory_mb, ) + print(json.dumps(asdict(tests_result))) else: - print( - json.dumps( - { - "success": False, - "error": f"Unknown mode: {mode}. Use 'metadata' or 'tests'", - } - ) + result = MetadataResult( + success=False, + error=f"Unknown mode: {mode}. Use 'metadata' or 'tests'", ) + print(json.dumps(asdict(result))) sys.exit(1) diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index c193a21..a66acbd 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -7,7 +7,7 @@ from dataclasses import asdict import cloudscraper from bs4 import BeautifulSoup, Tag -from .models import MetadataResult, Problem, TestCase, TestsResult +from .models import MetadataResult, ProblemSummary, TestCase, TestsResult def scrape(url: str) -> list[TestCase]: @@ -140,7 +140,37 @@ def parse_problem_url(contest_id: str, problem_letter: str) -> str: ) -def scrape_contest_problems(contest_id: str) -> list[Problem]: +def extract_problem_limits(soup: BeautifulSoup) -> tuple[int, float]: + import re + + timeout_ms = None + memory_mb = None + + time_limit_div = soup.find("div", class_="time-limit") + if time_limit_div: + text = time_limit_div.get_text().strip() + match = re.search(r"(\d+) seconds?", text) + if match: + seconds = int(match.group(1)) + timeout_ms = seconds * 1000 + + if timeout_ms is None: + raise ValueError("Could not find valid timeout in time-limit section") + + memory_limit_div = soup.find("div", class_="memory-limit") + if memory_limit_div: + text = memory_limit_div.get_text().strip() + match = re.search(r"(\d+) megabytes", text) + if match: + memory_mb = float(match.group(1)) + + if memory_mb is None: + raise ValueError("Could not find valid memory limit in memory-limit section") + + return timeout_ms, memory_mb + + +def scrape_contest_problems(contest_id: str) -> list[ProblemSummary]: try: contest_url: str = f"https://codeforces.com/contest/{contest_id}" scraper = cloudscraper.create_scraper() @@ -148,7 +178,7 @@ def scrape_contest_problems(contest_id: str) -> list[Problem]: response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") - problems: list[Problem] = [] + problems: list[ProblemSummary] = [] problem_links = soup.find_all( "a", href=lambda x: x and f"/contest/{contest_id}/problem/" in x @@ -163,12 +193,14 @@ def scrape_contest_problems(contest_id: str) -> list[Problem]: problem_name: str = link.get_text(strip=True) if problem_letter and problem_name: - problems.append(Problem(id=problem_letter, name=problem_name)) + problems.append( + ProblemSummary(id=problem_letter, name=problem_name) + ) problems.sort(key=lambda x: x.id) seen: set[str] = set() - unique_problems: list[Problem] = [] + unique_problems: list[ProblemSummary] = [] for p in problems: if p.id not in seen: seen.add(p.id) @@ -206,7 +238,7 @@ def main() -> None: sys.exit(1) contest_id: str = sys.argv[2] - problems: list[Problem] = scrape_contest_problems(contest_id) + problems: list[ProblemSummary] = scrape_contest_problems(contest_id) if not problems: result = MetadataResult( @@ -215,7 +247,9 @@ def main() -> None: print(json.dumps(asdict(result))) sys.exit(1) - result = MetadataResult(success=True, contest_id=contest_id, problems=problems) + result = MetadataResult( + success=True, error="", contest_id=contest_id, problems=problems + ) print(json.dumps(asdict(result))) elif mode == "tests": @@ -223,6 +257,11 @@ def main() -> None: tests_result = TestsResult( success=False, error="Usage: codeforces.py tests ", + problem_id="", + url="", + tests=[], + timeout_ms=0, + memory_mb=0, ) print(json.dumps(asdict(tests_result))) sys.exit(1) @@ -234,18 +273,46 @@ def main() -> None: url: str = parse_problem_url(tests_contest_id, problem_letter) tests: list[TestCase] = scrape_sample_tests(url) + try: + scraper = cloudscraper.create_scraper() + response = scraper.get(url, timeout=10) + response.raise_for_status() + soup = BeautifulSoup(response.text, "html.parser") + timeout_ms, memory_mb = extract_problem_limits(soup) + except Exception as e: + tests_result = TestsResult( + success=False, + error=f"Failed to extract constraints: {e}", + problem_id=problem_id, + url=url, + tests=[], + timeout_ms=0, + memory_mb=0, + ) + print(json.dumps(asdict(tests_result))) + sys.exit(1) + if not tests: tests_result = TestsResult( success=False, error=f"No tests found for {tests_contest_id} {problem_letter}", problem_id=problem_id, url=url, + tests=[], + timeout_ms=timeout_ms, + memory_mb=memory_mb, ) print(json.dumps(asdict(tests_result))) sys.exit(1) tests_result = TestsResult( - success=True, problem_id=problem_id, url=url, tests=tests + success=True, + error="", + problem_id=problem_id, + url=url, + tests=tests, + timeout_ms=timeout_ms, + memory_mb=memory_mb, ) print(json.dumps(asdict(tests_result))) diff --git a/scrapers/cses.py b/scrapers/cses.py index 16a9c18..edf3224 100755 --- a/scrapers/cses.py +++ b/scrapers/cses.py @@ -1,10 +1,14 @@ #!/usr/bin/env python3 import json +import re import sys +from dataclasses import asdict import requests -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup, Tag + +from .models import MetadataResult, ProblemSummary, TestCase, TestsResult def parse_problem_url(problem_input: str) -> str | None: @@ -15,10 +19,43 @@ def parse_problem_url(problem_input: str) -> str | None: return None +def extract_problem_limits(soup: BeautifulSoup) -> tuple[int, float]: + timeout_ms = None + memory_mb = None + + constraints_ul = soup.find("ul", class_="task-constraints") + if not constraints_ul or not isinstance(constraints_ul, Tag): + raise ValueError("Could not find task-constraints section") + + for li in constraints_ul.find_all("li"): + text = li.get_text() + + if "Time limit:" in text: + match = re.search(r"Time limit:\s*(\d+(?:\.\d+)?)\s*s", text) + if match: + seconds = float(match.group(1)) + timeout_ms = int(seconds * 1000) + + if "Memory limit:" in text: + match = re.search(r"Memory limit:\s*(\d+)\s*MB", text) + if match: + memory_mb = float(match.group(1)) + + if timeout_ms is None: + raise ValueError("Could not find valid timeout in task-constraints section") + + if memory_mb is None: + raise ValueError( + "Could not find valid memory limit in task-constraints section" + ) + + return timeout_ms, memory_mb + + def process_problem_element( element, current_category: str | None, - all_categories: dict[str, list[dict[str, str]]], + all_categories: dict[str, list[ProblemSummary]], ) -> str | None: if element.name == "h1": category_name = element.get_text().strip() @@ -39,11 +76,12 @@ def process_problem_element( if not (problem_id.isdigit() and problem_name and current_category): return current_category - all_categories[current_category].append({"id": problem_id, "name": problem_name}) + problem = ProblemSummary(id=problem_id, name=problem_name) + all_categories[current_category].append(problem) return current_category -def scrape_all_problems() -> dict[str, list[dict[str, str]]]: +def scrape_all_problems() -> dict[str, list[ProblemSummary]]: try: problemset_url = "https://cses.fi/problemset/" headers = { @@ -54,7 +92,7 @@ def scrape_all_problems() -> dict[str, list[dict[str, str]]]: response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") - all_categories: dict[str, list[dict[str, str]]] = {} + all_categories: dict[str, list[ProblemSummary]] = {} problem_links = soup.find_all( "a", href=lambda x: x and "/problemset/task/" in x @@ -68,7 +106,7 @@ def scrape_all_problems() -> dict[str, list[dict[str, str]]]: ) for category in all_categories: - all_categories[category].sort(key=lambda x: int(x["id"])) + all_categories[category].sort(key=lambda x: int(x.id)) print(f"Found {len(all_categories)} categories", file=sys.stderr) return all_categories @@ -105,7 +143,7 @@ def extract_example_test_case(soup) -> tuple[str, str] | None: return (input_text, output_text) -def scrape(url: str) -> list[tuple[str, str]]: +def scrape(url: str) -> list[TestCase]: try: headers = { "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" @@ -120,7 +158,8 @@ def scrape(url: str) -> list[tuple[str, str]]: if not test_case: return [] - return [test_case] + input_text, output_text = test_case + return [TestCase(input=input_text, expected=output_text)] except Exception as e: print(f"Error scraping CSES: {e}", file=sys.stderr) @@ -129,124 +168,125 @@ def scrape(url: str) -> list[tuple[str, str]]: def main() -> None: if len(sys.argv) < 2: - print( - json.dumps( - { - "success": False, - "error": "Usage: cses.py metadata OR cses.py tests ", - } - ) + result = MetadataResult( + success=False, + error="Usage: cses.py metadata OR cses.py tests ", ) + print(json.dumps(asdict(result))) sys.exit(1) mode: str = sys.argv[1] if mode == "metadata": if len(sys.argv) != 2: - print( - json.dumps( - { - "success": False, - "error": "Usage: cses.py metadata", - } - ) + result = MetadataResult( + success=False, + error="Usage: cses.py metadata", ) + print(json.dumps(asdict(result))) sys.exit(1) - all_categories: dict[str, list[dict[str, str]]] = scrape_all_problems() + all_categories: dict[str, list[ProblemSummary]] = scrape_all_problems() if not all_categories: - print( - json.dumps( - { - "success": False, - "error": "Failed to scrape CSES problem categories", - } - ) + result = MetadataResult( + success=False, + error="Failed to scrape CSES problem categories", ) + print(json.dumps(asdict(result))) sys.exit(1) - print( - json.dumps( - { - "success": True, - "categories": all_categories, - } - ) - ) + result = MetadataResult(success=True, error="", categories=all_categories) + print(json.dumps(asdict(result))) elif mode == "tests": if len(sys.argv) != 3: - print( - json.dumps( - { - "success": False, - "error": "Usage: cses.py tests ", - } - ) + tests_result = TestsResult( + success=False, + error="Usage: cses.py tests ", + problem_id="", + url="", + tests=[], + timeout_ms=0, + memory_mb=0, ) + print(json.dumps(asdict(tests_result))) sys.exit(1) problem_input: str = sys.argv[2] url: str | None = parse_problem_url(problem_input) if not url: - print( - json.dumps( - { - "success": False, - "error": f"Invalid problem input: {problem_input}. Use either problem ID (e.g., 1068) or full URL", - "problem_id": problem_input - if problem_input.isdigit() - else None, - } - ) + tests_result = TestsResult( + success=False, + error=f"Invalid problem input: {problem_input}. Use either problem ID (e.g., 1068) or full URL", + problem_id=problem_input if problem_input.isdigit() else "", + url="", + tests=[], + timeout_ms=0, + memory_mb=0, ) + print(json.dumps(asdict(tests_result))) sys.exit(1) - tests: list[tuple[str, str]] = scrape(url) + tests: list[TestCase] = scrape(url) problem_id: str = ( problem_input if problem_input.isdigit() else problem_input.split("/")[-1] ) - if not tests: - print( - json.dumps( - { - "success": False, - "error": f"No tests found for {problem_input}", - "problem_id": problem_id, - "url": url, - } - ) + try: + headers = { + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" + } + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + soup = BeautifulSoup(response.text, "html.parser") + timeout_ms, memory_mb = extract_problem_limits(soup) + except Exception as e: + tests_result = TestsResult( + success=False, + error=f"Failed to extract constraints: {e}", + problem_id=problem_id, + url=url, + tests=[], + timeout_ms=0, + memory_mb=0, ) + print(json.dumps(asdict(tests_result))) sys.exit(1) - test_list: list[dict[str, str]] = [ - {"input": i, "expected": o} for i, o in tests - ] - - print( - json.dumps( - { - "success": True, - "problem_id": problem_id, - "url": url, - "tests": test_list, - } + if not tests: + tests_result = TestsResult( + success=False, + error=f"No tests found for {problem_input}", + problem_id=problem_id, + url=url, + tests=[], + timeout_ms=timeout_ms, + memory_mb=memory_mb, ) + print(json.dumps(asdict(tests_result))) + sys.exit(1) + + test_cases = tests + tests_result = TestsResult( + success=True, + error="", + problem_id=problem_id, + url=url, + tests=test_cases, + timeout_ms=timeout_ms, + memory_mb=memory_mb, ) + print(json.dumps(asdict(tests_result))) else: - print( - json.dumps( - { - "success": False, - "error": f"Unknown mode: {mode}. Use 'metadata' or 'tests'", - } - ) + result = MetadataResult( + success=False, + error=f"Unknown mode: {mode}. Use 'metadata' or 'tests'", ) + print(json.dumps(asdict(result))) sys.exit(1) diff --git a/scrapers/models.py b/scrapers/models.py index ea0e03e..728e9bb 100644 --- a/scrapers/models.py +++ b/scrapers/models.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass @@ -8,7 +8,7 @@ class TestCase: @dataclass -class Problem: +class ProblemSummary: id: str name: str @@ -16,26 +16,20 @@ class Problem: @dataclass class ScrapingResult: success: bool - error: str | None = None + error: str @dataclass class MetadataResult(ScrapingResult): - contest_id: str | None = None - problems: list[Problem] | None = None - categories: dict[str, list[Problem]] | None = None - - def __post_init__(self): - if self.problems is None: - self.problems = [] + contest_id: str = "" + problems: list[ProblemSummary] = field(default_factory=list) + categories: dict[str, list[ProblemSummary]] = field(default_factory=dict) @dataclass class TestsResult(ScrapingResult): - problem_id: str = "" - url: str = "" - tests: list[TestCase] | None = None - - def __post_init__(self): - if self.tests is None: - self.tests = [] + problem_id: str + url: str + tests: list[TestCase] + timeout_ms: int + memory_mb: float diff --git a/spec/scraper_spec.lua b/spec/scraper_spec.lua index aacf29a..a06d97e 100644 --- a/spec/scraper_spec.lua +++ b/spec/scraper_spec.lua @@ -13,6 +13,7 @@ describe('cp.scrape', function() return nil end, set_contest_data = function() end, + set_test_cases = function() end, } mock_system_calls = {} @@ -31,7 +32,7 @@ describe('cp.scrape', function() result.stdout = '{"success": true, "problems": [{"id": "a", "name": "Test Problem"}]}' elseif vim.tbl_contains(cmd, 'tests') then result.stdout = - '{"success": true, "tests": [{"input": "1 2", "expected": "3"}], "url": "https://example.com"}' + '{"success": true, "tests": [{"input": "1 2", "expected": "3"}], "url": "https://example.com", "timeout_ms": 2000, "memory_mb": 256.0}' end end diff --git a/spec/test_render_spec.lua b/spec/test_render_spec.lua index f7ace52..aff3c3b 100644 --- a/spec/test_render_spec.lua +++ b/spec/test_render_spec.lua @@ -78,7 +78,7 @@ describe('cp.test_render', function() local result = test_render.render_test_list(test_state) local found_current = false for _, line in ipairs(result) do - if line:match('│.*>2.*│') then + if line:match('│.*> 2.*│') then found_current = true break end diff --git a/tests/scrapers/test_atcoder.py b/tests/scrapers/test_atcoder.py index 5086894..95ff09d 100644 --- a/tests/scrapers/test_atcoder.py +++ b/tests/scrapers/test_atcoder.py @@ -1,5 +1,6 @@ from unittest.mock import Mock from scrapers.atcoder import scrape, scrape_contest_problems +from scrapers.models import ProblemSummary def test_scrape_success(mocker, mock_atcoder_html): @@ -11,8 +12,8 @@ def test_scrape_success(mocker, mock_atcoder_html): result = scrape("https://atcoder.jp/contests/abc350/tasks/abc350_a") assert len(result) == 1 - assert result[0][0] == "3\n1 2 3" - assert result[0][1] == "6" + assert result[0].input == "3\n1 2 3" + assert result[0].expected == "6" def test_scrape_contest_problems(mocker): @@ -36,8 +37,8 @@ def test_scrape_contest_problems(mocker): result = scrape_contest_problems("abc350") assert len(result) == 2 - assert result[0] == {"id": "a", "name": "A - Water Tank"} - assert result[1] == {"id": "b", "name": "B - Dentist Aoki"} + assert result[0] == ProblemSummary(id="a", name="A - Water Tank") + assert result[1] == ProblemSummary(id="b", name="B - Dentist Aoki") def test_scrape_network_error(mocker): diff --git a/tests/scrapers/test_codeforces.py b/tests/scrapers/test_codeforces.py index 67277ea..1fbfbd1 100644 --- a/tests/scrapers/test_codeforces.py +++ b/tests/scrapers/test_codeforces.py @@ -1,6 +1,6 @@ from unittest.mock import Mock from scrapers.codeforces import scrape, scrape_contest_problems -from scrapers.models import Problem +from scrapers.models import ProblemSummary def test_scrape_success(mocker, mock_codeforces_html): @@ -36,8 +36,8 @@ def test_scrape_contest_problems(mocker): result = scrape_contest_problems("1900") assert len(result) == 2 - assert result[0] == Problem(id="a", name="A. Problem A") - assert result[1] == Problem(id="b", name="B. Problem B") + assert result[0] == ProblemSummary(id="a", name="A. Problem A") + assert result[1] == ProblemSummary(id="b", name="B. Problem B") def test_scrape_network_error(mocker): diff --git a/tests/scrapers/test_cses.py b/tests/scrapers/test_cses.py index 1dd0096..c91b0f8 100644 --- a/tests/scrapers/test_cses.py +++ b/tests/scrapers/test_cses.py @@ -1,5 +1,6 @@ from unittest.mock import Mock from scrapers.cses import scrape, scrape_all_problems +from scrapers.models import ProblemSummary def test_scrape_success(mocker, mock_cses_html): @@ -11,8 +12,8 @@ def test_scrape_success(mocker, mock_cses_html): result = scrape("https://cses.fi/problemset/task/1068") assert len(result) == 1 - assert result[0][0] == "3\n1 2 3" - assert result[0][1] == "6" + assert result[0].input == "3\n1 2 3" + assert result[0].expected == "6" def test_scrape_all_problems(mocker): @@ -32,10 +33,10 @@ def test_scrape_all_problems(mocker): assert "Introductory Problems" in result assert "Sorting and Searching" in result assert len(result["Introductory Problems"]) == 2 - assert result["Introductory Problems"][0] == { - "id": "1068", - "name": "Weird Algorithm", - } + assert result["Introductory Problems"][0] == ProblemSummary( + id="1068", + name="Weird Algorithm", + ) def test_scrape_network_error(mocker):