feat: context, not config

This commit is contained in:
Barrett Ruth 2025-09-24 18:21:34 -04:00
parent a0171ee81e
commit 9e84d57b8a
15 changed files with 209 additions and 328 deletions

View file

@ -25,9 +25,9 @@
---@field default_language? string
---@class Hooks
---@field before_run? fun(ctx: ProblemContext)
---@field before_debug? fun(ctx: ProblemContext)
---@field setup_code? fun(ctx: ProblemContext)
---@field before_run? fun(state: cp.State)
---@field before_debug? fun(state: cp.State)
---@field setup_code? fun(state: cp.State)
---@class RunPanelConfig
---@field ansi boolean Enable ANSI color parsing and highlighting

View file

@ -9,10 +9,4 @@ function M.log(msg, level, override)
end
end
function M.progress(msg)
vim.schedule(function()
vim.notify(('[cp.nvim]: %s'):format(msg), vim.log.levels.INFO)
end)
end
return M

View file

@ -34,16 +34,17 @@ end
---@param platform string Platform identifier (e.g. "codeforces", "atcoder")
---@return cp.ContestItem[]
local function get_contests_for_platform(platform)
local constants = require('cp.constants')
local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform
logger.log(('loading %s contests...'):format(platform_display_name), vim.log.levels.INFO, true)
cache.load()
local cached_contests = cache.get_contest_list(platform)
if cached_contests then
return cached_contests
end
local constants = require('cp.constants')
local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform
logger.progress(('loading %s contests...'):format(platform_display_name))
if not utils.setup_python_env() then
return {}
end
@ -59,8 +60,6 @@ local function get_contests_for_platform(platform)
'contests',
}
logger.progress(('running: %s'):format(table.concat(cmd, ' ')))
local result = vim
.system(cmd, {
cwd = plugin_path,
@ -69,9 +68,9 @@ local function get_contests_for_platform(platform)
})
:wait()
logger.progress(('exit code: %d, stdout length: %d'):format(result.code, #(result.stdout or '')))
logger.log(('exit code: %d, stdout length: %d'):format(result.code, #(result.stdout or '')))
if result.stderr and #result.stderr > 0 then
logger.progress(('stderr: %s'):format(result.stderr:sub(1, 200)))
logger.log(('stderr: %s'):format(result.stderr:sub(1, 200)))
end
if result.code ~= 0 then
@ -82,7 +81,7 @@ local function get_contests_for_platform(platform)
return {}
end
logger.progress(('stdout preview: %s'):format(result.stdout:sub(1, 100)))
logger.log(('stdout preview: %s'):format(result.stdout:sub(1, 100)))
local ok, data = pcall(vim.json.decode, result.stdout)
if not ok then
@ -107,7 +106,7 @@ local function get_contests_for_platform(platform)
end
cache.set_contest_list(platform, contests)
logger.progress(('loaded %d contests'):format(#contests))
logger.log(('loaded %d contests'):format(#contests))
return contests
end
@ -115,6 +114,8 @@ end
---@param contest_id string Contest identifier
---@return cp.ProblemItem[]
local function get_problems_for_contest(platform, contest_id)
local constants = require('cp.constants')
local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform
local problems = {}
cache.load()
@ -130,14 +131,16 @@ local function get_problems_for_contest(platform, contest_id)
return problems
end
logger.log(
('loading %s %s problems...'):format(platform_display_name, contest_id),
vim.log.levels.INFO,
true
)
if not utils.setup_python_env() then
return problems
end
local constants = require('cp.constants')
local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform
logger.progress(('loading %s %s problems...'):format(platform_display_name, contest_id))
local plugin_path = utils.get_plugin_path()
local cmd = {
'uv',

View file

@ -1,68 +0,0 @@
---@class ProblemContext
---@field contest string Contest name (e.g. "atcoder", "codeforces")
---@field contest_id string Contest ID (e.g. "abc123", "1933")
---@field problem_id? string Problem ID for AtCoder/Codeforces (e.g. "a", "b")
---@field source_file string Source filename (e.g. "abc123a.cpp")
---@field binary_file string Binary output path (e.g. "build/abc123a.run")
---@field input_file string Input test file path (e.g. "io/abc123a.in")
---@field output_file string Output file path (e.g. "io/abc123a.out")
---@field expected_file string Expected output path (e.g. "io/abc123a.expected")
---@field problem_name string Canonical problem identifier (e.g. "abc123a")
local M = {}
---@param contest string
---@param contest_id string
---@param problem_id? string
---@param config cp.Config
---@param language? string
---@return ProblemContext
function M.create_context(contest, contest_id, problem_id, config, language)
vim.validate({
contest = { contest, 'string' },
contest_id = { contest_id, 'string' },
problem_id = { problem_id, { 'string', 'nil' }, true },
config = { config, 'table' },
language = { language, { 'string', 'nil' }, true },
})
local contest_config = config.contests[contest]
if not contest_config then
error(("No contest config found for '%s'"):format(contest))
end
local target_language = language or contest_config.default_language
local language_config = contest_config[target_language]
if not language_config then
error(("No language config found for '%s' in contest '%s'"):format(target_language, contest))
end
if not language_config.extension then
error(
("No extension configured for language '%s' in contest '%s'"):format(target_language, contest)
)
end
local base_name
if config.filename then
base_name = config.filename(contest, contest_id, problem_id, config, language)
else
local default_filename = require('cp.config').default_filename
base_name = default_filename(contest_id, problem_id)
end
local source_file = base_name .. '.' .. language_config.extension
return {
contest = contest,
contest_id = contest_id,
problem_id = problem_id,
source_file = source_file,
binary_file = ('build/%s.run'):format(base_name),
input_file = ('io/%s.cpin'):format(base_name),
output_file = ('io/%s.cpout'):format(base_name),
expected_file = ('io/%s.expected'):format(base_name),
problem_name = base_name,
}
end
return M

View file

@ -203,17 +203,22 @@ local function format_output(exec_result, expected_file, is_debug)
return table.concat(output_lines, '') .. '\n' .. table.concat(metadata_lines, '\n')
end
---@param ctx ProblemContext
---@param contest_config ContestConfig
---@param is_debug? boolean
---@return {success: boolean, output: string?}
function M.compile_problem(ctx, contest_config, is_debug)
function M.compile_problem(contest_config, is_debug)
vim.validate({
ctx = { ctx, 'table' },
contest_config = { contest_config, 'table' },
})
local language = get_language_from_file(ctx.source_file, contest_config)
local state = require('cp.state')
local source_file = state.get_source_file()
if not source_file then
logger.log('No source file found', vim.log.levels.ERROR)
return { success = false, output = 'No source file found' }
end
local language = get_language_from_file(source_file, contest_config)
local language_config = contest_config[language]
if not language_config then
@ -221,9 +226,10 @@ function M.compile_problem(ctx, contest_config, is_debug)
return { success = false, output = 'No configuration for language: ' .. language }
end
local binary_file = state.get_binary_file()
local substitutions = {
source = ctx.source_file,
binary = ctx.binary_file,
source = source_file,
binary = binary_file,
version = tostring(language_config.version),
}
@ -244,26 +250,35 @@ function M.compile_problem(ctx, contest_config, is_debug)
return { success = true, output = nil }
end
function M.run_problem(ctx, contest_config, is_debug)
function M.run_problem(contest_config, is_debug)
vim.validate({
ctx = { ctx, 'table' },
contest_config = { contest_config, 'table' },
is_debug = { is_debug, 'boolean' },
})
vim.system({ 'mkdir', '-p', 'build', 'io' }):wait()
local state = require('cp.state')
local source_file = state.get_source_file()
local output_file = state.get_output_file()
local language = get_language_from_file(ctx.source_file, contest_config)
local language_config = contest_config[language]
if not language_config then
vim.fn.writefile({ 'Error: No configuration for language: ' .. language }, ctx.output_file)
if not source_file or not output_file then
logger.log('Missing required file paths', vim.log.levels.ERROR)
return
end
vim.system({ 'mkdir', '-p', 'build', 'io' }):wait()
local language = get_language_from_file(source_file, contest_config)
local language_config = contest_config[language]
if not language_config then
vim.fn.writefile({ 'Error: No configuration for language: ' .. language }, output_file)
return
end
local binary_file = state.get_binary_file()
local substitutions = {
source = ctx.source_file,
binary = ctx.binary_file,
source = source_file,
binary = binary_file,
version = tostring(language_config.version),
}
@ -271,26 +286,31 @@ function M.run_problem(ctx, contest_config, is_debug)
if compile_cmd then
local compile_result = M.compile_generic(language_config, substitutions)
if compile_result.code ~= 0 then
vim.fn.writefile({ compile_result.stderr }, ctx.output_file)
vim.fn.writefile({ compile_result.stderr }, output_file)
return
end
end
local input_file = state.get_input_file()
local input_data = ''
if vim.fn.filereadable(ctx.input_file) == 1 then
input_data = table.concat(vim.fn.readfile(ctx.input_file), '\n') .. '\n'
if input_file and vim.fn.filereadable(input_file) == 1 then
input_data = table.concat(vim.fn.readfile(input_file), '\n') .. '\n'
end
local cache = require('cp.cache')
cache.load()
local timeout_ms, _ = cache.get_constraints(ctx.contest, ctx.contest_id, ctx.problem_id)
local platform = state.get_platform()
local contest_id = state.get_contest_id()
local problem_id = state.get_problem_id()
local timeout_ms, _ = cache.get_constraints(platform, contest_id, problem_id)
timeout_ms = timeout_ms or 2000
local run_cmd = build_command(language_config.test, language_config.executable, substitutions)
local exec_result = execute_command(run_cmd, input_data, timeout_ms)
local formatted_output = format_output(exec_result, ctx.expected_file, is_debug)
local expected_file = state.get_expected_file()
local formatted_output = format_output(exec_result, expected_file, is_debug)
local output_buf = vim.fn.bufnr(ctx.output_file)
local output_buf = vim.fn.bufnr(output_file)
if output_buf ~= -1 then
local was_modifiable = vim.api.nvim_get_option_value('modifiable', { buf = output_buf })
local was_readonly = vim.api.nvim_get_option_value('readonly', { buf = output_buf })
@ -303,7 +323,7 @@ function M.run_problem(ctx, contest_config, is_debug)
vim.cmd.write()
end)
else
vim.fn.writefile(vim.split(formatted_output, '\n'), ctx.output_file)
vim.fn.writefile(vim.split(formatted_output, '\n'), output_file)
end
end

View file

@ -130,12 +130,22 @@ local function load_constraints_from_cache(platform, contest_id, problem_id)
return nil
end
---@param ctx ProblemContext
---@param contest_config ContestConfig
---@param test_case TestCase
---@return table
local function run_single_test_case(ctx, contest_config, cp_config, test_case)
local language = vim.fn.fnamemodify(ctx.source_file, ':e')
local function run_single_test_case(contest_config, cp_config, test_case)
local state = require('cp.state')
local source_file = state.get_source_file()
if not source_file then
return {
status = 'fail',
actual = '',
error = 'No source file found',
time_ms = 0,
}
end
local language = vim.fn.fnamemodify(source_file, ':e')
local language_name = constants.filetype_to_language[language] or contest_config.default_language
local language_config = contest_config[language_name]
@ -168,13 +178,14 @@ local function run_single_test_case(ctx, contest_config, cp_config, test_case)
return cmd
end
local binary_file = state.get_binary_file()
local substitutions = {
source = ctx.source_file,
binary = ctx.binary_file,
source = source_file,
binary = binary_file,
version = tostring(language_config.version or ''),
}
if language_config.compile and vim.fn.filereadable(ctx.binary_file) == 0 then
if language_config.compile and vim.fn.filereadable(binary_file) == 0 then
logger.log('binary not found, compiling first...')
local compile_cmd = substitute_template(language_config.compile, substitutions)
local redirected_cmd = vim.deepcopy(compile_cmd)
@ -282,10 +293,9 @@ local function run_single_test_case(ctx, contest_config, cp_config, test_case)
}
end
---@param ctx ProblemContext
---@param state table
---@return boolean
function M.load_test_cases(ctx, state)
function M.load_test_cases(state)
local test_cases = parse_test_cases_from_cache(
state.get_platform() or '',
state.get_contest_id() or '',
@ -293,7 +303,9 @@ function M.load_test_cases(ctx, state)
)
if #test_cases == 0 then
test_cases = parse_test_cases_from_files(ctx.input_file, ctx.expected_file)
local input_file = state.get_input_file()
local expected_file = state.get_expected_file()
test_cases = parse_test_cases_from_files(input_file, expected_file)
end
run_panel_state.test_cases = test_cases
@ -315,11 +327,10 @@ function M.load_test_cases(ctx, state)
return #test_cases > 0
end
---@param ctx ProblemContext
---@param contest_config ContestConfig
---@param index number
---@return boolean
function M.run_test_case(ctx, contest_config, cp_config, index)
function M.run_test_case(contest_config, cp_config, index)
local test_case = run_panel_state.test_cases[index]
if not test_case then
return false
@ -327,7 +338,7 @@ function M.run_test_case(ctx, contest_config, cp_config, index)
test_case.status = 'running'
local result = run_single_test_case(ctx, contest_config, cp_config, test_case)
local result = run_single_test_case(contest_config, cp_config, test_case)
test_case.status = result.status
test_case.actual = result.actual
@ -343,13 +354,13 @@ function M.run_test_case(ctx, contest_config, cp_config, index)
return true
end
---@param ctx ProblemContext
---@param contest_config ContestConfig
---@param cp_config cp.Config
---@return TestCase[]
function M.run_all_test_cases(ctx, contest_config, cp_config)
function M.run_all_test_cases(contest_config, cp_config)
local results = {}
for i, _ in ipairs(run_panel_state.test_cases) do
M.run_test_case(ctx, contest_config, cp_config, i)
M.run_test_case(contest_config, cp_config, i)
table.insert(results, run_panel_state.test_cases[i])
end
return results

View file

@ -3,7 +3,6 @@ local M = {}
local cache = require('cp.cache')
local config_module = require('cp.config')
local logger = require('cp.log')
local problem = require('cp.problem')
local scraper = require('cp.scraper')
local state = require('cp.state')
@ -42,7 +41,7 @@ function M.setup_contest(platform, contest_id, problem_id, language)
return
end
logger.progress(('fetching contest %s %s...'):format(platform, contest_id))
logger.log(('fetching contest %s %s...'):format(platform, contest_id))
scraper.scrape_contest_metadata(platform, contest_id, function(result)
if not result.success then
@ -59,7 +58,7 @@ function M.setup_contest(platform, contest_id, problem_id, language)
return
end
logger.progress(('found %d problems'):format(#problems))
logger.log(('found %d problems'):format(#problems))
state.set_contest_id(contest_id)
local target_problem = problem_id or problems[1].id
@ -96,16 +95,17 @@ function M.setup_problem(contest_id, problem_id, language)
local config = config_module.get_config()
local platform = state.get_platform() or ''
logger.progress(('setting up problem %s%s...'):format(contest_id, problem_id or ''))
logger.log(('setting up problem %s%s...'):format(contest_id, problem_id or ''))
local ctx = problem.create_context(platform, contest_id, problem_id, config, language)
state.set_contest_id(contest_id)
state.set_problem_id(problem_id)
local cached_tests = cache.get_test_cases(platform, contest_id, problem_id)
if cached_tests then
state.set_test_cases(cached_tests)
logger.log(('using cached test cases (%d)'):format(#cached_tests))
elseif vim.tbl_contains(config.scrapers, platform) then
logger.progress('loading test cases...')
logger.log('loading test cases...')
scraper.scrape_problem_tests(platform, contest_id, problem_id, function(result)
if result.success then
@ -128,15 +128,17 @@ function M.setup_problem(contest_id, problem_id, language)
state.set_test_cases({})
end
state.set_contest_id(contest_id)
state.set_problem_id(problem_id)
state.set_run_panel_active(false)
vim.schedule(function()
local ok, err = pcall(function()
vim.cmd.only({ mods = { silent = true } })
vim.cmd.e(ctx.source_file)
local source_file = state.get_source_file(language)
if not source_file then
error('Failed to generate source file path')
end
vim.cmd.e(source_file)
local source_buf = vim.api.nvim_get_current_buf()
if vim.api.nvim_buf_get_lines(source_buf, 0, -1, true)[1] == '' then
@ -166,12 +168,12 @@ function M.setup_problem(contest_id, problem_id, language)
end
if config.hooks and config.hooks.setup_code then
config.hooks.setup_code(ctx)
config.hooks.setup_code(state)
end
cache.set_file_state(vim.fn.expand('%:p'), platform, contest_id, problem_id, language)
logger.progress(('ready - problem %s'):format(ctx.problem_name))
logger.log(('ready - problem %s'):format(state.get_base_name()))
end)
if not ok then
@ -196,7 +198,7 @@ function M.scrape_remaining_problems(platform, contest_id, problems)
return
end
logger.progress(('caching %d remaining problems...'):format(#missing_problems))
logger.log(('caching %d remaining problems...'):format(#missing_problems))
for _, prob in ipairs(missing_problems) do
scraper.scrape_problem_tests(platform, contest_id, prob.id, function(result)

View file

@ -1,3 +1,26 @@
---@class cp.State
---@field get_platform fun(): string?
---@field set_platform fun(platform: string)
---@field get_contest_id fun(): string?
---@field set_contest_id fun(contest_id: string)
---@field get_problem_id fun(): string?
---@field set_problem_id fun(problem_id: string)
---@field get_test_cases fun(): table[]?
---@field set_test_cases fun(test_cases: table[])
---@field is_run_panel_active fun(): boolean
---@field set_run_panel_active fun(active: boolean)
---@field get_saved_session fun(): table?
---@field set_saved_session fun(session: table)
---@field get_context fun(): {platform: string?, contest_id: string?, problem_id: string?}
---@field has_context fun(): boolean
---@field reset fun()
---@field get_base_name fun(): string?
---@field get_source_file fun(language?: string): string?
---@field get_binary_file fun(): string?
---@field get_input_file fun(): string?
---@field get_output_file fun(): string?
---@field get_expected_file fun(): string?
local M = {}
local state = {
@ -65,6 +88,62 @@ function M.get_context()
}
end
function M.get_base_name()
if not state.contest_id then
return nil
end
local config_module = require('cp.config')
local config = config_module.get_config()
if config.filename then
return config.filename(state.platform or '', state.contest_id, state.problem_id, config)
else
return config_module.default_filename(state.contest_id, state.problem_id)
end
end
function M.get_source_file(language)
local base_name = M.get_base_name()
if not base_name or not state.platform then
return nil
end
local config = require('cp.config').get_config()
local contest_config = config.contests[state.platform]
if not contest_config then
return nil
end
local target_language = language or contest_config.default_language
local language_config = contest_config[target_language]
if not language_config or not language_config.extension then
return nil
end
return base_name .. '.' .. language_config.extension
end
function M.get_binary_file()
local base_name = M.get_base_name()
return base_name and ('build/%s.run'):format(base_name) or nil
end
function M.get_input_file()
local base_name = M.get_base_name()
return base_name and ('io/%s.cpin'):format(base_name) or nil
end
function M.get_output_file()
local base_name = M.get_base_name()
return base_name and ('io/%s.cpout'):format(base_name) or nil
end
function M.get_expected_file()
local base_name = M.get_base_name()
return base_name and ('io/%s.expected'):format(base_name) or nil
end
function M.has_context()
return state.platform and state.contest_id
end

View file

@ -4,7 +4,6 @@ local buffer_utils = require('cp.utils.buffer')
local config_module = require('cp.config')
local layouts = require('cp.ui.layouts')
local logger = require('cp.log')
local problem = require('cp.problem')
local state = require('cp.state')
local current_diff_layout = nil
@ -57,12 +56,12 @@ function M.toggle_run_panel(is_debug)
)
local config = config_module.get_config()
local ctx = problem.create_context(platform or '', contest_id or '', problem_id, config)
local run = require('cp.runner.run')
logger.log(('run panel: checking test cases for %s'):format(ctx.input_file))
local input_file = state.get_input_file()
logger.log(('run panel: checking test cases for %s'):format(input_file or 'none'))
if not run.load_test_cases(ctx, state) then
if not run.load_test_cases(state) then
logger.log('no test cases found', vim.log.levels.WARN)
return
end
@ -170,18 +169,18 @@ function M.toggle_run_panel(is_debug)
setup_keybindings_for_buffer(test_buffers.tab_buf)
if config.hooks and config.hooks.before_run then
config.hooks.before_run(ctx)
config.hooks.before_run(state)
end
if is_debug and config.hooks and config.hooks.before_debug then
config.hooks.before_debug(ctx)
config.hooks.before_debug(state)
end
local execute = require('cp.runner.execute')
local contest_config = config.contests[state.get_platform() or '']
local compile_result = execute.compile_problem(ctx, contest_config, is_debug)
local compile_result = execute.compile_problem(contest_config, is_debug)
if compile_result.success then
run.run_all_test_cases(ctx, contest_config, config)
run.run_all_test_cases(contest_config, config)
else
run.handle_compilation_failure(compile_result.output)
end