feat: add epsilon tolerance for floating-point output comparison
Problem: output comparison used exact string equality after whitespace normalisation, causing correct solutions to fail on problems where floating-point answers are accepted within a tolerance (e.g. 1e-6). Solution: add an optional ui.panel.epsilon config value. When set, actual and expected output are compared token-by-token: numeric tokens are compared with math.abs(a - b) <= epsilon, non-numeric tokens fall back to exact string equality. Per-problem epsilon can also be stored in the cache and takes precedence over the global default.
This commit is contained in:
parent
84d12758c2
commit
e685a8089f
3 changed files with 87 additions and 3 deletions
|
|
@ -27,6 +27,7 @@
|
||||||
---@field multi_test? boolean
|
---@field multi_test? boolean
|
||||||
---@field memory_mb? number
|
---@field memory_mb? number
|
||||||
---@field timeout_ms? number
|
---@field timeout_ms? number
|
||||||
|
---@field epsilon? number
|
||||||
---@field combined_test? CombinedTest
|
---@field combined_test? CombinedTest
|
||||||
---@field test_cases TestCase[]
|
---@field test_cases TestCase[]
|
||||||
|
|
||||||
|
|
@ -273,6 +274,34 @@ function M.get_constraints(platform, contest_id, problem_id)
|
||||||
return problem_data.timeout_ms, problem_data.memory_mb
|
return problem_data.timeout_ms, problem_data.memory_mb
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---@param platform string
|
||||||
|
---@param contest_id string
|
||||||
|
---@param problem_id? string
|
||||||
|
---@return number?
|
||||||
|
function M.get_epsilon(platform, contest_id, problem_id)
|
||||||
|
vim.validate({
|
||||||
|
platform = { platform, 'string' },
|
||||||
|
contest_id = { contest_id, 'string' },
|
||||||
|
problem_id = { problem_id, { 'string', 'nil' }, true },
|
||||||
|
})
|
||||||
|
|
||||||
|
if
|
||||||
|
not cache_data[platform]
|
||||||
|
or not cache_data[platform][contest_id]
|
||||||
|
or not cache_data[platform][contest_id].index_map
|
||||||
|
then
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
|
||||||
|
local index = cache_data[platform][contest_id].index_map[problem_id]
|
||||||
|
if not index then
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
|
||||||
|
local problem_data = cache_data[platform][contest_id].problems[index]
|
||||||
|
return problem_data and problem_data.epsilon or nil
|
||||||
|
end
|
||||||
|
|
||||||
---@param file_path string
|
---@param file_path string
|
||||||
---@return FileState|nil
|
---@return FileState|nil
|
||||||
function M.get_file_state(file_path)
|
function M.get_file_state(file_path)
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@
|
||||||
---@class PanelConfig
|
---@class PanelConfig
|
||||||
---@field diff_modes string[]
|
---@field diff_modes string[]
|
||||||
---@field max_output_lines integer
|
---@field max_output_lines integer
|
||||||
|
---@field epsilon number?
|
||||||
|
|
||||||
---@class DiffGitConfig
|
---@class DiffGitConfig
|
||||||
---@field args string[]
|
---@field args string[]
|
||||||
|
|
@ -174,7 +175,7 @@ M.defaults = {
|
||||||
add_test_key = 'ga',
|
add_test_key = 'ga',
|
||||||
save_and_exit_key = 'q',
|
save_and_exit_key = 'q',
|
||||||
},
|
},
|
||||||
panel = { diff_modes = { 'side-by-side', 'git', 'vim' }, max_output_lines = 50 },
|
panel = { diff_modes = { 'side-by-side', 'git', 'vim' }, max_output_lines = 50, epsilon = nil },
|
||||||
diff = {
|
diff = {
|
||||||
git = {
|
git = {
|
||||||
args = { 'diff', '--no-index', '--word-diff=plain', '--word-diff-regex=.', '--no-prefix' },
|
args = { 'diff', '--no-index', '--word-diff=plain', '--word-diff-regex=.', '--no-prefix' },
|
||||||
|
|
@ -368,6 +369,13 @@ function M.setup(user_config)
|
||||||
end,
|
end,
|
||||||
'positive integer',
|
'positive integer',
|
||||||
},
|
},
|
||||||
|
epsilon = {
|
||||||
|
cfg.ui.panel.epsilon,
|
||||||
|
function(v)
|
||||||
|
return v == nil or (type(v) == 'number' and v >= 0)
|
||||||
|
end,
|
||||||
|
'nil or non-negative number',
|
||||||
|
},
|
||||||
git = { cfg.ui.diff.git, { 'table' } },
|
git = { cfg.ui.diff.git, { 'table' } },
|
||||||
git_args = { cfg.ui.diff.git.args, is_string_list, 'string[]' },
|
git_args = { cfg.ui.diff.git.args, is_string_list, 'string[]' },
|
||||||
width = {
|
width = {
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
---@class ProblemConstraints
|
---@class ProblemConstraints
|
||||||
---@field timeout_ms number
|
---@field timeout_ms number
|
||||||
---@field memory_mb number
|
---@field memory_mb number
|
||||||
|
---@field epsilon number?
|
||||||
|
|
||||||
---@class PanelState
|
---@class PanelState
|
||||||
---@field test_cases RanTestCase[]
|
---@field test_cases RanTestCase[]
|
||||||
|
|
@ -56,7 +57,8 @@ local function load_constraints_from_cache(platform, contest_id, problem_id)
|
||||||
cache.load()
|
cache.load()
|
||||||
local timeout_ms, memory_mb = cache.get_constraints(platform, contest_id, problem_id)
|
local timeout_ms, memory_mb = cache.get_constraints(platform, contest_id, problem_id)
|
||||||
if timeout_ms and memory_mb then
|
if timeout_ms and memory_mb then
|
||||||
return { timeout_ms = timeout_ms, memory_mb = memory_mb }
|
local epsilon = cache.get_epsilon(platform, contest_id, problem_id)
|
||||||
|
return { timeout_ms = timeout_ms, memory_mb = memory_mb, epsilon = epsilon }
|
||||||
end
|
end
|
||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
@ -99,6 +101,49 @@ local function build_command(cmd, substitutions)
|
||||||
return execute.build_command(cmd, substitutions)
|
return execute.build_command(cmd, substitutions)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function compare_outputs(actual, expected, epsilon)
|
||||||
|
local norm_actual = normalize_lines(actual)
|
||||||
|
local norm_expected = normalize_lines(expected)
|
||||||
|
|
||||||
|
if epsilon == nil or epsilon == 0 then
|
||||||
|
return norm_actual == norm_expected
|
||||||
|
end
|
||||||
|
|
||||||
|
local actual_lines = vim.split(norm_actual, '\n', { plain = true })
|
||||||
|
local expected_lines = vim.split(norm_expected, '\n', { plain = true })
|
||||||
|
|
||||||
|
if #actual_lines ~= #expected_lines then
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
for i = 1, #actual_lines do
|
||||||
|
local a_tokens = vim.split(actual_lines[i], '%s+', { plain = false, trimempty = true })
|
||||||
|
local e_tokens = vim.split(expected_lines[i], '%s+', { plain = false, trimempty = true })
|
||||||
|
|
||||||
|
if #a_tokens ~= #e_tokens then
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
for j = 1, #a_tokens do
|
||||||
|
local a_tok, e_tok = a_tokens[j], e_tokens[j]
|
||||||
|
local a_num = tonumber(a_tok)
|
||||||
|
local e_num = tonumber(e_tok)
|
||||||
|
|
||||||
|
if a_num ~= nil and e_num ~= nil then
|
||||||
|
if math.abs(a_num - e_num) > epsilon then
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
else
|
||||||
|
if a_tok ~= e_tok then
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
|
||||||
---@param test_case RanTestCase
|
---@param test_case RanTestCase
|
||||||
---@param debug boolean?
|
---@param debug boolean?
|
||||||
---@param on_complete fun(result: { status: "pass"|"fail"|"tle"|"mle", actual: string, actual_highlights: Highlight[], error: string, stderr: string, time_ms: number, code: integer, ok: boolean, signal: string?, tled: boolean, mled: boolean, rss_mb: number })
|
---@param on_complete fun(result: { status: "pass"|"fail"|"tle"|"mle", actual: string, actual_highlights: Highlight[], error: string, stderr: string, time_ms: number, code: integer, ok: boolean, signal: string?, tled: boolean, mled: boolean, rss_mb: number })
|
||||||
|
|
@ -143,7 +188,9 @@ local function run_single_test_case(test_case, debug, on_complete)
|
||||||
end
|
end
|
||||||
|
|
||||||
local expected = test_case.expected or ''
|
local expected = test_case.expected or ''
|
||||||
local ok = normalize_lines(out) == normalize_lines(expected)
|
local epsilon = (panel_state.constraints and panel_state.constraints.epsilon)
|
||||||
|
or config.ui.panel.epsilon
|
||||||
|
local ok = compare_outputs(out, expected, epsilon)
|
||||||
|
|
||||||
local signal = r.signal
|
local signal = r.signal
|
||||||
if not signal and r.code and r.code >= 128 then
|
if not signal and r.code and r.code >= 128 then
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue