diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 33342c2..da0f667 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -27,6 +27,7 @@ ---@field multi_test? boolean ---@field memory_mb? number ---@field timeout_ms? number +---@field epsilon? number ---@field combined_test? CombinedTest ---@field test_cases TestCase[] @@ -273,6 +274,34 @@ function M.get_constraints(platform, contest_id, problem_id) return problem_data.timeout_ms, problem_data.memory_mb end +---@param platform string +---@param contest_id string +---@param problem_id? string +---@return number? +function M.get_epsilon(platform, contest_id, problem_id) + vim.validate({ + platform = { platform, 'string' }, + contest_id = { contest_id, 'string' }, + problem_id = { problem_id, { 'string', 'nil' }, true }, + }) + + if + not cache_data[platform] + or not cache_data[platform][contest_id] + or not cache_data[platform][contest_id].index_map + then + return nil + end + + local index = cache_data[platform][contest_id].index_map[problem_id] + if not index then + return nil + end + + local problem_data = cache_data[platform][contest_id].problems[index] + return problem_data and problem_data.epsilon or nil +end + ---@param file_path string ---@return FileState|nil function M.get_file_state(file_path) diff --git a/lua/cp/config.lua b/lua/cp/config.lua index 6cf43d9..4304a31 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -21,6 +21,7 @@ ---@class PanelConfig ---@field diff_modes string[] ---@field max_output_lines integer +---@field epsilon number? ---@class DiffGitConfig ---@field args string[] @@ -174,7 +175,7 @@ M.defaults = { add_test_key = 'ga', save_and_exit_key = 'q', }, - panel = { diff_modes = { 'side-by-side', 'git', 'vim' }, max_output_lines = 50 }, + panel = { diff_modes = { 'side-by-side', 'git', 'vim' }, max_output_lines = 50, epsilon = nil }, diff = { git = { args = { 'diff', '--no-index', '--word-diff=plain', '--word-diff-regex=.', '--no-prefix' }, @@ -368,6 +369,13 @@ function M.setup(user_config) end, 'positive integer', }, + epsilon = { + cfg.ui.panel.epsilon, + function(v) + return v == nil or (type(v) == 'number' and v >= 0) + end, + 'nil or non-negative number', + }, git = { cfg.ui.diff.git, { 'table' } }, git_args = { cfg.ui.diff.git.args, is_string_list, 'string[]' }, width = { diff --git a/lua/cp/runner/run.lua b/lua/cp/runner/run.lua index 4e4a8f6..97ae7ad 100644 --- a/lua/cp/runner/run.lua +++ b/lua/cp/runner/run.lua @@ -19,6 +19,7 @@ ---@class ProblemConstraints ---@field timeout_ms number ---@field memory_mb number +---@field epsilon number? ---@class PanelState ---@field test_cases RanTestCase[] @@ -56,7 +57,8 @@ local function load_constraints_from_cache(platform, contest_id, problem_id) cache.load() local timeout_ms, memory_mb = cache.get_constraints(platform, contest_id, problem_id) if timeout_ms and memory_mb then - return { timeout_ms = timeout_ms, memory_mb = memory_mb } + local epsilon = cache.get_epsilon(platform, contest_id, problem_id) + return { timeout_ms = timeout_ms, memory_mb = memory_mb, epsilon = epsilon } end return nil end @@ -99,6 +101,49 @@ local function build_command(cmd, substitutions) return execute.build_command(cmd, substitutions) end +local function compare_outputs(actual, expected, epsilon) + local norm_actual = normalize_lines(actual) + local norm_expected = normalize_lines(expected) + + if epsilon == nil or epsilon == 0 then + return norm_actual == norm_expected + end + + local actual_lines = vim.split(norm_actual, '\n', { plain = true }) + local expected_lines = vim.split(norm_expected, '\n', { plain = true }) + + if #actual_lines ~= #expected_lines then + return false + end + + for i = 1, #actual_lines do + local a_tokens = vim.split(actual_lines[i], '%s+', { plain = false, trimempty = true }) + local e_tokens = vim.split(expected_lines[i], '%s+', { plain = false, trimempty = true }) + + if #a_tokens ~= #e_tokens then + return false + end + + for j = 1, #a_tokens do + local a_tok, e_tok = a_tokens[j], e_tokens[j] + local a_num = tonumber(a_tok) + local e_num = tonumber(e_tok) + + if a_num ~= nil and e_num ~= nil then + if math.abs(a_num - e_num) > epsilon then + return false + end + else + if a_tok ~= e_tok then + return false + end + end + end + end + + return true +end + ---@param test_case RanTestCase ---@param debug boolean? ---@param on_complete fun(result: { status: "pass"|"fail"|"tle"|"mle", actual: string, actual_highlights: Highlight[], error: string, stderr: string, time_ms: number, code: integer, ok: boolean, signal: string?, tled: boolean, mled: boolean, rss_mb: number }) @@ -143,7 +188,9 @@ local function run_single_test_case(test_case, debug, on_complete) end local expected = test_case.expected or '' - local ok = normalize_lines(out) == normalize_lines(expected) + local epsilon = (panel_state.constraints and panel_state.constraints.epsilon) + or config.ui.panel.epsilon + local ok = compare_outputs(out, expected, epsilon) local signal = r.signal if not signal and r.code and r.code >= 128 then