From fef73887e466bde52040f7b7754bf644de059054 Mon Sep 17 00:00:00 2001 From: Barrett Ruth Date: Tue, 4 Nov 2025 08:15:08 -0500 Subject: [PATCH] feat(io): multi-test case view --- lua/cp/cache.lua | 7 ++++++- lua/cp/scraper.lua | 1 + lua/cp/setup.lua | 27 ++++++++++++++++++++++----- lua/cp/ui/views.lua | 30 ++++++++++++++++++++++++++---- scrapers/codeforces.py | 35 ++++++++++------------------------- scrapers/models.py | 1 + 6 files changed, 66 insertions(+), 35 deletions(-) diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 86d806f..20cbd98 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -20,6 +20,7 @@ ---@field id string ---@field name? string ---@field interactive? boolean +---@field multi_test? boolean ---@field memory_mb? number ---@field timeout_ms? number ---@field test_cases TestCase[] @@ -187,6 +188,7 @@ end ---@param timeout_ms number ---@param memory_mb number ---@param interactive boolean +---@param multi_test boolean function M.set_test_cases( platform, contest_id, @@ -194,7 +196,8 @@ function M.set_test_cases( test_cases, timeout_ms, memory_mb, - interactive + interactive, + multi_test ) vim.validate({ platform = { platform, 'string' }, @@ -204,6 +207,7 @@ function M.set_test_cases( timeout_ms = { timeout_ms, { 'number', 'nil' }, true }, memory_mb = { memory_mb, { 'number', 'nil' }, true }, interactive = { interactive, { 'boolean', 'nil' }, true }, + multi_test = { multi_test, { 'boolean', 'nil' }, true }, }) local index = cache_data[platform][contest_id].index_map[problem_id] @@ -212,6 +216,7 @@ function M.set_test_cases( cache_data[platform][contest_id].problems[index].timeout_ms = timeout_ms cache_data[platform][contest_id].problems[index].memory_mb = memory_mb cache_data[platform][contest_id].problems[index].interactive = interactive + cache_data[platform][contest_id].problems[index].multi_test = multi_test M.save() end diff --git a/lua/cp/scraper.lua b/lua/cp/scraper.lua index f8cb817..7745ffb 100644 --- a/lua/cp/scraper.lua +++ b/lua/cp/scraper.lua @@ -198,6 +198,7 @@ function M.scrape_all_tests(platform, contest_id, callback) timeout_ms = ev.timeout_ms or 0, memory_mb = ev.memory_mb or 0, interactive = ev.interactive or false, + multi_test = ev.multi_test or false, problem_id = ev.problem_id, }) end diff --git a/lua/cp/setup.lua b/lua/cp/setup.lua index a079697..0028cfc 100644 --- a/lua/cp/setup.lua +++ b/lua/cp/setup.lua @@ -98,18 +98,35 @@ local function start_tests(platform, contest_id, problems) cached_tests, ev.timeout_ms or 0, ev.memory_mb or 0, - ev.interactive + ev.interactive, + ev.multi_test ) local io_state = state.get_io_view_state() if io_state then - local test_cases = cache.get_test_cases(platform, contest_id, state.get_problem_id()) + local problem_id = state.get_problem_id() + local test_cases = cache.get_test_cases(platform, contest_id, problem_id) local input_lines = {} - for _, tc in ipairs(test_cases) do - for _, line in ipairs(vim.split(tc.input, '\n')) do - table.insert(input_lines, line) + + local contest_data = cache.get_contest_data(platform, contest_id) + local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test + + if is_multi_test and #test_cases > 1 then + table.insert(input_lines, tostring(#test_cases)) + for _, tc in ipairs(test_cases) do + local stripped = tc.input:gsub('^1\n', '') + for _, line in ipairs(vim.split(stripped, '\n')) do + table.insert(input_lines, line) + end + end + else + for _, tc in ipairs(test_cases) do + for _, line in ipairs(vim.split(tc.input, '\n')) do + table.insert(input_lines, line) + end end end + require('cp.utils').update_buffer_content(io_state.input_buf, input_lines, nil, nil) end end) diff --git a/lua/cp/ui/views.lua b/lua/cp/ui/views.lua index 4cdc8d6..0cf8cc5 100644 --- a/lua/cp/ui/views.lua +++ b/lua/cp/ui/views.lua @@ -315,9 +315,21 @@ function M.ensure_io_view() local test_cases = cache.get_test_cases(platform, contest_id, problem_id) if test_cases and #test_cases > 0 then local input_lines = {} - for _, tc in ipairs(test_cases) do - for _, line in ipairs(vim.split(tc.input, '\n')) do - table.insert(input_lines, line) + local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test + + if is_multi_test and #test_cases > 1 then + table.insert(input_lines, tostring(#test_cases)) + for _, tc in ipairs(test_cases) do + local stripped = tc.input:gsub('^1\n', '') + for _, line in ipairs(vim.split(stripped, '\n')) do + table.insert(input_lines, line) + end + end + else + for _, tc in ipairs(test_cases) do + for _, line in ipairs(vim.split(tc.input, '\n')) do + table.insert(input_lines, line) + end end end utils.update_buffer_content(input_buf, input_lines, nil, nil) @@ -437,6 +449,12 @@ function M.run_io_view(test_index, debug) ) end + local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test + + if is_multi_test and #test_indices > 1 then + table.insert(input_lines, tostring(#test_indices)) + end + for _, idx in ipairs(test_indices) do local tc = test_state.test_cases[idx] @@ -479,7 +497,11 @@ function M.run_io_view(test_index, debug) end end - for _, line in ipairs(vim.split(tc.input, '\n')) do + local test_input = tc.input + if is_multi_test and #test_indices > 1 then + test_input = test_input:gsub('^1\n', '') + end + for _, line in ipairs(vim.split(test_input, '\n')) do table.insert(input_lines, line) end end diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index ab2f555..33a5e11 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -83,10 +83,10 @@ def _extract_title(block: Tag) -> tuple[str, str]: return parts[0].strip().upper(), parts[1].strip() -def _extract_samples(block: Tag, multi_test: bool = False) -> list[TestCase]: +def _extract_samples(block: Tag) -> tuple[list[TestCase], bool]: st = block.find("div", class_="sample-test") if not st: - return [] + return [], False input_pres: list[Tag] = [ # type: ignore[misc] inp.find("pre") # type: ignore[misc] @@ -126,19 +126,16 @@ def _extract_samples(block: Tag, multi_test: bool = False) -> list[TestCase]: ) for k in keys ] - if multi_test: - return [TestCase(input=f"1\n{tc.input}", expected=tc.expected) for tc in samples] - return samples + samples_with_prefix = [ + TestCase(input=f"1\n{tc.input}", expected=tc.expected) for tc in samples + ] + return samples_with_prefix, True inputs = [_text_from_pre(p) for p in input_pres] outputs = [_text_from_pre(p) for p in output_pres] n = min(len(inputs), len(outputs)) samples = [TestCase(input=inputs[i], expected=outputs[i]) for i in range(n)] - - if multi_test and samples: - return [TestCase(input=f"1\n{tc.input}", expected=tc.expected) for tc in samples] - - return samples + return samples, False def _is_interactive(block: Tag) -> bool: @@ -147,19 +144,6 @@ def _is_interactive(block: Tag) -> bool: return "This is an interactive problem" in txt -def _is_multi_test_case(block: Tag) -> bool: - input_spec = block.find("div", class_="input-specification") - if not input_spec: - return False - txt = input_spec.get_text(" ", strip=True).lower() - patterns = [ - r"first line.*contains.*integer.*number of test case", - r"first line.*integer.*denoting.*number of test case", - r"first line.*number of test case", - ] - return any(re.search(pattern, txt) for pattern in patterns) - - def _fetch_problems_html(contest_id: str) -> str: url = f"{BASE_URL}/contest/{contest_id}/problems" page = StealthyFetcher.fetch( @@ -180,8 +164,7 @@ def _parse_all_blocks(html: str) -> list[dict[str, Any]]: name = _extract_title(b)[1] if not letter: continue - multi_test = _is_multi_test_case(b) - tests = _extract_samples(b, multi_test) + tests, multi_test = _extract_samples(b) timeout_ms, memory_mb = _extract_limits(b) interactive = _is_interactive(b) out.append( @@ -192,6 +175,7 @@ def _parse_all_blocks(html: str) -> list[dict[str, Any]]: "timeout_ms": timeout_ms, "memory_mb": memory_mb, "interactive": interactive, + "multi_test": multi_test, } ) return out @@ -274,6 +258,7 @@ class CodeforcesScraper(BaseScraper): "timeout_ms": b.get("timeout_ms", 0), "memory_mb": b.get("memory_mb", 0), "interactive": bool(b.get("interactive")), + "multi_test": bool(b.get("multi_test", False)), } ), flush=True, diff --git a/scrapers/models.py b/scrapers/models.py index 2a954ef..95b7982 100644 --- a/scrapers/models.py +++ b/scrapers/models.py @@ -50,6 +50,7 @@ class TestsResult(ScrapingResult): timeout_ms: int memory_mb: float interactive: bool = False + multi_test: bool = False model_config = ConfigDict(extra="forbid")