From 9e84d57b8ae92657929bded67a96067e5402ff08 Mon Sep 17 00:00:00 2001 From: Barrett Ruth Date: Wed, 24 Sep 2025 18:21:34 -0400 Subject: [PATCH] feat: context, not config --- doc/cp.txt | 27 ++---- lua/cp/config.lua | 6 +- lua/cp/log.lua | 6 -- lua/cp/pickers/init.lua | 31 +++---- lua/cp/problem.lua | 68 --------------- lua/cp/runner/execute.lua | 66 +++++++++------ lua/cp/runner/run.lua | 41 +++++---- lua/cp/setup.lua | 26 +++--- lua/cp/state.lua | 79 ++++++++++++++++++ lua/cp/ui/panel.lua | 15 ++-- scrapers/atcoder.py | 5 +- spec/error_boundaries_spec.lua | 3 - spec/panel_spec.lua | 10 ++- spec/problem_spec.lua | 146 --------------------------------- spec/spec_helper.lua | 8 +- 15 files changed, 209 insertions(+), 328 deletions(-) delete mode 100644 lua/cp/problem.lua delete mode 100644 spec/problem_spec.lua diff --git a/doc/cp.txt b/doc/cp.txt index e4fc5a7..be761d6 100644 --- a/doc/cp.txt +++ b/doc/cp.txt @@ -236,32 +236,21 @@ Here's an example configuration with lazy.nvim: >lua *cp.Hooks* Fields: ~ {before_run} (function, optional) Called before test panel opens. - function(ctx: ProblemContext) + function(state: cp.State) {before_debug} (function, optional) Called before debug compilation. - function(ctx: ProblemContext) + function(state: cp.State) {setup_code} (function, optional) Called after source file is opened. Good for configuring buffer settings. - function(ctx: ProblemContext) + function(state: cp.State) - *ProblemContext* - Context object passed to hook functions containing problem information. - - Fields: ~ - {contest} (string) Platform name (e.g. "atcoder", "codeforces") - {contest_id} (string) Contest ID (e.g. "abc123", "1933") - {problem_id} (string, optional) Problem ID (e.g. "a", "b") - nil for CSES - {source_file} (string) Source filename (e.g. "abc123a.cpp") - {binary_file} (string) Binary output path (e.g. "build/abc123a.run") - {input_file} (string) Test input path (e.g. "io/abc123a.cpin") - {output_file} (string) Program output path (e.g. "io/abc123a.cpout") - {expected_file} (string) Expected output path (e.g. "io/abc123a.expected") - {problem_name} (string) Display name (e.g. "abc123a") + Hook functions receive the cp.nvim state object (cp.State). See the state + module documentation for available methods and fields. Example usage in hook: >lua hooks = { - setup_code = function(ctx) - print("Setting up " .. ctx.problem_name) - print("Source file: " .. ctx.source_file) + setup_code = function(state) + print("Setting up " .. state.get_base_name()) + print("Source file: " .. state.get_source_file()) end } < diff --git a/lua/cp/config.lua b/lua/cp/config.lua index 411d002..9935413 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -25,9 +25,9 @@ ---@field default_language? string ---@class Hooks ----@field before_run? fun(ctx: ProblemContext) ----@field before_debug? fun(ctx: ProblemContext) ----@field setup_code? fun(ctx: ProblemContext) +---@field before_run? fun(state: cp.State) +---@field before_debug? fun(state: cp.State) +---@field setup_code? fun(state: cp.State) ---@class RunPanelConfig ---@field ansi boolean Enable ANSI color parsing and highlighting diff --git a/lua/cp/log.lua b/lua/cp/log.lua index 6a05316..9c702b4 100644 --- a/lua/cp/log.lua +++ b/lua/cp/log.lua @@ -9,10 +9,4 @@ function M.log(msg, level, override) end end -function M.progress(msg) - vim.schedule(function() - vim.notify(('[cp.nvim]: %s'):format(msg), vim.log.levels.INFO) - end) -end - return M diff --git a/lua/cp/pickers/init.lua b/lua/cp/pickers/init.lua index f8cac85..77c0685 100644 --- a/lua/cp/pickers/init.lua +++ b/lua/cp/pickers/init.lua @@ -34,16 +34,17 @@ end ---@param platform string Platform identifier (e.g. "codeforces", "atcoder") ---@return cp.ContestItem[] local function get_contests_for_platform(platform) + local constants = require('cp.constants') + local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform + + logger.log(('loading %s contests...'):format(platform_display_name), vim.log.levels.INFO, true) + cache.load() local cached_contests = cache.get_contest_list(platform) if cached_contests then return cached_contests end - local constants = require('cp.constants') - local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform - logger.progress(('loading %s contests...'):format(platform_display_name)) - if not utils.setup_python_env() then return {} end @@ -59,8 +60,6 @@ local function get_contests_for_platform(platform) 'contests', } - logger.progress(('running: %s'):format(table.concat(cmd, ' '))) - local result = vim .system(cmd, { cwd = plugin_path, @@ -69,9 +68,9 @@ local function get_contests_for_platform(platform) }) :wait() - logger.progress(('exit code: %d, stdout length: %d'):format(result.code, #(result.stdout or ''))) + logger.log(('exit code: %d, stdout length: %d'):format(result.code, #(result.stdout or ''))) if result.stderr and #result.stderr > 0 then - logger.progress(('stderr: %s'):format(result.stderr:sub(1, 200))) + logger.log(('stderr: %s'):format(result.stderr:sub(1, 200))) end if result.code ~= 0 then @@ -82,7 +81,7 @@ local function get_contests_for_platform(platform) return {} end - logger.progress(('stdout preview: %s'):format(result.stdout:sub(1, 100))) + logger.log(('stdout preview: %s'):format(result.stdout:sub(1, 100))) local ok, data = pcall(vim.json.decode, result.stdout) if not ok then @@ -107,7 +106,7 @@ local function get_contests_for_platform(platform) end cache.set_contest_list(platform, contests) - logger.progress(('loaded %d contests'):format(#contests)) + logger.log(('loaded %d contests'):format(#contests)) return contests end @@ -115,6 +114,8 @@ end ---@param contest_id string Contest identifier ---@return cp.ProblemItem[] local function get_problems_for_contest(platform, contest_id) + local constants = require('cp.constants') + local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform local problems = {} cache.load() @@ -130,14 +131,16 @@ local function get_problems_for_contest(platform, contest_id) return problems end + logger.log( + ('loading %s %s problems...'):format(platform_display_name, contest_id), + vim.log.levels.INFO, + true + ) + if not utils.setup_python_env() then return problems end - local constants = require('cp.constants') - local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform - logger.progress(('loading %s %s problems...'):format(platform_display_name, contest_id)) - local plugin_path = utils.get_plugin_path() local cmd = { 'uv', diff --git a/lua/cp/problem.lua b/lua/cp/problem.lua deleted file mode 100644 index bf5a56d..0000000 --- a/lua/cp/problem.lua +++ /dev/null @@ -1,68 +0,0 @@ ----@class ProblemContext ----@field contest string Contest name (e.g. "atcoder", "codeforces") ----@field contest_id string Contest ID (e.g. "abc123", "1933") ----@field problem_id? string Problem ID for AtCoder/Codeforces (e.g. "a", "b") ----@field source_file string Source filename (e.g. "abc123a.cpp") ----@field binary_file string Binary output path (e.g. "build/abc123a.run") ----@field input_file string Input test file path (e.g. "io/abc123a.in") ----@field output_file string Output file path (e.g. "io/abc123a.out") ----@field expected_file string Expected output path (e.g. "io/abc123a.expected") ----@field problem_name string Canonical problem identifier (e.g. "abc123a") - -local M = {} - ----@param contest string ----@param contest_id string ----@param problem_id? string ----@param config cp.Config ----@param language? string ----@return ProblemContext -function M.create_context(contest, contest_id, problem_id, config, language) - vim.validate({ - contest = { contest, 'string' }, - contest_id = { contest_id, 'string' }, - problem_id = { problem_id, { 'string', 'nil' }, true }, - config = { config, 'table' }, - language = { language, { 'string', 'nil' }, true }, - }) - - local contest_config = config.contests[contest] - if not contest_config then - error(("No contest config found for '%s'"):format(contest)) - end - - local target_language = language or contest_config.default_language - local language_config = contest_config[target_language] - if not language_config then - error(("No language config found for '%s' in contest '%s'"):format(target_language, contest)) - end - if not language_config.extension then - error( - ("No extension configured for language '%s' in contest '%s'"):format(target_language, contest) - ) - end - - local base_name - if config.filename then - base_name = config.filename(contest, contest_id, problem_id, config, language) - else - local default_filename = require('cp.config').default_filename - base_name = default_filename(contest_id, problem_id) - end - - local source_file = base_name .. '.' .. language_config.extension - - return { - contest = contest, - contest_id = contest_id, - problem_id = problem_id, - source_file = source_file, - binary_file = ('build/%s.run'):format(base_name), - input_file = ('io/%s.cpin'):format(base_name), - output_file = ('io/%s.cpout'):format(base_name), - expected_file = ('io/%s.expected'):format(base_name), - problem_name = base_name, - } -end - -return M diff --git a/lua/cp/runner/execute.lua b/lua/cp/runner/execute.lua index 62c0f99..e4bb416 100644 --- a/lua/cp/runner/execute.lua +++ b/lua/cp/runner/execute.lua @@ -203,17 +203,22 @@ local function format_output(exec_result, expected_file, is_debug) return table.concat(output_lines, '') .. '\n' .. table.concat(metadata_lines, '\n') end ----@param ctx ProblemContext ---@param contest_config ContestConfig ---@param is_debug? boolean ---@return {success: boolean, output: string?} -function M.compile_problem(ctx, contest_config, is_debug) +function M.compile_problem(contest_config, is_debug) vim.validate({ - ctx = { ctx, 'table' }, contest_config = { contest_config, 'table' }, }) - local language = get_language_from_file(ctx.source_file, contest_config) + local state = require('cp.state') + local source_file = state.get_source_file() + if not source_file then + logger.log('No source file found', vim.log.levels.ERROR) + return { success = false, output = 'No source file found' } + end + + local language = get_language_from_file(source_file, contest_config) local language_config = contest_config[language] if not language_config then @@ -221,9 +226,10 @@ function M.compile_problem(ctx, contest_config, is_debug) return { success = false, output = 'No configuration for language: ' .. language } end + local binary_file = state.get_binary_file() local substitutions = { - source = ctx.source_file, - binary = ctx.binary_file, + source = source_file, + binary = binary_file, version = tostring(language_config.version), } @@ -244,26 +250,35 @@ function M.compile_problem(ctx, contest_config, is_debug) return { success = true, output = nil } end -function M.run_problem(ctx, contest_config, is_debug) +function M.run_problem(contest_config, is_debug) vim.validate({ - ctx = { ctx, 'table' }, contest_config = { contest_config, 'table' }, is_debug = { is_debug, 'boolean' }, }) - vim.system({ 'mkdir', '-p', 'build', 'io' }):wait() + local state = require('cp.state') + local source_file = state.get_source_file() + local output_file = state.get_output_file() - local language = get_language_from_file(ctx.source_file, contest_config) - local language_config = contest_config[language] - - if not language_config then - vim.fn.writefile({ 'Error: No configuration for language: ' .. language }, ctx.output_file) + if not source_file or not output_file then + logger.log('Missing required file paths', vim.log.levels.ERROR) return end + vim.system({ 'mkdir', '-p', 'build', 'io' }):wait() + + local language = get_language_from_file(source_file, contest_config) + local language_config = contest_config[language] + + if not language_config then + vim.fn.writefile({ 'Error: No configuration for language: ' .. language }, output_file) + return + end + + local binary_file = state.get_binary_file() local substitutions = { - source = ctx.source_file, - binary = ctx.binary_file, + source = source_file, + binary = binary_file, version = tostring(language_config.version), } @@ -271,26 +286,31 @@ function M.run_problem(ctx, contest_config, is_debug) if compile_cmd then local compile_result = M.compile_generic(language_config, substitutions) if compile_result.code ~= 0 then - vim.fn.writefile({ compile_result.stderr }, ctx.output_file) + vim.fn.writefile({ compile_result.stderr }, output_file) return end end + local input_file = state.get_input_file() local input_data = '' - if vim.fn.filereadable(ctx.input_file) == 1 then - input_data = table.concat(vim.fn.readfile(ctx.input_file), '\n') .. '\n' + if input_file and vim.fn.filereadable(input_file) == 1 then + input_data = table.concat(vim.fn.readfile(input_file), '\n') .. '\n' end local cache = require('cp.cache') cache.load() - local timeout_ms, _ = cache.get_constraints(ctx.contest, ctx.contest_id, ctx.problem_id) + local platform = state.get_platform() + local contest_id = state.get_contest_id() + local problem_id = state.get_problem_id() + local timeout_ms, _ = cache.get_constraints(platform, contest_id, problem_id) timeout_ms = timeout_ms or 2000 local run_cmd = build_command(language_config.test, language_config.executable, substitutions) local exec_result = execute_command(run_cmd, input_data, timeout_ms) - local formatted_output = format_output(exec_result, ctx.expected_file, is_debug) + local expected_file = state.get_expected_file() + local formatted_output = format_output(exec_result, expected_file, is_debug) - local output_buf = vim.fn.bufnr(ctx.output_file) + local output_buf = vim.fn.bufnr(output_file) if output_buf ~= -1 then local was_modifiable = vim.api.nvim_get_option_value('modifiable', { buf = output_buf }) local was_readonly = vim.api.nvim_get_option_value('readonly', { buf = output_buf }) @@ -303,7 +323,7 @@ function M.run_problem(ctx, contest_config, is_debug) vim.cmd.write() end) else - vim.fn.writefile(vim.split(formatted_output, '\n'), ctx.output_file) + vim.fn.writefile(vim.split(formatted_output, '\n'), output_file) end end diff --git a/lua/cp/runner/run.lua b/lua/cp/runner/run.lua index bff8a0f..cb1bce2 100644 --- a/lua/cp/runner/run.lua +++ b/lua/cp/runner/run.lua @@ -130,12 +130,22 @@ local function load_constraints_from_cache(platform, contest_id, problem_id) return nil end ----@param ctx ProblemContext ---@param contest_config ContestConfig ---@param test_case TestCase ---@return table -local function run_single_test_case(ctx, contest_config, cp_config, test_case) - local language = vim.fn.fnamemodify(ctx.source_file, ':e') +local function run_single_test_case(contest_config, cp_config, test_case) + local state = require('cp.state') + local source_file = state.get_source_file() + if not source_file then + return { + status = 'fail', + actual = '', + error = 'No source file found', + time_ms = 0, + } + end + + local language = vim.fn.fnamemodify(source_file, ':e') local language_name = constants.filetype_to_language[language] or contest_config.default_language local language_config = contest_config[language_name] @@ -168,13 +178,14 @@ local function run_single_test_case(ctx, contest_config, cp_config, test_case) return cmd end + local binary_file = state.get_binary_file() local substitutions = { - source = ctx.source_file, - binary = ctx.binary_file, + source = source_file, + binary = binary_file, version = tostring(language_config.version or ''), } - if language_config.compile and vim.fn.filereadable(ctx.binary_file) == 0 then + if language_config.compile and vim.fn.filereadable(binary_file) == 0 then logger.log('binary not found, compiling first...') local compile_cmd = substitute_template(language_config.compile, substitutions) local redirected_cmd = vim.deepcopy(compile_cmd) @@ -282,10 +293,9 @@ local function run_single_test_case(ctx, contest_config, cp_config, test_case) } end ----@param ctx ProblemContext ---@param state table ---@return boolean -function M.load_test_cases(ctx, state) +function M.load_test_cases(state) local test_cases = parse_test_cases_from_cache( state.get_platform() or '', state.get_contest_id() or '', @@ -293,7 +303,9 @@ function M.load_test_cases(ctx, state) ) if #test_cases == 0 then - test_cases = parse_test_cases_from_files(ctx.input_file, ctx.expected_file) + local input_file = state.get_input_file() + local expected_file = state.get_expected_file() + test_cases = parse_test_cases_from_files(input_file, expected_file) end run_panel_state.test_cases = test_cases @@ -315,11 +327,10 @@ function M.load_test_cases(ctx, state) return #test_cases > 0 end ----@param ctx ProblemContext ---@param contest_config ContestConfig ---@param index number ---@return boolean -function M.run_test_case(ctx, contest_config, cp_config, index) +function M.run_test_case(contest_config, cp_config, index) local test_case = run_panel_state.test_cases[index] if not test_case then return false @@ -327,7 +338,7 @@ function M.run_test_case(ctx, contest_config, cp_config, index) test_case.status = 'running' - local result = run_single_test_case(ctx, contest_config, cp_config, test_case) + local result = run_single_test_case(contest_config, cp_config, test_case) test_case.status = result.status test_case.actual = result.actual @@ -343,13 +354,13 @@ function M.run_test_case(ctx, contest_config, cp_config, index) return true end ----@param ctx ProblemContext ---@param contest_config ContestConfig +---@param cp_config cp.Config ---@return TestCase[] -function M.run_all_test_cases(ctx, contest_config, cp_config) +function M.run_all_test_cases(contest_config, cp_config) local results = {} for i, _ in ipairs(run_panel_state.test_cases) do - M.run_test_case(ctx, contest_config, cp_config, i) + M.run_test_case(contest_config, cp_config, i) table.insert(results, run_panel_state.test_cases[i]) end return results diff --git a/lua/cp/setup.lua b/lua/cp/setup.lua index 4992543..48ebed8 100644 --- a/lua/cp/setup.lua +++ b/lua/cp/setup.lua @@ -3,7 +3,6 @@ local M = {} local cache = require('cp.cache') local config_module = require('cp.config') local logger = require('cp.log') -local problem = require('cp.problem') local scraper = require('cp.scraper') local state = require('cp.state') @@ -42,7 +41,7 @@ function M.setup_contest(platform, contest_id, problem_id, language) return end - logger.progress(('fetching contest %s %s...'):format(platform, contest_id)) + logger.log(('fetching contest %s %s...'):format(platform, contest_id)) scraper.scrape_contest_metadata(platform, contest_id, function(result) if not result.success then @@ -59,7 +58,7 @@ function M.setup_contest(platform, contest_id, problem_id, language) return end - logger.progress(('found %d problems'):format(#problems)) + logger.log(('found %d problems'):format(#problems)) state.set_contest_id(contest_id) local target_problem = problem_id or problems[1].id @@ -96,16 +95,17 @@ function M.setup_problem(contest_id, problem_id, language) local config = config_module.get_config() local platform = state.get_platform() or '' - logger.progress(('setting up problem %s%s...'):format(contest_id, problem_id or '')) + logger.log(('setting up problem %s%s...'):format(contest_id, problem_id or '')) - local ctx = problem.create_context(platform, contest_id, problem_id, config, language) + state.set_contest_id(contest_id) + state.set_problem_id(problem_id) local cached_tests = cache.get_test_cases(platform, contest_id, problem_id) if cached_tests then state.set_test_cases(cached_tests) logger.log(('using cached test cases (%d)'):format(#cached_tests)) elseif vim.tbl_contains(config.scrapers, platform) then - logger.progress('loading test cases...') + logger.log('loading test cases...') scraper.scrape_problem_tests(platform, contest_id, problem_id, function(result) if result.success then @@ -128,15 +128,17 @@ function M.setup_problem(contest_id, problem_id, language) state.set_test_cases({}) end - state.set_contest_id(contest_id) - state.set_problem_id(problem_id) state.set_run_panel_active(false) vim.schedule(function() local ok, err = pcall(function() vim.cmd.only({ mods = { silent = true } }) - vim.cmd.e(ctx.source_file) + local source_file = state.get_source_file(language) + if not source_file then + error('Failed to generate source file path') + end + vim.cmd.e(source_file) local source_buf = vim.api.nvim_get_current_buf() if vim.api.nvim_buf_get_lines(source_buf, 0, -1, true)[1] == '' then @@ -166,12 +168,12 @@ function M.setup_problem(contest_id, problem_id, language) end if config.hooks and config.hooks.setup_code then - config.hooks.setup_code(ctx) + config.hooks.setup_code(state) end cache.set_file_state(vim.fn.expand('%:p'), platform, contest_id, problem_id, language) - logger.progress(('ready - problem %s'):format(ctx.problem_name)) + logger.log(('ready - problem %s'):format(state.get_base_name())) end) if not ok then @@ -196,7 +198,7 @@ function M.scrape_remaining_problems(platform, contest_id, problems) return end - logger.progress(('caching %d remaining problems...'):format(#missing_problems)) + logger.log(('caching %d remaining problems...'):format(#missing_problems)) for _, prob in ipairs(missing_problems) do scraper.scrape_problem_tests(platform, contest_id, prob.id, function(result) diff --git a/lua/cp/state.lua b/lua/cp/state.lua index ae21fc5..d61a40a 100644 --- a/lua/cp/state.lua +++ b/lua/cp/state.lua @@ -1,3 +1,26 @@ +---@class cp.State +---@field get_platform fun(): string? +---@field set_platform fun(platform: string) +---@field get_contest_id fun(): string? +---@field set_contest_id fun(contest_id: string) +---@field get_problem_id fun(): string? +---@field set_problem_id fun(problem_id: string) +---@field get_test_cases fun(): table[]? +---@field set_test_cases fun(test_cases: table[]) +---@field is_run_panel_active fun(): boolean +---@field set_run_panel_active fun(active: boolean) +---@field get_saved_session fun(): table? +---@field set_saved_session fun(session: table) +---@field get_context fun(): {platform: string?, contest_id: string?, problem_id: string?} +---@field has_context fun(): boolean +---@field reset fun() +---@field get_base_name fun(): string? +---@field get_source_file fun(language?: string): string? +---@field get_binary_file fun(): string? +---@field get_input_file fun(): string? +---@field get_output_file fun(): string? +---@field get_expected_file fun(): string? + local M = {} local state = { @@ -65,6 +88,62 @@ function M.get_context() } end +function M.get_base_name() + if not state.contest_id then + return nil + end + + local config_module = require('cp.config') + local config = config_module.get_config() + + if config.filename then + return config.filename(state.platform or '', state.contest_id, state.problem_id, config) + else + return config_module.default_filename(state.contest_id, state.problem_id) + end +end + +function M.get_source_file(language) + local base_name = M.get_base_name() + if not base_name or not state.platform then + return nil + end + + local config = require('cp.config').get_config() + local contest_config = config.contests[state.platform] + if not contest_config then + return nil + end + + local target_language = language or contest_config.default_language + local language_config = contest_config[target_language] + if not language_config or not language_config.extension then + return nil + end + + return base_name .. '.' .. language_config.extension +end + +function M.get_binary_file() + local base_name = M.get_base_name() + return base_name and ('build/%s.run'):format(base_name) or nil +end + +function M.get_input_file() + local base_name = M.get_base_name() + return base_name and ('io/%s.cpin'):format(base_name) or nil +end + +function M.get_output_file() + local base_name = M.get_base_name() + return base_name and ('io/%s.cpout'):format(base_name) or nil +end + +function M.get_expected_file() + local base_name = M.get_base_name() + return base_name and ('io/%s.expected'):format(base_name) or nil +end + function M.has_context() return state.platform and state.contest_id end diff --git a/lua/cp/ui/panel.lua b/lua/cp/ui/panel.lua index 1ca5551..9fe2a9f 100644 --- a/lua/cp/ui/panel.lua +++ b/lua/cp/ui/panel.lua @@ -4,7 +4,6 @@ local buffer_utils = require('cp.utils.buffer') local config_module = require('cp.config') local layouts = require('cp.ui.layouts') local logger = require('cp.log') -local problem = require('cp.problem') local state = require('cp.state') local current_diff_layout = nil @@ -57,12 +56,12 @@ function M.toggle_run_panel(is_debug) ) local config = config_module.get_config() - local ctx = problem.create_context(platform or '', contest_id or '', problem_id, config) local run = require('cp.runner.run') - logger.log(('run panel: checking test cases for %s'):format(ctx.input_file)) + local input_file = state.get_input_file() + logger.log(('run panel: checking test cases for %s'):format(input_file or 'none')) - if not run.load_test_cases(ctx, state) then + if not run.load_test_cases(state) then logger.log('no test cases found', vim.log.levels.WARN) return end @@ -170,18 +169,18 @@ function M.toggle_run_panel(is_debug) setup_keybindings_for_buffer(test_buffers.tab_buf) if config.hooks and config.hooks.before_run then - config.hooks.before_run(ctx) + config.hooks.before_run(state) end if is_debug and config.hooks and config.hooks.before_debug then - config.hooks.before_debug(ctx) + config.hooks.before_debug(state) end local execute = require('cp.runner.execute') local contest_config = config.contests[state.get_platform() or ''] - local compile_result = execute.compile_problem(ctx, contest_config, is_debug) + local compile_result = execute.compile_problem(contest_config, is_debug) if compile_result.success then - run.run_all_test_cases(ctx, contest_config, config) + run.run_all_test_cases(contest_config, config) else run.handle_compilation_failure(compile_result.output) end diff --git a/scrapers/atcoder.py b/scrapers/atcoder.py index cd72613..2eba02b 100644 --- a/scrapers/atcoder.py +++ b/scrapers/atcoder.py @@ -272,8 +272,9 @@ def scrape_contests() -> list[ContestSummary]: r"[\uff01-\uff5e]", lambda m: chr(ord(m.group()) - 0xFEE0), name ) - # Skip AtCoder Heuristic Contests (AHC) as they don't have standard sample tests - if not contest_id.startswith("ahc"): + if not ( + contest_id.startswith("ahc") or name.lower().find("heuristic") != -1 + ): contests.append( ContestSummary(id=contest_id, name=name, display_name=name) ) diff --git a/spec/error_boundaries_spec.lua b/spec/error_boundaries_spec.lua index f17fa83..0af5f2a 100644 --- a/spec/error_boundaries_spec.lua +++ b/spec/error_boundaries_spec.lua @@ -9,9 +9,6 @@ describe('Error boundary handling', function() log = function(msg, level) table.insert(logged_messages, { msg = msg, level = level }) end, - progress = function(msg) - table.insert(logged_messages, { msg = msg, level = vim.log.levels.INFO }) - end, set_config = function() end, } package.loaded['cp.log'] = mock_logger diff --git a/spec/panel_spec.lua b/spec/panel_spec.lua index b15ea84..335e1fa 100644 --- a/spec/panel_spec.lua +++ b/spec/panel_spec.lua @@ -88,20 +88,22 @@ describe('Panel integration', function() state.set_contest_id('2146') state.set_problem_id('b') - local problem = require('cp.problem') local config_module = require('cp.config') local processed_config = config_module.setup({ contests = { codeforces = { cpp = { extension = 'cpp' } } }, }) - local ctx = problem.create_context('codeforces', '2146', 'b', processed_config) + local cp_state = require('cp.state') + cp_state.set_platform('codeforces') + cp_state.set_contest_id('2146') + cp_state.set_problem_id('b') assert.has_no_errors(function() - run.load_test_cases(ctx, state) + run.load_test_cases(state) end) local fake_state_data = { platform = 'codeforces', contest_id = '2146', problem_id = 'b' } assert.has_errors(function() - run.load_test_cases(ctx, fake_state_data) + run.load_test_cases(fake_state_data) end) end) end) diff --git a/spec/problem_spec.lua b/spec/problem_spec.lua deleted file mode 100644 index d76f2d7..0000000 --- a/spec/problem_spec.lua +++ /dev/null @@ -1,146 +0,0 @@ -describe('cp.problem', function() - local problem - local spec_helper = require('spec.spec_helper') - - before_each(function() - spec_helper.setup() - problem = require('cp.problem') - end) - - after_each(function() - spec_helper.teardown() - end) - - describe('create_context', function() - local base_config = { - contests = { - atcoder = { - default_language = 'cpp', - cpp = { extension = 'cpp' }, - python = { extension = 'py' }, - }, - codeforces = { - default_language = 'cpp', - cpp = { extension = 'cpp' }, - }, - }, - } - - it('creates basic context with required fields', function() - local context = problem.create_context('atcoder', 'abc123', 'a', base_config) - - assert.equals('atcoder', context.contest) - assert.equals('abc123', context.contest_id) - assert.equals('a', context.problem_id) - assert.equals('abc123a', context.problem_name) - assert.equals('abc123a.cpp', context.source_file) - assert.equals('build/abc123a.run', context.binary_file) - assert.equals('io/abc123a.cpin', context.input_file) - assert.equals('io/abc123a.cpout', context.output_file) - assert.equals('io/abc123a.expected', context.expected_file) - end) - - it('handles context without problem_id', function() - local context = problem.create_context('codeforces', '1933', nil, base_config) - - assert.equals('codeforces', context.contest) - assert.equals('1933', context.contest_id) - assert.is_nil(context.problem_id) - assert.equals('1933', context.problem_name) - assert.equals('1933.cpp', context.source_file) - assert.equals('build/1933.run', context.binary_file) - end) - - it('uses default language from contest config', function() - local context = problem.create_context('atcoder', 'abc123', 'a', base_config) - assert.equals('abc123a.cpp', context.source_file) - end) - - it('respects explicit language parameter', function() - local context = problem.create_context('atcoder', 'abc123', 'a', base_config, 'python') - assert.equals('abc123a.py', context.source_file) - end) - - it('uses custom filename function when provided', function() - local config_with_custom = vim.tbl_deep_extend('force', base_config, { - filename = function(contest, contest_id, problem_id) - return contest .. '_' .. contest_id .. (problem_id and ('_' .. problem_id) or '') - end, - }) - - local context = problem.create_context('atcoder', 'abc123', 'a', config_with_custom) - assert.equals('atcoder_abc123_a.cpp', context.source_file) - assert.equals('atcoder_abc123_a', context.problem_name) - end) - - it('validates required parameters', function() - assert.has_error(function() - problem.create_context(nil, 'abc123', 'a', base_config) - end) - - assert.has_error(function() - problem.create_context('atcoder', nil, 'a', base_config) - end) - - assert.has_error(function() - problem.create_context('atcoder', 'abc123', 'a', nil) - end) - end) - - it('validates contest exists in config', function() - assert.has_error(function() - problem.create_context('invalid_contest', 'abc123', 'a', base_config) - end) - end) - - it('validates language exists in contest config', function() - assert.has_error(function() - problem.create_context('atcoder', 'abc123', 'a', base_config, 'invalid_language') - end) - end) - - it('validates default language exists', function() - local bad_config = { - contests = { - test_contest = { - default_language = 'nonexistent', - }, - }, - } - - assert.has_error(function() - problem.create_context('test_contest', 'abc123', 'a', bad_config) - end) - end) - - it('validates language extension is configured', function() - local bad_config = { - contests = { - test_contest = { - default_language = 'cpp', - cpp = {}, - }, - }, - } - - assert.has_error(function() - problem.create_context('test_contest', 'abc123', 'a', bad_config) - end) - end) - - it('handles complex contest and problem ids', function() - local context = problem.create_context('atcoder', 'arc123', 'f', base_config) - assert.equals('arc123f', context.problem_name) - assert.equals('arc123f.cpp', context.source_file) - assert.equals('build/arc123f.run', context.binary_file) - end) - - it('generates correct io file paths', function() - local context = problem.create_context('atcoder', 'abc123', 'a', base_config) - - assert.equals('io/abc123a.cpin', context.input_file) - assert.equals('io/abc123a.cpout', context.output_file) - assert.equals('io/abc123a.expected', context.expected_file) - end) - end) -end) diff --git a/spec/spec_helper.lua b/spec/spec_helper.lua index 0e02f87..acbdf62 100644 --- a/spec/spec_helper.lua +++ b/spec/spec_helper.lua @@ -6,9 +6,6 @@ local mock_logger = { log = function(msg, level) table.insert(M.logged_messages, { msg = msg, level = level }) end, - progress = function(msg) - table.insert(M.logged_messages, { msg = msg, level = vim.log.levels.INFO }) - end, set_config = function() end, } @@ -83,10 +80,11 @@ end function M.mock_scraper_success() package.loaded['cp.scrape'] = { - scrape_problem = function(ctx) + scrape_problem = function() + local state = require('cp.state') return { success = true, - problem_id = ctx.problem_id, + problem_id = state.get_problem_id(), test_cases = { { input = '1 2', expected = '3' }, { input = '3 4', expected = '7' },