feat: multi-test case view

This commit is contained in:
Barrett Ruth 2025-11-04 21:32:40 -05:00
parent 6477fdc20c
commit aab211902e
12 changed files with 315 additions and 124 deletions

View file

@ -16,6 +16,10 @@
---@field name string
---@field id string
---@class CombinedTest
---@field input string
---@field expected string
---@class Problem
---@field id string
---@field name? string
@ -23,6 +27,7 @@
---@field multi_test? boolean
---@field memory_mb? number
---@field timeout_ms? number
---@field combined_test? CombinedTest
---@field test_cases TestCase[]
---@class TestCase
@ -181,9 +186,34 @@ function M.get_test_cases(platform, contest_id, problem_id)
return cache_data[platform][contest_id].problems[index].test_cases or {}
end
---@param platform string
---@param contest_id string
---@param problem_id? string
---@return CombinedTest?
function M.get_combined_test(platform, contest_id, problem_id)
vim.validate({
platform = { platform, 'string' },
contest_id = { contest_id, 'string' },
problem_id = { problem_id, { 'string', 'nil' }, true },
})
if
not cache_data[platform]
or not cache_data[platform][contest_id]
or not cache_data[platform][contest_id].problems
or not cache_data[platform][contest_id].index_map
then
return nil
end
local index = cache_data[platform][contest_id].index_map[problem_id]
return cache_data[platform][contest_id].problems[index].combined_test
end
---@param platform string
---@param contest_id string
---@param problem_id string
---@param combined_test? CombinedTest
---@param test_cases TestCase[]
---@param timeout_ms number
---@param memory_mb number
@ -193,6 +223,7 @@ function M.set_test_cases(
platform,
contest_id,
problem_id,
combined_test,
test_cases,
timeout_ms,
memory_mb,
@ -203,6 +234,7 @@ function M.set_test_cases(
platform = { platform, 'string' },
contest_id = { contest_id, 'string' },
problem_id = { problem_id, { 'string', 'nil' }, true },
combined_test = { combined_test, { 'table', 'nil' }, true },
test_cases = { test_cases, 'table' },
timeout_ms = { timeout_ms, { 'number', 'nil' }, true },
memory_mb = { memory_mb, { 'number', 'nil' }, true },
@ -212,6 +244,7 @@ function M.set_test_cases(
local index = cache_data[platform][contest_id].index_map[problem_id]
cache_data[platform][contest_id].problems[index].combined_test = combined_test
cache_data[platform][contest_id].problems[index].test_cases = test_cases
cache_data[platform][contest_id].problems[index].timeout_ms = timeout_ms
cache_data[platform][contest_id].problems[index].memory_mb = memory_mb

View file

@ -76,50 +76,74 @@ local function parse_command(args)
elseif first == 'run' or first == 'panel' then
local debug = false
local test_index = nil
local mode = 'combined'
if #args == 2 then
if args[2] == '--debug' then
debug = true
elseif args[2] == 'all' then
mode = 'individual'
else
local idx = tonumber(args[2])
if not idx then
return {
type = 'error',
message = ("Invalid argument '%s': expected test number or --debug"):format(args[2]),
message = ("Invalid argument '%s': expected test number, 'all', or --debug"):format(
args[2]
),
}
end
if idx < 1 or idx ~= math.floor(idx) then
return { type = 'error', message = ("'%s' is not a valid test index"):format(idx) }
end
test_index = idx
mode = 'individual'
end
elseif #args == 3 then
local idx = tonumber(args[2])
if not idx then
return {
type = 'error',
message = ("Invalid argument '%s': expected test number"):format(args[2]),
}
if args[2] == 'all' then
mode = 'individual'
if args[3] ~= '--debug' then
return {
type = 'error',
message = ("Invalid argument '%s': expected --debug"):format(args[3]),
}
end
debug = true
else
local idx = tonumber(args[2])
if not idx then
return {
type = 'error',
message = ("Invalid argument '%s': expected test number"):format(args[2]),
}
end
if idx < 1 or idx ~= math.floor(idx) then
return { type = 'error', message = ("'%s' is not a valid test index"):format(idx) }
end
if args[3] ~= '--debug' then
return {
type = 'error',
message = ("Invalid argument '%s': expected --debug"):format(args[3]),
}
end
test_index = idx
mode = 'individual'
debug = true
end
if idx < 1 or idx ~= math.floor(idx) then
return { type = 'error', message = ("'%s' is not a valid test index"):format(idx) }
end
if args[3] ~= '--debug' then
return {
type = 'error',
message = ("Invalid argument '%s': expected --debug"):format(args[3]),
}
end
test_index = idx
debug = true
elseif #args > 3 then
return {
type = 'error',
message = 'Too many arguments. Usage: :CP ' .. first .. ' [test_num] [--debug]',
message = 'Too many arguments. Usage: :CP ' .. first .. ' [all|test_num] [--debug]',
}
end
return { type = 'action', action = first, test_index = test_index, debug = debug }
return {
type = 'action',
action = first,
test_index = test_index,
debug = debug,
mode = mode,
}
else
local language = nil
if #args >= 3 and args[2] == '--lang' then
@ -197,7 +221,7 @@ function M.handle_command(opts)
if cmd.action == 'interact' then
ui.toggle_interactive(cmd.interactor_cmd)
elseif cmd.action == 'run' then
ui.run_io_view(cmd.test_index, cmd.debug)
ui.run_io_view(cmd.test_index, cmd.debug, cmd.mode)
elseif cmd.action == 'panel' then
ui.toggle_panel({ debug = cmd.debug, test_index = cmd.test_index })
elseif cmd.action == 'next' then

View file

@ -198,6 +198,40 @@ function M.load_test_cases()
return #tcs > 0
end
---@param debug boolean?
---@return RanTestCase?
function M.run_combined_test(debug)
local combined = cache.get_combined_test(
state.get_platform() or '',
state.get_contest_id() or '',
state.get_problem_id()
)
if not combined then
logger.log('No combined test found', vim.log.levels.ERROR)
return nil
end
local ran_test = {
index = 1,
input = combined.input,
expected = combined.expected,
status = 'running',
actual = nil,
time_ms = nil,
code = nil,
ok = nil,
signal = nil,
tled = false,
mled = false,
rss_mb = 0,
selected = true,
}
local result = run_single_test_case(ran_test, debug)
return result
end
---@param index number
---@param debug boolean?
---@return boolean

View file

@ -194,6 +194,7 @@ function M.scrape_all_tests(platform, contest_id, callback)
end
if type(callback) == 'function' then
callback({
combined = ev.combined,
tests = ev.tests,
timeout_ms = ev.timeout_ms or 0,
memory_mb = ev.memory_mb or 0,

View file

@ -95,6 +95,7 @@ local function start_tests(platform, contest_id, problems)
platform,
contest_id,
ev.problem_id,
ev.combined,
cached_tests,
ev.timeout_ms or 0,
ev.memory_mb or 0,
@ -104,30 +105,11 @@ local function start_tests(platform, contest_id, problems)
local io_state = state.get_io_view_state()
if io_state then
local problem_id = state.get_problem_id()
local test_cases = cache.get_test_cases(platform, contest_id, problem_id)
local input_lines = {}
local contest_data = cache.get_contest_data(platform, contest_id)
local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test
if is_multi_test and #test_cases > 1 then
table.insert(input_lines, tostring(#test_cases))
for _, tc in ipairs(test_cases) do
local stripped = tc.input:gsub('^1\n', '')
for _, line in ipairs(vim.split(stripped, '\n')) do
table.insert(input_lines, line)
end
end
else
for _, tc in ipairs(test_cases) do
for _, line in ipairs(vim.split(tc.input, '\n')) do
table.insert(input_lines, line)
end
end
local combined_test = cache.get_combined_test(platform, contest_id, state.get_problem_id())
if combined_test then
local input_lines = vim.split(combined_test.input, '\n')
require('cp.utils').update_buffer_content(io_state.input_buf, input_lines, nil, nil)
end
require('cp.utils').update_buffer_content(io_state.input_buf, input_lines, nil, nil)
end
end)
end

View file

@ -338,7 +338,9 @@ function M.ensure_io_view()
vim.api.nvim_set_current_win(solution_win)
end
function M.run_io_view(test_index, debug)
function M.run_io_view(test_index, debug, mode)
mode = mode or 'combined'
local platform, contest_id, problem_id =
state.get_platform(), state.get_contest_id(), state.get_problem_id()
if not platform or not contest_id or not problem_id then
@ -359,30 +361,42 @@ function M.run_io_view(test_index, debug)
M.ensure_io_view()
local run = require('cp.runner.run')
if not run.load_test_cases() then
logger.log('No test cases available', vim.log.levels.ERROR)
return
end
local test_state = run.get_panel_state()
local test_indices = {}
if test_index then
if test_index < 1 or test_index > #test_state.test_cases then
logger.log(
string.format(
'Test %d does not exist (only %d tests available)',
test_index,
#test_state.test_cases
),
vim.log.levels.WARN
)
if mode == 'combined' then
local combined = cache.get_combined_test(platform, contest_id, problem_id)
if not combined then
logger.log('No combined test available', vim.log.levels.ERROR)
return
end
test_indices = { test_index }
else
for i = 1, #test_state.test_cases do
test_indices[i] = i
if not run.load_test_cases() then
logger.log('No test cases available', vim.log.levels.ERROR)
return
end
end
local test_indices = {}
if mode == 'individual' then
local test_state = run.get_panel_state()
if test_index then
if test_index < 1 or test_index > #test_state.test_cases then
logger.log(
string.format(
'Test %d does not exist (only %d tests available)',
test_index,
#test_state.test_cases
),
vim.log.levels.WARN
)
return
end
test_indices = { test_index }
else
for i = 1, #test_state.test_cases do
test_indices[i] = i
end
end
end
@ -418,8 +432,6 @@ function M.run_io_view(test_index, debug)
return
end
run.run_all_test_cases(test_indices, debug)
local run_render = require('cp.runner.run_render')
run_render.setup_highlights()
@ -430,64 +442,55 @@ function M.run_io_view(test_index, debug)
local formatter = config.ui.run.format_verdict
local max_time_actual = 0
local max_time_limit = 0
local max_mem_actual = 0
local max_mem_limit = 0
if mode == 'combined' then
local combined = cache.get_combined_test(platform, contest_id, problem_id)
for _, idx in ipairs(test_indices) do
local tc = test_state.test_cases[idx]
max_time_actual = math.max(max_time_actual, #string.format('%.2f', tc.time_ms or 0))
max_time_limit = math.max(
max_time_limit,
#tostring(test_state.constraints and test_state.constraints.timeout_ms or 0)
)
max_mem_actual = math.max(max_mem_actual, #string.format('%.0f', tc.rss_mb or 0))
max_mem_limit = math.max(
max_mem_limit,
#string.format('%.0f', test_state.constraints and test_state.constraints.memory_mb or 0)
)
end
run.load_test_cases()
local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test
local result = run.run_combined_test(debug)
if is_multi_test and #test_indices > 1 then
table.insert(input_lines, tostring(#test_indices))
end
for _, idx in ipairs(test_indices) do
local tc = test_state.test_cases[idx]
if tc.actual then
for _, line in ipairs(vim.split(tc.actual, '\n', { plain = true, trimempty = false })) do
table.insert(output_lines, line)
end
if not result then
logger.log('Failed to run combined test', vim.log.levels.ERROR)
return
end
local status = run_render.get_status_info(tc)
input_lines = vim.split(combined.input, '\n')
if result.actual then
output_lines = vim.split(result.actual, '\n')
end
local status = run_render.get_status_info(result)
local test_state = run.get_panel_state()
---@type VerdictFormatData
local format_data = {
index = idx,
index = 1,
status = status,
time_ms = tc.time_ms or 0,
time_ms = result.time_ms or 0,
time_limit_ms = test_state.constraints and test_state.constraints.timeout_ms or 0,
memory_mb = tc.rss_mb or 0,
memory_mb = result.rss_mb or 0,
memory_limit_mb = test_state.constraints and test_state.constraints.memory_mb or 0,
exit_code = tc.code or 0,
signal = (tc.code and tc.code >= 128) and require('cp.constants').signal_codes[tc.code]
exit_code = result.code or 0,
signal = (result.code and result.code >= 128)
and require('cp.constants').signal_codes[result.code]
or nil,
time_actual_width = max_time_actual,
time_limit_width = max_time_limit,
mem_actual_width = max_mem_actual,
mem_limit_width = max_mem_limit,
time_actual_width = #string.format('%.2f', result.time_ms or 0),
time_limit_width = #tostring(
test_state.constraints and test_state.constraints.timeout_ms or 0
),
mem_actual_width = #string.format('%.0f', result.rss_mb or 0),
mem_limit_width = #string.format(
'%.0f',
test_state.constraints and test_state.constraints.memory_mb or 0
),
}
local result = formatter(format_data)
table.insert(verdict_lines, result.line)
local verdict_result = formatter(format_data)
table.insert(verdict_lines, verdict_result.line)
if result.highlights then
for _, hl in ipairs(result.highlights) do
if verdict_result.highlights then
for _, hl in ipairs(verdict_result.highlights) do
table.insert(verdict_highlights, {
line_offset = #verdict_lines - 1,
col_start = hl.col_start,
@ -496,13 +499,83 @@ function M.run_io_view(test_index, debug)
})
end
end
else
run.run_all_test_cases(test_indices, debug)
local test_state = run.get_panel_state()
local test_input = tc.input
if is_multi_test and #test_indices > 1 then
test_input = test_input:gsub('^1\n', '')
local max_time_actual = 0
local max_time_limit = 0
local max_mem_actual = 0
local max_mem_limit = 0
for _, idx in ipairs(test_indices) do
local tc = test_state.test_cases[idx]
max_time_actual = math.max(max_time_actual, #string.format('%.2f', tc.time_ms or 0))
max_time_limit = math.max(
max_time_limit,
#tostring(test_state.constraints and test_state.constraints.timeout_ms or 0)
)
max_mem_actual = math.max(max_mem_actual, #string.format('%.0f', tc.rss_mb or 0))
max_mem_limit = math.max(
max_mem_limit,
#string.format('%.0f', test_state.constraints and test_state.constraints.memory_mb or 0)
)
end
for _, line in ipairs(vim.split(test_input, '\n')) do
table.insert(input_lines, line)
local all_outputs = {}
for _, idx in ipairs(test_indices) do
local tc = test_state.test_cases[idx]
for _, line in ipairs(vim.split(tc.input, '\n')) do
table.insert(input_lines, line)
end
if tc.actual then
table.insert(all_outputs, tc.actual)
end
end
local combined_output = table.concat(all_outputs, '')
if combined_output ~= '' then
for _, line in ipairs(vim.split(combined_output, '\n')) do
table.insert(output_lines, line)
end
end
for _, idx in ipairs(test_indices) do
local tc = test_state.test_cases[idx]
local status = run_render.get_status_info(tc)
---@type VerdictFormatData
local format_data = {
index = idx,
status = status,
time_ms = tc.time_ms or 0,
time_limit_ms = test_state.constraints and test_state.constraints.timeout_ms or 0,
memory_mb = tc.rss_mb or 0,
memory_limit_mb = test_state.constraints and test_state.constraints.memory_mb or 0,
exit_code = tc.code or 0,
signal = (tc.code and tc.code >= 128) and require('cp.constants').signal_codes[tc.code]
or nil,
time_actual_width = max_time_actual,
time_limit_width = max_time_limit,
mem_actual_width = max_mem_actual,
mem_limit_width = max_mem_limit,
}
local result = formatter(format_data)
table.insert(verdict_lines, result.line)
if result.highlights then
for _, hl in ipairs(result.highlights) do
table.insert(verdict_highlights, {
line_offset = #verdict_lines - 1,
col_start = hl.col_start,
col_end = hl.col_end,
group = hl.group,
})
end
end
end
end

View file

@ -16,6 +16,7 @@ from urllib3.util.retry import Retry
from .base import BaseScraper
from .models import (
CombinedTest,
ContestListResult,
ContestSummary,
MetadataResult,
@ -364,6 +365,7 @@ async def main_async() -> int:
success=False,
error="Usage: atcoder.py tests <contest_id>",
problem_id="",
combined=CombinedTest(input="", expected=""),
tests=[],
timeout_ms=0,
memory_mb=0,

View file

@ -34,10 +34,13 @@ class BaseScraper(ABC):
def _create_tests_error(
self, error_msg: str, problem_id: str = "", url: str = ""
) -> TestsResult:
from .models import CombinedTest
return TestsResult(
success=False,
error=f"{self.platform_name}: {error_msg}",
problem_id=problem_id,
combined=CombinedTest(input="", expected=""),
tests=[],
timeout_ms=0,
memory_mb=0,

View file

@ -11,6 +11,7 @@ from scrapling.fetchers import StealthyFetcher
from .base import BaseScraper
from .models import (
CombinedTest,
ContestListResult,
ContestSummary,
MetadataResult,
@ -279,6 +280,7 @@ async def main_async() -> int:
success=False,
error="Usage: codechef.py tests <contest_id>",
problem_id="",
combined=CombinedTest(input="", expected=""),
tests=[],
timeout_ms=0,
memory_mb=0,

View file

@ -13,6 +13,7 @@ from scrapling.fetchers import StealthyFetcher
from .base import BaseScraper
from .models import (
CombinedTest,
ContestListResult,
ContestSummary,
MetadataResult,
@ -126,16 +127,12 @@ def _extract_samples(block: Tag) -> tuple[list[TestCase], bool]:
)
for k in keys
]
samples_with_prefix = [
TestCase(input=f"1\n{tc.input}", expected=tc.expected) for tc in samples
]
return samples_with_prefix, True
return samples, True
inputs = [_text_from_pre(p) for p in input_pres]
outputs = [_text_from_pre(p) for p in output_pres]
n = min(len(inputs), len(outputs))
samples = [TestCase(input=inputs[i], expected=outputs[i]) for i in range(n)]
return samples, False
return [TestCase(input=inputs[i], expected=outputs[i]) for i in range(n)], False
def _is_interactive(block: Tag) -> bool:
@ -164,18 +161,35 @@ def _parse_all_blocks(html: str) -> list[dict[str, Any]]:
name = _extract_title(b)[1]
if not letter:
continue
tests, multi_test = _extract_samples(b)
raw_samples, is_grouped = _extract_samples(b)
timeout_ms, memory_mb = _extract_limits(b)
interactive = _is_interactive(b)
if is_grouped and raw_samples:
combined_input = f"{len(raw_samples)}\n" + "\n".join(
tc.input for tc in raw_samples
)
combined_expected = "\n".join(tc.expected for tc in raw_samples)
individual_tests = [
TestCase(input=f"1\n{tc.input}", expected=tc.expected)
for tc in raw_samples
]
else:
combined_input = "\n".join(tc.input for tc in raw_samples)
combined_expected = "\n".join(tc.expected for tc in raw_samples)
individual_tests = raw_samples
out.append(
{
"letter": letter,
"name": name,
"tests": tests,
"combined_input": combined_input,
"combined_expected": combined_expected,
"tests": individual_tests,
"timeout_ms": timeout_ms,
"memory_mb": memory_mb,
"interactive": interactive,
"multi_test": multi_test,
"multi_test": is_grouped,
}
)
return out
@ -252,6 +266,10 @@ class CodeforcesScraper(BaseScraper):
json.dumps(
{
"problem_id": pid,
"combined": {
"input": b.get("combined_input", ""),
"expected": b.get("combined_expected", ""),
},
"tests": [
{"input": t.input, "expected": t.expected} for t in tests
],
@ -298,6 +316,7 @@ async def main_async() -> int:
success=False,
error="Usage: codeforces.py tests <contest_id>",
problem_id="",
combined=CombinedTest(input="", expected=""),
tests=[],
timeout_ms=0,
memory_mb=0,

View file

@ -10,6 +10,7 @@ import httpx
from .base import BaseScraper
from .models import (
CombinedTest,
ContestListResult,
ContestSummary,
MetadataResult,
@ -233,8 +234,16 @@ class CSESScraper(BaseScraper):
except Exception:
tests = []
timeout_ms, memory_mb, interactive = 0, 0, False
combined_input = "\n".join(t.input for t in tests)
combined_expected = "\n".join(t.expected for t in tests)
return {
"problem_id": pid,
"combined": {
"input": combined_input,
"expected": combined_expected,
},
"tests": [
{"input": t.input, "expected": t.expected} for t in tests
],
@ -282,6 +291,7 @@ async def main_async() -> int:
success=False,
error="Usage: cses.py tests <category>",
problem_id="",
combined=CombinedTest(input="", expected=""),
tests=[],
timeout_ms=0,
memory_mb=0,

View file

@ -8,6 +8,13 @@ class TestCase(BaseModel):
model_config = ConfigDict(extra="forbid")
class CombinedTest(BaseModel):
input: str
expected: str
model_config = ConfigDict(extra="forbid")
class ProblemSummary(BaseModel):
id: str
name: str
@ -46,6 +53,7 @@ class ContestListResult(ScrapingResult):
class TestsResult(ScrapingResult):
problem_id: str
combined: CombinedTest
tests: list[TestCase] = Field(default_factory=list)
timeout_ms: int
memory_mb: float