feat(io): multi-test case view

This commit is contained in:
Barrett Ruth 2025-11-04 08:15:08 -05:00
parent 3654748632
commit fef73887e4
6 changed files with 66 additions and 35 deletions

View file

@ -20,6 +20,7 @@
---@field id string
---@field name? string
---@field interactive? boolean
---@field multi_test? boolean
---@field memory_mb? number
---@field timeout_ms? number
---@field test_cases TestCase[]
@ -187,6 +188,7 @@ end
---@param timeout_ms number
---@param memory_mb number
---@param interactive boolean
---@param multi_test boolean
function M.set_test_cases(
platform,
contest_id,
@ -194,7 +196,8 @@ function M.set_test_cases(
test_cases,
timeout_ms,
memory_mb,
interactive
interactive,
multi_test
)
vim.validate({
platform = { platform, 'string' },
@ -204,6 +207,7 @@ function M.set_test_cases(
timeout_ms = { timeout_ms, { 'number', 'nil' }, true },
memory_mb = { memory_mb, { 'number', 'nil' }, true },
interactive = { interactive, { 'boolean', 'nil' }, true },
multi_test = { multi_test, { 'boolean', 'nil' }, true },
})
local index = cache_data[platform][contest_id].index_map[problem_id]
@ -212,6 +216,7 @@ function M.set_test_cases(
cache_data[platform][contest_id].problems[index].timeout_ms = timeout_ms
cache_data[platform][contest_id].problems[index].memory_mb = memory_mb
cache_data[platform][contest_id].problems[index].interactive = interactive
cache_data[platform][contest_id].problems[index].multi_test = multi_test
M.save()
end

View file

@ -198,6 +198,7 @@ function M.scrape_all_tests(platform, contest_id, callback)
timeout_ms = ev.timeout_ms or 0,
memory_mb = ev.memory_mb or 0,
interactive = ev.interactive or false,
multi_test = ev.multi_test or false,
problem_id = ev.problem_id,
})
end

View file

@ -98,18 +98,35 @@ local function start_tests(platform, contest_id, problems)
cached_tests,
ev.timeout_ms or 0,
ev.memory_mb or 0,
ev.interactive
ev.interactive,
ev.multi_test
)
local io_state = state.get_io_view_state()
if io_state then
local test_cases = cache.get_test_cases(platform, contest_id, state.get_problem_id())
local problem_id = state.get_problem_id()
local test_cases = cache.get_test_cases(platform, contest_id, problem_id)
local input_lines = {}
for _, tc in ipairs(test_cases) do
for _, line in ipairs(vim.split(tc.input, '\n')) do
table.insert(input_lines, line)
local contest_data = cache.get_contest_data(platform, contest_id)
local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test
if is_multi_test and #test_cases > 1 then
table.insert(input_lines, tostring(#test_cases))
for _, tc in ipairs(test_cases) do
local stripped = tc.input:gsub('^1\n', '')
for _, line in ipairs(vim.split(stripped, '\n')) do
table.insert(input_lines, line)
end
end
else
for _, tc in ipairs(test_cases) do
for _, line in ipairs(vim.split(tc.input, '\n')) do
table.insert(input_lines, line)
end
end
end
require('cp.utils').update_buffer_content(io_state.input_buf, input_lines, nil, nil)
end
end)

View file

@ -315,9 +315,21 @@ function M.ensure_io_view()
local test_cases = cache.get_test_cases(platform, contest_id, problem_id)
if test_cases and #test_cases > 0 then
local input_lines = {}
for _, tc in ipairs(test_cases) do
for _, line in ipairs(vim.split(tc.input, '\n')) do
table.insert(input_lines, line)
local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test
if is_multi_test and #test_cases > 1 then
table.insert(input_lines, tostring(#test_cases))
for _, tc in ipairs(test_cases) do
local stripped = tc.input:gsub('^1\n', '')
for _, line in ipairs(vim.split(stripped, '\n')) do
table.insert(input_lines, line)
end
end
else
for _, tc in ipairs(test_cases) do
for _, line in ipairs(vim.split(tc.input, '\n')) do
table.insert(input_lines, line)
end
end
end
utils.update_buffer_content(input_buf, input_lines, nil, nil)
@ -437,6 +449,12 @@ function M.run_io_view(test_index, debug)
)
end
local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test
if is_multi_test and #test_indices > 1 then
table.insert(input_lines, tostring(#test_indices))
end
for _, idx in ipairs(test_indices) do
local tc = test_state.test_cases[idx]
@ -479,7 +497,11 @@ function M.run_io_view(test_index, debug)
end
end
for _, line in ipairs(vim.split(tc.input, '\n')) do
local test_input = tc.input
if is_multi_test and #test_indices > 1 then
test_input = test_input:gsub('^1\n', '')
end
for _, line in ipairs(vim.split(test_input, '\n')) do
table.insert(input_lines, line)
end
end

View file

@ -83,10 +83,10 @@ def _extract_title(block: Tag) -> tuple[str, str]:
return parts[0].strip().upper(), parts[1].strip()
def _extract_samples(block: Tag, multi_test: bool = False) -> list[TestCase]:
def _extract_samples(block: Tag) -> tuple[list[TestCase], bool]:
st = block.find("div", class_="sample-test")
if not st:
return []
return [], False
input_pres: list[Tag] = [ # type: ignore[misc]
inp.find("pre") # type: ignore[misc]
@ -126,19 +126,16 @@ def _extract_samples(block: Tag, multi_test: bool = False) -> list[TestCase]:
)
for k in keys
]
if multi_test:
return [TestCase(input=f"1\n{tc.input}", expected=tc.expected) for tc in samples]
return samples
samples_with_prefix = [
TestCase(input=f"1\n{tc.input}", expected=tc.expected) for tc in samples
]
return samples_with_prefix, True
inputs = [_text_from_pre(p) for p in input_pres]
outputs = [_text_from_pre(p) for p in output_pres]
n = min(len(inputs), len(outputs))
samples = [TestCase(input=inputs[i], expected=outputs[i]) for i in range(n)]
if multi_test and samples:
return [TestCase(input=f"1\n{tc.input}", expected=tc.expected) for tc in samples]
return samples
return samples, False
def _is_interactive(block: Tag) -> bool:
@ -147,19 +144,6 @@ def _is_interactive(block: Tag) -> bool:
return "This is an interactive problem" in txt
def _is_multi_test_case(block: Tag) -> bool:
input_spec = block.find("div", class_="input-specification")
if not input_spec:
return False
txt = input_spec.get_text(" ", strip=True).lower()
patterns = [
r"first line.*contains.*integer.*number of test case",
r"first line.*integer.*denoting.*number of test case",
r"first line.*number of test case",
]
return any(re.search(pattern, txt) for pattern in patterns)
def _fetch_problems_html(contest_id: str) -> str:
url = f"{BASE_URL}/contest/{contest_id}/problems"
page = StealthyFetcher.fetch(
@ -180,8 +164,7 @@ def _parse_all_blocks(html: str) -> list[dict[str, Any]]:
name = _extract_title(b)[1]
if not letter:
continue
multi_test = _is_multi_test_case(b)
tests = _extract_samples(b, multi_test)
tests, multi_test = _extract_samples(b)
timeout_ms, memory_mb = _extract_limits(b)
interactive = _is_interactive(b)
out.append(
@ -192,6 +175,7 @@ def _parse_all_blocks(html: str) -> list[dict[str, Any]]:
"timeout_ms": timeout_ms,
"memory_mb": memory_mb,
"interactive": interactive,
"multi_test": multi_test,
}
)
return out
@ -274,6 +258,7 @@ class CodeforcesScraper(BaseScraper):
"timeout_ms": b.get("timeout_ms", 0),
"memory_mb": b.get("memory_mb", 0),
"interactive": bool(b.get("interactive")),
"multi_test": bool(b.get("multi_test", False)),
}
),
flush=True,

View file

@ -50,6 +50,7 @@ class TestsResult(ScrapingResult):
timeout_ms: int
memory_mb: float
interactive: bool = False
multi_test: bool = False
model_config = ConfigDict(extra="forbid")