diff --git a/lua/cp/commands/cache.lua b/lua/cp/commands/cache.lua new file mode 100644 index 0000000..08f50de --- /dev/null +++ b/lua/cp/commands/cache.lua @@ -0,0 +1,32 @@ +local M = {} + +local cache = require('cp.cache') +local constants = require('cp.constants') +local logger = require('cp.log') + +local platforms = constants.PLATFORMS + +function M.handle_cache_command(cmd) + if cmd.subcommand == 'clear' then + cache.load() + if cmd.platform then + if vim.tbl_contains(platforms, cmd.platform) then + cache.clear_platform(cmd.platform) + logger.log(('cleared cache for %s'):format(cmd.platform), vim.log.levels.INFO, true) + else + logger.log( + ('unknown platform: %s. Available: %s'):format( + cmd.platform, + table.concat(platforms, ', ') + ), + vim.log.levels.ERROR + ) + end + else + cache.clear_all() + logger.log('cleared all cache', vim.log.levels.INFO, true) + end + end +end + +return M diff --git a/lua/cp/commands/init.lua b/lua/cp/commands/init.lua new file mode 100644 index 0000000..0ef9c3a --- /dev/null +++ b/lua/cp/commands/init.lua @@ -0,0 +1,177 @@ +local M = {} + +local constants = require('cp.constants') +local logger = require('cp.log') +local state = require('cp.state') + +local platforms = constants.PLATFORMS +local actions = constants.ACTIONS + +local function parse_command(args) + if #args == 0 then + return { + type = 'restore_from_file', + } + end + + local language = nil + local debug = false + + for i, arg in ipairs(args) do + local lang_match = arg:match('^--lang=(.+)$') + if lang_match then + language = lang_match + elseif arg == '--lang' then + if i + 1 <= #args then + language = args[i + 1] + else + return { type = 'error', message = '--lang requires a value' } + end + elseif arg == '--debug' then + debug = true + end + end + + local filtered_args = vim.tbl_filter(function(arg) + return not (arg:match('^--lang') or arg == language or arg == '--debug') + end, args) + + local first = filtered_args[1] + + if vim.tbl_contains(actions, first) then + if first == 'cache' then + local subcommand = filtered_args[2] + if not subcommand then + return { type = 'error', message = 'cache command requires subcommand: clear' } + end + if subcommand == 'clear' then + local platform = filtered_args[3] + return { + type = 'cache', + subcommand = 'clear', + platform = platform, + } + else + return { type = 'error', message = 'unknown cache subcommand: ' .. subcommand } + end + else + return { type = 'action', action = first, language = language, debug = debug } + end + end + + if vim.tbl_contains(platforms, first) then + if #filtered_args == 1 then + return { + type = 'platform_only', + platform = first, + language = language, + } + elseif #filtered_args == 2 then + return { + type = 'contest_setup', + platform = first, + contest = filtered_args[2], + language = language, + } + elseif #filtered_args == 3 then + return { + type = 'full_setup', + platform = first, + contest = filtered_args[2], + problem = filtered_args[3], + language = language, + } + else + return { type = 'error', message = 'Too many arguments' } + end + end + + if state.get_platform() and state.get_contest_id() then + local cache = require('cp.cache') + cache.load() + local contest_data = + cache.get_contest_data(state.get_platform() or '', state.get_contest_id() or '') + if contest_data and contest_data.problems then + local problem_ids = vim.tbl_map(function(prob) + return prob.id + end, contest_data.problems) + if vim.tbl_contains(problem_ids, first) then + return { type = 'problem_switch', problem = first, language = language } + end + end + return { + type = 'error', + message = ("invalid subcommand '%s'"):format(first), + } + end + + return { type = 'error', message = 'Unknown command or no contest context' } +end + +function M.handle_command(opts) + local cmd = parse_command(opts.fargs) + + if cmd.type == 'error' then + logger.log(cmd.message, vim.log.levels.ERROR) + return + end + + if cmd.type == 'restore_from_file' then + local restore = require('cp.restore') + restore.restore_from_current_file() + return + end + + if cmd.type == 'action' then + local setup = require('cp.setup') + local ui = require('cp.ui.panel') + + if cmd.action == 'run' then + ui.toggle_run_panel(cmd.debug) + elseif cmd.action == 'next' then + setup.navigate_problem(1, cmd.language) + elseif cmd.action == 'prev' then + setup.navigate_problem(-1, cmd.language) + elseif cmd.action == 'pick' then + local picker = require('cp.commands.picker') + picker.handle_pick_action() + end + return + end + + if cmd.type == 'cache' then + local cache_commands = require('cp.commands.cache') + cache_commands.handle_cache_command(cmd) + return + end + + if cmd.type == 'platform_only' then + local setup = require('cp.setup') + setup.set_platform(cmd.platform) + return + end + + if cmd.type == 'contest_setup' then + local setup = require('cp.setup') + if setup.set_platform(cmd.platform) then + setup.setup_contest(cmd.contest, cmd.language) + end + return + end + + if cmd.type == 'full_setup' then + local setup = require('cp.setup') + if setup.set_platform(cmd.platform) then + setup.handle_full_setup(cmd) + end + return + end + + if cmd.type == 'problem_switch' then + local setup = require('cp.setup') + setup.setup_problem(state.get_contest_id() or '', cmd.problem, cmd.language) + return + end +end + +return M diff --git a/lua/cp/commands/picker.lua b/lua/cp/commands/picker.lua new file mode 100644 index 0000000..755b613 --- /dev/null +++ b/lua/cp/commands/picker.lua @@ -0,0 +1,50 @@ +local M = {} + +local config_module = require('cp.config') +local logger = require('cp.log') + +function M.handle_pick_action() + local config = config_module.get_config() + + if not config.picker then + logger.log( + 'No picker configured. Set picker = "telescope" or picker = "fzf-lua" in config', + vim.log.levels.ERROR + ) + return + end + + if config.picker == 'telescope' then + local ok = pcall(require, 'telescope') + if not ok then + logger.log( + 'Telescope not available. Install telescope.nvim or change picker config', + vim.log.levels.ERROR + ) + return + end + local ok_cp, telescope_cp = pcall(require, 'cp.pickers.telescope') + if not ok_cp then + logger.log('Failed to load telescope integration', vim.log.levels.ERROR) + return + end + telescope_cp.platform_picker() + elseif config.picker == 'fzf-lua' then + local ok, _ = pcall(require, 'fzf-lua') + if not ok then + logger.log( + 'fzf-lua not available. Install fzf-lua or change picker config', + vim.log.levels.ERROR + ) + return + end + local ok_cp, fzf_cp = pcall(require, 'cp.pickers.fzf_lua') + if not ok_cp then + logger.log('Failed to load fzf-lua integration', vim.log.levels.ERROR) + return + end + fzf_cp.platform_picker() + end +end + +return M diff --git a/lua/cp/config.lua b/lua/cp/config.lua index 8bafd35..411d002 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -292,4 +292,14 @@ end M.default_filename = default_filename +local current_config = nil + +function M.set_current_config(config) + current_config = config +end + +function M.get_config() + return current_config or M.defaults +end + return M diff --git a/lua/cp/init.lua b/lua/cp/init.lua index eed7d7b..3319bb7 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -1,10 +1,7 @@ local M = {} -local cache = require('cp.cache') local config_module = require('cp.config') local logger = require('cp.log') -local problem = require('cp.problem') -local scrape = require('cp.scrape') local snippets = require('cp.snippets') local state = require('cp.state') @@ -17,1064 +14,17 @@ local user_config = {} local config = config_module.setup(user_config) local snippets_initialized = false -local current_diff_layout = nil -local current_mode = nil - -local constants = require('cp.constants') -local platforms = constants.PLATFORMS -local actions = constants.ACTIONS - -local function set_platform(platform) - if not vim.tbl_contains(platforms, platform) then - logger.log( - ('unknown platform. Available: [%s]'):format(table.concat(platforms, ', ')), - vim.log.levels.ERROR - ) - return false - end - - state.set_platform(platform) - vim.system({ 'mkdir', '-p', 'build', 'io' }):wait() - return true -end - ----@param contest_id string ----@param problem_id? string ----@param language? string -local function setup_problem(contest_id, problem_id, language) - if state.get_platform() == '' then - logger.log('no platform set. run :CP first', vim.log.levels.ERROR) - return - end - - local problem_name = contest_id .. (problem_id or '') - logger.log(('setting up problem: %s'):format(problem_name)) - - local ctx = problem.create_context(state.get_platform(), contest_id, problem_id, config, language) - - if vim.tbl_contains(config.scrapers, state.get_platform()) then - cache.load() - local existing_contest_data = cache.get_contest_data(state.get_platform(), contest_id) - - if not existing_contest_data then - local metadata_result = scrape.scrape_contest_metadata(state.get_platform(), contest_id) - if not metadata_result.success then - logger.log( - 'failed to load contest metadata: ' .. (metadata_result.error or 'unknown error'), - vim.log.levels.WARN - ) - end - end - end - - local cached_test_cases = cache.get_test_cases(state.get_platform(), contest_id, problem_id) - if cached_test_cases then - state.set_test_cases(cached_test_cases) - logger.log(('using cached test cases (%d)'):format(#cached_test_cases)) - elseif vim.tbl_contains(config.scrapers, state.get_platform()) then - local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[state.get_platform()] - or state.get_platform() - logger.log( - ('Scraping %s %s %s for test cases, this may take a few seconds...'):format( - platform_display_name, - contest_id, - problem_id - ), - vim.log.levels.INFO, - true - ) - - local scrape_result = scrape.scrape_problem(ctx) - - if not scrape_result.success then - logger.log( - 'scraping failed: ' .. (scrape_result.error or 'unknown error'), - vim.log.levels.ERROR - ) - return - end - - local test_count = scrape_result.test_count or 0 - logger.log(('scraped %d test case(s) for %s'):format(test_count, scrape_result.problem_id)) - state.set_test_cases(scrape_result.test_cases) - - if scrape_result.test_cases then - cache.set_test_cases(state.get_platform(), contest_id, problem_id, scrape_result.test_cases) - end - else - logger.log(('scraping disabled for %s'):format(state.get_platform())) - state.set_test_cases(nil) - end - - vim.cmd('silent only') - state.set_run_panel_active(false) - - state.set_contest_id(contest_id) - state.set_problem_id(problem_id) - - vim.cmd.e(ctx.source_file) - local source_buf = vim.api.nvim_get_current_buf() - - if vim.api.nvim_buf_get_lines(source_buf, 0, -1, true)[1] == '' then - local has_luasnip, luasnip = pcall(require, 'luasnip') - if has_luasnip then - local filetype = vim.api.nvim_get_option_value('filetype', { buf = source_buf }) - local language_name = constants.filetype_to_language[filetype] - local canonical_language = constants.canonical_filetypes[language_name] or language_name - local prefixed_trigger = ('cp.nvim/%s.%s'):format(state.get_platform(), canonical_language) - - vim.api.nvim_buf_set_lines(0, 0, -1, false, { prefixed_trigger }) - vim.api.nvim_win_set_cursor(0, { 1, #prefixed_trigger }) - vim.cmd.startinsert({ bang = true }) - - vim.schedule(function() - if luasnip.expandable() then - luasnip.expand() - else - vim.api.nvim_buf_set_lines(0, 0, 1, false, { '' }) - vim.api.nvim_win_set_cursor(0, { 1, 0 }) - end - vim.cmd.stopinsert() - end) - else - vim.api.nvim_input(('i%s'):format(state.get_platform())) - end - end - - if config.hooks and config.hooks.setup_code then - config.hooks.setup_code(ctx) - end - - cache.set_file_state(vim.fn.expand('%:p'), state.get_platform(), contest_id, problem_id, language) - - logger.log(('switched to problem %s'):format(ctx.problem_name)) -end - -local function scrape_missing_problems(contest_id, missing_problems) - vim.fn.mkdir('io', 'p') - - logger.log(('scraping %d uncached problems...'):format(#missing_problems)) - - local results = - scrape.scrape_problems_parallel(state.get_platform(), contest_id, missing_problems, config) - - local success_count = 0 - local failed_problems = {} - for problem_id, result in pairs(results) do - if result.success then - success_count = success_count + 1 - else - table.insert(failed_problems, problem_id) - end - end - - if #failed_problems > 0 then - logger.log( - ('scraping complete: %d/%d successful, failed: %s'):format( - success_count, - #missing_problems, - table.concat(failed_problems, ', ') - ), - vim.log.levels.WARN - ) - else - logger.log(('scraping complete: %d/%d successful'):format(success_count, #missing_problems)) - end -end - -local function get_current_problem() - local filename = vim.fn.expand('%:t:r') - if filename == '' then - logger.log('no file open', vim.log.levels.ERROR) - return nil - end - return filename -end - -local function create_buffer_with_options(filetype) - local buf = vim.api.nvim_create_buf(false, true) - vim.api.nvim_set_option_value('bufhidden', 'wipe', { buf = buf }) - vim.api.nvim_set_option_value('readonly', true, { buf = buf }) - vim.api.nvim_set_option_value('modifiable', false, { buf = buf }) - if filetype then - vim.api.nvim_set_option_value('filetype', filetype, { buf = buf }) - end - return buf -end - -local setup_keybindings_for_buffer - -local function toggle_run_panel(is_debug) - if state.run_panel_active then - if current_diff_layout then - current_diff_layout.cleanup() - current_diff_layout = nil - current_mode = nil - end - if state.saved_session then - vim.cmd(('source %s'):format(state.saved_session)) - vim.fn.delete(state.saved_session) - state.saved_session = nil - end - - state.set_run_panel_active(false) - logger.log('test panel closed') - return - end - - if state.get_platform() == '' then - logger.log( - 'No contest configured. Use :CP to set up first.', - vim.log.levels.ERROR - ) - return - end - - local problem_id = get_current_problem() - if not problem_id then - return - end - - local ctx = problem.create_context( - state.get_platform(), - state.get_contest_id(), - state.get_problem_id(), - config - ) - local run = require('cp.runner.run') - - if not run.load_test_cases(ctx, state) then - logger.log('no test cases found', vim.log.levels.WARN) - return - end - - state.saved_session = vim.fn.tempname() - vim.cmd(('mksession! %s'):format(state.saved_session)) - - vim.cmd('silent only') - - local tab_buf = create_buffer_with_options() - local main_win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(main_win, tab_buf) - vim.api.nvim_set_option_value('filetype', 'cptest', { buf = tab_buf }) - - local test_windows = { - tab_win = main_win, - } - local test_buffers = { - tab_buf = tab_buf, - } - - local highlight = require('cp.ui.highlight') - local diff_namespace = highlight.create_namespace() - - local test_list_namespace = vim.api.nvim_create_namespace('cp_test_list') - local ansi_namespace = vim.api.nvim_create_namespace('cp_ansi_highlights') - - local function update_buffer_content(bufnr, lines, highlights, namespace) - local was_readonly = vim.api.nvim_get_option_value('readonly', { buf = bufnr }) - - vim.api.nvim_set_option_value('readonly', false, { buf = bufnr }) - vim.api.nvim_set_option_value('modifiable', true, { buf = bufnr }) - vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) - vim.api.nvim_set_option_value('modifiable', false, { buf = bufnr }) - vim.api.nvim_set_option_value('readonly', was_readonly, { buf = bufnr }) - - highlight.apply_highlights(bufnr, highlights, namespace or test_list_namespace) - end - - local function create_none_diff_layout(parent_win, expected_content, actual_content) - local expected_buf = create_buffer_with_options() - local actual_buf = create_buffer_with_options() - - vim.api.nvim_set_current_win(parent_win) - vim.cmd.split() - vim.cmd('resize ' .. math.floor(vim.o.lines * 0.35)) - local actual_win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(actual_win, actual_buf) - - vim.cmd.vsplit() - local expected_win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(expected_win, expected_buf) - - vim.api.nvim_set_option_value('filetype', 'cptest', { buf = expected_buf }) - vim.api.nvim_set_option_value('filetype', 'cptest', { buf = actual_buf }) - vim.api.nvim_set_option_value('winbar', 'Expected', { win = expected_win }) - vim.api.nvim_set_option_value('winbar', 'Actual', { win = actual_win }) - - local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) - local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) - - update_buffer_content(expected_buf, expected_lines, {}) - update_buffer_content(actual_buf, actual_lines, {}) - - return { - buffers = { expected_buf, actual_buf }, - windows = { expected_win, actual_win }, - cleanup = function() - pcall(vim.api.nvim_win_close, expected_win, true) - pcall(vim.api.nvim_win_close, actual_win, true) - pcall(vim.api.nvim_buf_delete, expected_buf, { force = true }) - pcall(vim.api.nvim_buf_delete, actual_buf, { force = true }) - end, - } - end - - local function create_vim_diff_layout(parent_win, expected_content, actual_content) - local expected_buf = create_buffer_with_options() - local actual_buf = create_buffer_with_options() - - vim.api.nvim_set_current_win(parent_win) - vim.cmd.split() - vim.cmd('resize ' .. math.floor(vim.o.lines * 0.35)) - local actual_win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(actual_win, actual_buf) - - vim.cmd.vsplit() - local expected_win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(expected_win, expected_buf) - - vim.api.nvim_set_option_value('filetype', 'cptest', { buf = expected_buf }) - vim.api.nvim_set_option_value('filetype', 'cptest', { buf = actual_buf }) - vim.api.nvim_set_option_value('winbar', 'Expected', { win = expected_win }) - vim.api.nvim_set_option_value('winbar', 'Actual', { win = actual_win }) - - local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) - local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) - - update_buffer_content(expected_buf, expected_lines, {}) - update_buffer_content(actual_buf, actual_lines, {}) - - vim.api.nvim_set_option_value('diff', true, { win = expected_win }) - vim.api.nvim_set_option_value('diff', true, { win = actual_win }) - vim.api.nvim_win_call(expected_win, function() - vim.cmd.diffthis() - end) - vim.api.nvim_win_call(actual_win, function() - vim.cmd.diffthis() - end) - -- NOTE: diffthis() sets foldcolumn, so override it after - vim.api.nvim_set_option_value('foldcolumn', '0', { win = expected_win }) - vim.api.nvim_set_option_value('foldcolumn', '0', { win = actual_win }) - - return { - buffers = { expected_buf, actual_buf }, - windows = { expected_win, actual_win }, - cleanup = function() - pcall(vim.api.nvim_win_close, expected_win, true) - pcall(vim.api.nvim_win_close, actual_win, true) - pcall(vim.api.nvim_buf_delete, expected_buf, { force = true }) - pcall(vim.api.nvim_buf_delete, actual_buf, { force = true }) - end, - } - end - - local function create_git_diff_layout(parent_win, expected_content, actual_content) - local diff_buf = create_buffer_with_options() - - vim.api.nvim_set_current_win(parent_win) - vim.cmd.split() - vim.cmd('resize ' .. math.floor(vim.o.lines * 0.35)) - local diff_win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(diff_win, diff_buf) - - vim.api.nvim_set_option_value('filetype', 'cptest', { buf = diff_buf }) - vim.api.nvim_set_option_value('winbar', 'Expected vs Actual', { win = diff_win }) - - local diff_backend = require('cp.ui.diff') - local backend = diff_backend.get_best_backend('git') - local diff_result = backend.render(expected_content, actual_content) - - if diff_result.raw_diff and diff_result.raw_diff ~= '' then - highlight.parse_and_apply_diff(diff_buf, diff_result.raw_diff, diff_namespace) - else - local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) - update_buffer_content(diff_buf, lines, {}) - end - - return { - buffers = { diff_buf }, - windows = { diff_win }, - cleanup = function() - pcall(vim.api.nvim_win_close, diff_win, true) - pcall(vim.api.nvim_buf_delete, diff_buf, { force = true }) - end, - } - end - - local function create_single_layout(parent_win, content) - local buf = create_buffer_with_options() - local lines = vim.split(content, '\n', { plain = true, trimempty = true }) - update_buffer_content(buf, lines, {}) - - vim.api.nvim_set_current_win(parent_win) - vim.cmd.split() - vim.cmd('resize ' .. math.floor(vim.o.lines * 0.35)) - local win = vim.api.nvim_get_current_win() - vim.api.nvim_win_set_buf(win, buf) - vim.api.nvim_set_option_value('filetype', 'cptest', { buf = buf }) - - return { - buffers = { buf }, - windows = { win }, - cleanup = function() - pcall(vim.api.nvim_win_close, win, true) - pcall(vim.api.nvim_buf_delete, buf, { force = true }) - end, - } - end - - local function create_diff_layout(mode, parent_win, expected_content, actual_content) - if mode == 'single' then - return create_single_layout(parent_win, actual_content) - elseif mode == 'none' then - return create_none_diff_layout(parent_win, expected_content, actual_content) - elseif mode == 'git' then - return create_git_diff_layout(parent_win, expected_content, actual_content) - else - return create_vim_diff_layout(parent_win, expected_content, actual_content) - end - end - - local function update_diff_panes() - local test_state = run.get_run_panel_state() - local current_test = test_state.test_cases[test_state.current_index] - - if not current_test then - return - end - - local expected_content = current_test.expected or '' - local actual_content = current_test.actual or '(not run yet)' - local actual_highlights = current_test.actual_highlights or {} - local is_compilation_failure = current_test.error - and current_test.error:match('Compilation failed') - local should_show_diff = current_test.status == 'fail' - and current_test.actual - and not is_compilation_failure - - if not should_show_diff then - expected_content = expected_content - actual_content = actual_content - end - - local desired_mode = is_compilation_failure and 'single' or config.run_panel.diff_mode - - if current_diff_layout and current_mode ~= desired_mode then - local saved_pos = vim.api.nvim_win_get_cursor(0) - current_diff_layout.cleanup() - current_diff_layout = nil - current_mode = nil - - current_diff_layout = - create_diff_layout(desired_mode, main_win, expected_content, actual_content) - current_mode = desired_mode - - for _, buf in ipairs(current_diff_layout.buffers) do - setup_keybindings_for_buffer(buf) - end - - pcall(vim.api.nvim_win_set_cursor, 0, saved_pos) - return - end - - if not current_diff_layout then - current_diff_layout = - create_diff_layout(desired_mode, main_win, expected_content, actual_content) - current_mode = desired_mode - - for _, buf in ipairs(current_diff_layout.buffers) do - setup_keybindings_for_buffer(buf) - end - else - if desired_mode == 'single' then - local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) - update_buffer_content( - current_diff_layout.buffers[1], - lines, - actual_highlights, - ansi_namespace - ) - elseif desired_mode == 'git' then - local diff_backend = require('cp.ui.diff') - local backend = diff_backend.get_best_backend('git') - local diff_result = backend.render(expected_content, actual_content) - - if diff_result.raw_diff and diff_result.raw_diff ~= '' then - highlight.parse_and_apply_diff( - current_diff_layout.buffers[1], - diff_result.raw_diff, - diff_namespace - ) - else - local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) - update_buffer_content( - current_diff_layout.buffers[1], - lines, - actual_highlights, - ansi_namespace - ) - end - elseif desired_mode == 'none' then - local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) - local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) - update_buffer_content(current_diff_layout.buffers[1], expected_lines, {}) - update_buffer_content( - current_diff_layout.buffers[2], - actual_lines, - actual_highlights, - ansi_namespace - ) - else - local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) - local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) - update_buffer_content(current_diff_layout.buffers[1], expected_lines, {}) - update_buffer_content( - current_diff_layout.buffers[2], - actual_lines, - actual_highlights, - ansi_namespace - ) - - if should_show_diff then - vim.api.nvim_set_option_value('diff', true, { win = current_diff_layout.windows[1] }) - vim.api.nvim_set_option_value('diff', true, { win = current_diff_layout.windows[2] }) - vim.api.nvim_win_call(current_diff_layout.windows[1], function() - vim.cmd.diffthis() - end) - vim.api.nvim_win_call(current_diff_layout.windows[2], function() - vim.cmd.diffthis() - end) - vim.api.nvim_set_option_value('foldcolumn', '0', { win = current_diff_layout.windows[1] }) - vim.api.nvim_set_option_value('foldcolumn', '0', { win = current_diff_layout.windows[2] }) - else - vim.api.nvim_set_option_value('diff', false, { win = current_diff_layout.windows[1] }) - vim.api.nvim_set_option_value('diff', false, { win = current_diff_layout.windows[2] }) - end - end - end - end - - local function refresh_run_panel() - if not test_buffers.tab_buf or not vim.api.nvim_buf_is_valid(test_buffers.tab_buf) then - return - end - - local run_render = require('cp.runner.run_render') - run_render.setup_highlights() - - local test_state = run.get_run_panel_state() - local tab_lines, tab_highlights = run_render.render_test_list(test_state) - update_buffer_content(test_buffers.tab_buf, tab_lines, tab_highlights) - - update_diff_panes() - end - - ---@param delta number 1 for next, -1 for prev - local function navigate_test_case(delta) - local test_state = run.get_run_panel_state() - if #test_state.test_cases == 0 then - return - end - - test_state.current_index = test_state.current_index + delta - if test_state.current_index < 1 then - test_state.current_index = #test_state.test_cases - elseif test_state.current_index > #test_state.test_cases then - test_state.current_index = 1 - end - - refresh_run_panel() - end - - setup_keybindings_for_buffer = function(buf) - vim.keymap.set('n', 'q', function() - toggle_run_panel() - end, { buffer = buf, silent = true }) - vim.keymap.set('n', config.run_panel.toggle_diff_key, function() - local modes = { 'none', 'git', 'vim' } - local current_idx = nil - for i, mode in ipairs(modes) do - if config.run_panel.diff_mode == mode then - current_idx = i - break - end - end - current_idx = current_idx or 1 - config.run_panel.diff_mode = modes[(current_idx % #modes) + 1] - refresh_run_panel() - end, { buffer = buf, silent = true }) - vim.keymap.set('n', config.run_panel.next_test_key, function() - navigate_test_case(1) - end, { buffer = buf, silent = true }) - vim.keymap.set('n', config.run_panel.prev_test_key, function() - navigate_test_case(-1) - end, { buffer = buf, silent = true }) - end - - vim.keymap.set('n', config.run_panel.next_test_key, function() - navigate_test_case(1) - end, { buffer = test_buffers.tab_buf, silent = true }) - vim.keymap.set('n', config.run_panel.prev_test_key, function() - navigate_test_case(-1) - end, { buffer = test_buffers.tab_buf, silent = true }) - - setup_keybindings_for_buffer(test_buffers.tab_buf) - - if config.hooks and config.hooks.before_run then - config.hooks.before_run(ctx) - end - - if is_debug and config.hooks and config.hooks.before_debug then - config.hooks.before_debug(ctx) - end - - local execute = require('cp.runner.execute') - local contest_config = config.contests[state.get_platform()] - local compile_result = execute.compile_problem(ctx, contest_config, is_debug) - if compile_result.success then - run.run_all_test_cases(ctx, contest_config, config) - else - run.handle_compilation_failure(compile_result.output) - end - - refresh_run_panel() - - vim.schedule(function() - if config.run_panel.ansi then - local ansi = require('cp.ui.ansi') - ansi.setup_highlight_groups() - end - if current_diff_layout then - update_diff_panes() - end - end) - - vim.api.nvim_set_current_win(test_windows.tab_win) - - state.run_panel_active = true - state.test_buffers = test_buffers - state.test_windows = test_windows - local test_state = run.get_run_panel_state() - logger.log( - string.format('test panel opened (%d test cases)', #test_state.test_cases), - vim.log.levels.INFO - ) -end - ----@param contest_id string ----@param language? string -local function setup_contest(contest_id, language) - if state.get_platform() == '' then - logger.log('no platform set', vim.log.levels.ERROR) - return false - end - - if not vim.tbl_contains(config.scrapers, state.get_platform()) then - logger.log('scraping disabled for ' .. state.get_platform(), vim.log.levels.WARN) - return false - end - - logger.log(('setting up contest %s %s'):format(state.get_platform(), contest_id)) - - local metadata_result = scrape.scrape_contest_metadata(state.get_platform(), contest_id) - if not metadata_result.success then - logger.log( - 'failed to load contest metadata: ' .. (metadata_result.error or 'unknown error'), - vim.log.levels.ERROR - ) - return false - end - - local problems = metadata_result.problems - if not problems or #problems == 0 then - logger.log('no problems found in contest', vim.log.levels.ERROR) - return false - end - - logger.log(('found %d problems, checking cache...'):format(#problems)) - - cache.load() - local missing_problems = {} - for _, prob in ipairs(problems) do - local cached_tests = cache.get_test_cases(state.get_platform(), contest_id, prob.id) - if not cached_tests then - table.insert(missing_problems, prob) - end - end - - if #missing_problems > 0 then - logger.log(('scraping %d uncached problems...'):format(#missing_problems)) - scrape_missing_problems(contest_id, missing_problems) - else - logger.log('all problems already cached') - end - - state.set_contest_id(contest_id) - setup_problem(contest_id, problems[1].id, language) - - return true -end - ----@param delta number 1 for next, -1 for prev ----@param language? string -local function navigate_problem(delta, language) - if state.get_platform() == '' or state.get_contest_id() == '' then - logger.log('no contest set. run :CP first', vim.log.levels.ERROR) - return - end - - cache.load() - local contest_data = cache.get_contest_data(state.get_platform(), state.get_contest_id()) - if not contest_data or not contest_data.problems then - logger.log( - 'no contest metadata found. set up a problem first to cache contest data', - vim.log.levels.ERROR - ) - return - end - - local problems = contest_data.problems - local current_problem_id = state.get_problem_id() - - if not current_problem_id then - logger.log('no current problem set', vim.log.levels.ERROR) - return - end - - local current_index = nil - for i, prob in ipairs(problems) do - if prob.id == current_problem_id then - current_index = i - break - end - end - - if not current_index then - logger.log('current problem not found in contest', vim.log.levels.ERROR) - return - end - - local new_index = current_index + delta - - if new_index < 1 or new_index > #problems then - local msg = delta > 0 and 'at last problem' or 'at first problem' - logger.log(msg, vim.log.levels.WARN) - return - end - - local new_problem = problems[new_index] - - setup_problem(state.get_contest_id(), new_problem.id, language) -end - -local function handle_pick_action() - if not config.picker then - logger.log( - 'No picker configured. Set picker = "telescope" or picker = "fzf-lua" in config', - vim.log.levels.ERROR - ) - return - end - - if config.picker == 'telescope' then - local ok = pcall(require, 'telescope') - if not ok then - logger.log( - 'Telescope not available. Install telescope.nvim or change picker config', - vim.log.levels.ERROR - ) - return - end - local ok_cp, telescope_cp = pcall(require, 'cp.pickers.telescope') - if not ok_cp then - logger.log('Failed to load telescope integration', vim.log.levels.ERROR) - return - end - telescope_cp.platform_picker() - elseif config.picker == 'fzf-lua' then - local ok, _ = pcall(require, 'fzf-lua') - if not ok then - logger.log( - 'fzf-lua not available. Install fzf-lua or change picker config', - vim.log.levels.ERROR - ) - return - end - local ok_cp, fzf_cp = pcall(require, 'cp.pickers.fzf_lua') - if not ok_cp then - logger.log('Failed to load fzf-lua integration', vim.log.levels.ERROR) - return - end - fzf_cp.platform_picker() - end -end - -local function handle_cache_command(cmd) - if cmd.subcommand == 'clear' then - cache.load() - if cmd.platform then - if vim.tbl_contains(platforms, cmd.platform) then - cache.clear_platform(cmd.platform) - logger.log(('cleared cache for %s'):format(cmd.platform), vim.log.levels.INFO, true) - else - logger.log( - ('unknown platform: %s. Available: %s'):format( - cmd.platform, - table.concat(platforms, ', ') - ), - vim.log.levels.ERROR - ) - end - else - cache.clear_all() - logger.log('cleared all cache', vim.log.levels.INFO, true) - end - end -end - -local function restore_from_current_file() - local current_file = vim.fn.expand('%:p') - if current_file == '' then - logger.log('No file is currently open', vim.log.levels.ERROR) - return false - end - - cache.load() - local file_state = cache.get_file_state(current_file) - if not file_state then - logger.log( - 'No cached state found for current file. Use :CP first.', - vim.log.levels.ERROR - ) - return false - end - - logger.log( - ('Restoring from cached state: %s %s %s'):format( - file_state.platform, - file_state.contest_id, - file_state.problem_id or 'N/A' - ) - ) - - if not set_platform(file_state.platform) then - return false - end - - state.set_contest_id(file_state.contest_id) - state.set_problem_id(file_state.problem_id) - - setup_problem(file_state.contest_id, file_state.problem_id, file_state.language) - - return true -end - -local function parse_command(args) - if #args == 0 then - return { - type = 'restore_from_file', - } - end - - local language = nil - local debug = false - - for i, arg in ipairs(args) do - local lang_match = arg:match('^--lang=(.+)$') - if lang_match then - language = lang_match - elseif arg == '--lang' then - if i + 1 <= #args then - language = args[i + 1] - else - return { type = 'error', message = '--lang requires a value' } - end - elseif arg == '--debug' then - debug = true - end - end - - local filtered_args = vim.tbl_filter(function(arg) - return not (arg:match('^--lang') or arg == language or arg == '--debug') - end, args) - - local first = filtered_args[1] - - if vim.tbl_contains(actions, first) then - if first == 'cache' then - local subcommand = filtered_args[2] - if not subcommand then - return { type = 'error', message = 'cache command requires subcommand: clear' } - end - if subcommand == 'clear' then - local platform = filtered_args[3] - return { - type = 'cache', - subcommand = 'clear', - platform = platform, - } - else - return { type = 'error', message = 'unknown cache subcommand: ' .. subcommand } - end - else - return { type = 'action', action = first, language = language, debug = debug } - end - end - - if vim.tbl_contains(platforms, first) then - if #filtered_args == 1 then - return { - type = 'platform_only', - platform = first, - language = language, - } - elseif #filtered_args == 2 then - return { - type = 'contest_setup', - platform = first, - contest = filtered_args[2], - language = language, - } - elseif #filtered_args == 3 then - return { - type = 'full_setup', - platform = first, - contest = filtered_args[2], - problem = filtered_args[3], - language = language, - } - else - return { type = 'error', message = 'Too many arguments' } - end - end - - if state.get_platform() ~= '' and state.get_contest_id() ~= '' then - cache.load() - local contest_data = cache.get_contest_data(state.get_platform(), state.get_contest_id()) - if contest_data and contest_data.problems then - local problem_ids = vim.tbl_map(function(prob) - return prob.id - end, contest_data.problems) - if vim.tbl_contains(problem_ids, first) then - return { type = 'problem_switch', problem = first, language = language } - end - end - return { - type = 'error', - message = ("invalid subcommand '%s'"):format(first), - } - end - - return { type = 'error', message = 'Unknown command or no contest context' } -end - function M.handle_command(opts) - local cmd = parse_command(opts.fargs) - - if cmd.type == 'error' then - logger.log(cmd.message, vim.log.levels.ERROR) - return - end - - if cmd.type == 'restore_from_file' then - restore_from_current_file() - return - end - - if cmd.type == 'action' then - if cmd.action == 'run' then - toggle_run_panel(cmd.debug) - elseif cmd.action == 'next' then - navigate_problem(1, cmd.language) - elseif cmd.action == 'prev' then - navigate_problem(-1, cmd.language) - elseif cmd.action == 'pick' then - handle_pick_action() - end - return - end - - if cmd.type == 'cache' then - handle_cache_command(cmd) - return - end - - if cmd.type == 'platform_only' then - set_platform(cmd.platform) - return - end - - if cmd.type == 'contest_setup' then - if set_platform(cmd.platform) then - setup_contest(cmd.contest, cmd.language) - end - return - end - - if cmd.type == 'full_setup' then - if set_platform(cmd.platform) then - state.set_contest_id(cmd.contest) - local problem_ids = {} - local has_metadata = false - - if vim.tbl_contains(config.scrapers, cmd.platform) then - local metadata_result = scrape.scrape_contest_metadata(cmd.platform, cmd.contest) - if not metadata_result.success then - logger.log( - 'failed to load contest metadata: ' .. (metadata_result.error or 'unknown error'), - vim.log.levels.ERROR - ) - return - end - - logger.log( - ('loaded %d problems for %s %s'):format( - #metadata_result.problems, - cmd.platform, - cmd.contest - ), - vim.log.levels.INFO, - true - ) - problem_ids = vim.tbl_map(function(prob) - return prob.id - end, metadata_result.problems) - has_metadata = true - else - cache.load() - local contest_data = cache.get_contest_data(cmd.platform, cmd.contest) - if contest_data and contest_data.problems then - problem_ids = vim.tbl_map(function(prob) - return prob.id - end, contest_data.problems) - has_metadata = true - end - end - - if has_metadata and not vim.tbl_contains(problem_ids, cmd.problem) then - logger.log( - ("Invalid problem '%s' for contest %s %s"):format(cmd.problem, cmd.platform, cmd.contest), - vim.log.levels.ERROR - ) - return - end - - setup_problem(cmd.contest, cmd.problem, cmd.language) - end - return - end - - if cmd.type == 'problem_switch' then - setup_problem(state.get_contest_id(), cmd.problem, cmd.language) - return - end + local commands = require('cp.commands') + commands.handle_command(opts) end function M.setup(opts) opts = opts or {} user_config = opts config = config_module.setup(user_config) + config_module.set_current_config(config) + if not snippets_initialized then snippets.setup(config) snippets_initialized = true diff --git a/lua/cp/restore.lua b/lua/cp/restore.lua new file mode 100644 index 0000000..60236b4 --- /dev/null +++ b/lua/cp/restore.lua @@ -0,0 +1,45 @@ +local M = {} + +local cache = require('cp.cache') +local logger = require('cp.log') +local state = require('cp.state') + +function M.restore_from_current_file() + local current_file = vim.fn.expand('%:p') + if current_file == '' then + logger.log('No file is currently open', vim.log.levels.ERROR) + return false + end + + cache.load() + local file_state = cache.get_file_state(current_file) + if not file_state then + logger.log( + 'No cached state found for current file. Use :CP first.', + vim.log.levels.ERROR + ) + return false + end + + logger.log( + ('Restoring from cached state: %s %s %s'):format( + file_state.platform, + file_state.contest_id, + file_state.problem_id or 'N/A' + ) + ) + + local setup = require('cp.setup') + if not setup.set_platform(file_state.platform) then + return false + end + + state.set_contest_id(file_state.contest_id) + state.set_problem_id(file_state.problem_id) + + setup.setup_problem(file_state.contest_id, file_state.problem_id, file_state.language) + + return true +end + +return M diff --git a/lua/cp/runner/run.lua b/lua/cp/runner/run.lua index f996a60..abe13e3 100644 --- a/lua/cp/runner/run.lua +++ b/lua/cp/runner/run.lua @@ -297,7 +297,11 @@ end ---@param state table ---@return boolean function M.load_test_cases(ctx, state) - local test_cases = parse_test_cases_from_cache(state.platform, state.contest_id, state.problem_id) + local test_cases = parse_test_cases_from_cache( + state.get_platform() or '', + state.get_contest_id() or '', + state.get_problem_id() + ) if #test_cases == 0 then test_cases = parse_test_cases_from_files(ctx.input_file, ctx.expected_file) @@ -305,8 +309,11 @@ function M.load_test_cases(ctx, state) run_panel_state.test_cases = test_cases run_panel_state.current_index = 1 - run_panel_state.constraints = - load_constraints_from_cache(state.platform, state.contest_id, state.problem_id) + run_panel_state.constraints = load_constraints_from_cache( + state.get_platform() or '', + state.get_contest_id() or '', + state.get_problem_id() + ) local constraint_info = run_panel_state.constraints and string.format( diff --git a/lua/cp/setup/contest.lua b/lua/cp/setup/contest.lua new file mode 100644 index 0000000..7649330 --- /dev/null +++ b/lua/cp/setup/contest.lua @@ -0,0 +1,43 @@ +local M = {} + +local logger = require('cp.log') +local scrape = require('cp.scrape') +local state = require('cp.state') + +function M.scrape_missing_problems(contest_id, missing_problems, config) + vim.fn.mkdir('io', 'p') + + logger.log(('scraping %d uncached problems...'):format(#missing_problems)) + + local results = scrape.scrape_problems_parallel( + state.get_platform() or '', + contest_id, + missing_problems, + config + ) + + local success_count = 0 + local failed_problems = {} + for problem_id, result in pairs(results) do + if result.success then + success_count = success_count + 1 + else + table.insert(failed_problems, problem_id) + end + end + + if #failed_problems > 0 then + logger.log( + ('scraping complete: %d/%d successful, failed: %s'):format( + success_count, + #missing_problems, + table.concat(failed_problems, ', ') + ), + vim.log.levels.WARN + ) + else + logger.log(('scraping complete: %d/%d successful'):format(success_count, #missing_problems)) + end +end + +return M diff --git a/lua/cp/setup/init.lua b/lua/cp/setup/init.lua new file mode 100644 index 0000000..f654e5c --- /dev/null +++ b/lua/cp/setup/init.lua @@ -0,0 +1,260 @@ +local M = {} + +local cache = require('cp.cache') +local config_module = require('cp.config') +local logger = require('cp.log') +local problem = require('cp.problem') +local scrape = require('cp.scrape') +local state = require('cp.state') + +local constants = require('cp.constants') +local platforms = constants.PLATFORMS + +function M.set_platform(platform) + if not vim.tbl_contains(platforms, platform) then + logger.log( + ('unknown platform. Available: [%s]'):format(table.concat(platforms, ', ')), + vim.log.levels.ERROR + ) + return false + end + + state.set_platform(platform) + vim.system({ 'mkdir', '-p', 'build', 'io' }):wait() + return true +end + +function M.setup_problem(contest_id, problem_id, language) + if not state.get_platform() then + logger.log('no platform set. run :CP first', vim.log.levels.ERROR) + return + end + + local config = config_module.get_config() + local problem_name = contest_id .. (problem_id or '') + logger.log(('setting up problem: %s'):format(problem_name)) + + local ctx = + problem.create_context(state.get_platform() or '', contest_id, problem_id, config, language) + + if vim.tbl_contains(config.scrapers, state.get_platform() or '') then + cache.load() + local existing_contest_data = cache.get_contest_data(state.get_platform() or '', contest_id) + + if not existing_contest_data then + local metadata_result = scrape.scrape_contest_metadata(state.get_platform() or '', contest_id) + if not metadata_result.success then + logger.log( + 'failed to load contest metadata: ' .. (metadata_result.error or 'unknown error'), + vim.log.levels.WARN + ) + end + end + end + + local cached_test_cases = cache.get_test_cases(state.get_platform() or '', contest_id, problem_id) + if cached_test_cases then + state.set_test_cases(cached_test_cases) + logger.log(('using cached test cases (%d)'):format(#cached_test_cases)) + elseif vim.tbl_contains(config.scrapers, state.get_platform() or '') then + local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[state.get_platform() or ''] + or (state.get_platform() or '') + logger.log( + ('Scraping %s %s %s for test cases, this may take a few seconds...'):format( + platform_display_name, + contest_id, + problem_id + ), + vim.log.levels.INFO, + true + ) + + local scrape_result = scrape.scrape_problem(ctx) + + if not scrape_result.success then + logger.log( + 'scraping failed: ' .. (scrape_result.error or 'unknown error'), + vim.log.levels.ERROR + ) + return + end + + local test_count = scrape_result.test_count or 0 + logger.log(('scraped %d test case(s) for %s'):format(test_count, scrape_result.problem_id)) + state.set_test_cases(scrape_result.test_cases) + + if scrape_result.test_cases then + cache.set_test_cases( + state.get_platform() or '', + contest_id, + problem_id, + scrape_result.test_cases + ) + end + else + logger.log(('scraping disabled for %s'):format(state.get_platform() or '')) + state.set_test_cases(nil) + end + + vim.cmd('silent only') + state.set_run_panel_active(false) + + state.set_contest_id(contest_id) + state.set_problem_id(problem_id) + + vim.cmd.e(ctx.source_file) + local source_buf = vim.api.nvim_get_current_buf() + + if vim.api.nvim_buf_get_lines(source_buf, 0, -1, true)[1] == '' then + local has_luasnip, luasnip = pcall(require, 'luasnip') + if has_luasnip then + local filetype = vim.api.nvim_get_option_value('filetype', { buf = source_buf }) + local language_name = constants.filetype_to_language[filetype] + local canonical_language = constants.canonical_filetypes[language_name] or language_name + local prefixed_trigger = ('cp.nvim/%s.%s'):format(state.get_platform(), canonical_language) + + vim.api.nvim_buf_set_lines(0, 0, -1, false, { prefixed_trigger }) + vim.api.nvim_win_set_cursor(0, { 1, #prefixed_trigger }) + vim.cmd.startinsert({ bang = true }) + + vim.schedule(function() + if luasnip.expandable() then + luasnip.expand() + else + vim.api.nvim_buf_set_lines(0, 0, 1, false, { '' }) + vim.api.nvim_win_set_cursor(0, { 1, 0 }) + end + vim.cmd.stopinsert() + end) + else + vim.api.nvim_input(('i%s'):format(state.get_platform())) + end + end + + if config.hooks and config.hooks.setup_code then + config.hooks.setup_code(ctx) + end + + cache.set_file_state( + vim.fn.expand('%:p'), + state.get_platform() or '', + contest_id, + problem_id, + language + ) + + logger.log(('switched to problem %s'):format(ctx.problem_name)) +end + +function M.setup_contest(contest_id, language) + if not state.get_platform() then + logger.log('no platform set', vim.log.levels.ERROR) + return false + end + + local config = config_module.get_config() + + if not vim.tbl_contains(config.scrapers, state.get_platform() or '') then + logger.log('scraping disabled for ' .. (state.get_platform() or ''), vim.log.levels.WARN) + return false + end + + logger.log(('setting up contest %s %s'):format(state.get_platform() or '', contest_id)) + + local metadata_result = scrape.scrape_contest_metadata(state.get_platform() or '', contest_id) + if not metadata_result.success then + logger.log( + 'failed to load contest metadata: ' .. (metadata_result.error or 'unknown error'), + vim.log.levels.ERROR + ) + return false + end + + local problems = metadata_result.problems + if not problems or #problems == 0 then + logger.log('no problems found in contest', vim.log.levels.ERROR) + return false + end + + logger.log(('found %d problems, checking cache...'):format(#problems)) + + cache.load() + local missing_problems = {} + for _, prob in ipairs(problems) do + local cached_tests = cache.get_test_cases(state.get_platform() or '', contest_id, prob.id) + if not cached_tests then + table.insert(missing_problems, prob) + end + end + + if #missing_problems > 0 then + local contest_scraper = require('cp.setup.contest') + contest_scraper.scrape_missing_problems(contest_id, missing_problems, config) + else + logger.log('all problems already cached') + end + + state.set_contest_id(contest_id) + M.setup_problem(contest_id, problems[1].id, language) + + return true +end + +function M.navigate_problem(delta, language) + if not state.get_platform() or not state.get_contest_id() then + logger.log('no contest set. run :CP first', vim.log.levels.ERROR) + return + end + + local navigation = require('cp.setup.navigation') + navigation.navigate_problem(delta, language) +end + +function M.handle_full_setup(cmd) + state.set_contest_id(cmd.contest) + local problem_ids = {} + local has_metadata = false + local config = config_module.get_config() + + if vim.tbl_contains(config.scrapers, cmd.platform) then + local metadata_result = scrape.scrape_contest_metadata(cmd.platform, cmd.contest) + if not metadata_result.success then + logger.log( + 'failed to load contest metadata: ' .. (metadata_result.error or 'unknown error'), + vim.log.levels.ERROR + ) + return + end + + logger.log( + ('loaded %d problems for %s %s'):format(#metadata_result.problems, cmd.platform, cmd.contest), + vim.log.levels.INFO, + true + ) + problem_ids = vim.tbl_map(function(prob) + return prob.id + end, metadata_result.problems) + has_metadata = true + else + cache.load() + local contest_data = cache.get_contest_data(cmd.platform or '', cmd.contest) + if contest_data and contest_data.problems then + problem_ids = vim.tbl_map(function(prob) + return prob.id + end, contest_data.problems) + has_metadata = true + end + end + + if has_metadata and not vim.tbl_contains(problem_ids, cmd.problem) then + logger.log( + ("Invalid problem '%s' for contest %s %s"):format(cmd.problem, cmd.platform, cmd.contest), + vim.log.levels.ERROR + ) + return + end + + M.setup_problem(cmd.contest, cmd.problem, cmd.language) +end + +return M diff --git a/lua/cp/setup/navigation.lua b/lua/cp/setup/navigation.lua new file mode 100644 index 0000000..bab857b --- /dev/null +++ b/lua/cp/setup/navigation.lua @@ -0,0 +1,64 @@ +local M = {} + +local cache = require('cp.cache') +local logger = require('cp.log') +local state = require('cp.state') + +local function get_current_problem() + local filename = vim.fn.expand('%:t:r') + if filename == '' then + logger.log('no file open', vim.log.levels.ERROR) + return nil + end + return filename +end + +function M.navigate_problem(delta, language) + cache.load() + local contest_data = + cache.get_contest_data(state.get_platform() or '', state.get_contest_id() or '') + if not contest_data or not contest_data.problems then + logger.log( + 'no contest metadata found. set up a problem first to cache contest data', + vim.log.levels.ERROR + ) + return + end + + local problems = contest_data.problems + local current_problem_id = state.get_problem_id() + + if not current_problem_id then + logger.log('no current problem set', vim.log.levels.ERROR) + return + end + + local current_index = nil + for i, prob in ipairs(problems) do + if prob.id == current_problem_id then + current_index = i + break + end + end + + if not current_index then + logger.log('current problem not found in contest', vim.log.levels.ERROR) + return + end + + local new_index = current_index + delta + + if new_index < 1 or new_index > #problems then + local msg = delta > 0 and 'at last problem' or 'at first problem' + logger.log(msg, vim.log.levels.WARN) + return + end + + local new_problem = problems[new_index] + local setup = require('cp.setup') + setup.setup_problem(state.get_contest_id() or '', new_problem.id, language) +end + +M.get_current_problem = get_current_problem + +return M diff --git a/lua/cp/state.lua b/lua/cp/state.lua index 96f7d2f..ae21fc5 100644 --- a/lua/cp/state.lua +++ b/lua/cp/state.lua @@ -1,8 +1,8 @@ local M = {} local state = { - platform = '', - contest_id = '', + platform = nil, + contest_id = nil, problem_id = nil, test_cases = nil, run_panel_active = false, @@ -66,12 +66,12 @@ function M.get_context() end function M.has_context() - return state.platform ~= '' and state.contest_id ~= '' + return state.platform and state.contest_id end function M.reset() - state.platform = '' - state.contest_id = '' + state.platform = nil + state.contest_id = nil state.problem_id = nil state.test_cases = nil state.run_panel_active = false diff --git a/lua/cp/ui/layouts.lua b/lua/cp/ui/layouts.lua new file mode 100644 index 0000000..f3d6dd3 --- /dev/null +++ b/lua/cp/ui/layouts.lua @@ -0,0 +1,290 @@ +local M = {} + +local buffer_utils = require('cp.utils.buffer') + +local function create_none_diff_layout(parent_win, expected_content, actual_content) + local expected_buf = buffer_utils.create_buffer_with_options() + local actual_buf = buffer_utils.create_buffer_with_options() + + vim.api.nvim_set_current_win(parent_win) + vim.cmd.split() + vim.cmd('resize ' .. math.floor(vim.o.lines * 0.35)) + local actual_win = vim.api.nvim_get_current_win() + vim.api.nvim_win_set_buf(actual_win, actual_buf) + + vim.cmd.vsplit() + local expected_win = vim.api.nvim_get_current_win() + vim.api.nvim_win_set_buf(expected_win, expected_buf) + + vim.api.nvim_set_option_value('filetype', 'cptest', { buf = expected_buf }) + vim.api.nvim_set_option_value('filetype', 'cptest', { buf = actual_buf }) + vim.api.nvim_set_option_value('winbar', 'Expected', { win = expected_win }) + vim.api.nvim_set_option_value('winbar', 'Actual', { win = actual_win }) + + local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) + local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + + buffer_utils.update_buffer_content(expected_buf, expected_lines, {}) + buffer_utils.update_buffer_content(actual_buf, actual_lines, {}) + + return { + buffers = { expected_buf, actual_buf }, + windows = { expected_win, actual_win }, + cleanup = function() + pcall(vim.api.nvim_win_close, expected_win, true) + pcall(vim.api.nvim_win_close, actual_win, true) + pcall(vim.api.nvim_buf_delete, expected_buf, { force = true }) + pcall(vim.api.nvim_buf_delete, actual_buf, { force = true }) + end, + } +end + +local function create_vim_diff_layout(parent_win, expected_content, actual_content) + local expected_buf = buffer_utils.create_buffer_with_options() + local actual_buf = buffer_utils.create_buffer_with_options() + + vim.api.nvim_set_current_win(parent_win) + vim.cmd.split() + vim.cmd('resize ' .. math.floor(vim.o.lines * 0.35)) + local actual_win = vim.api.nvim_get_current_win() + vim.api.nvim_win_set_buf(actual_win, actual_buf) + + vim.cmd.vsplit() + local expected_win = vim.api.nvim_get_current_win() + vim.api.nvim_win_set_buf(expected_win, expected_buf) + + vim.api.nvim_set_option_value('filetype', 'cptest', { buf = expected_buf }) + vim.api.nvim_set_option_value('filetype', 'cptest', { buf = actual_buf }) + vim.api.nvim_set_option_value('winbar', 'Expected', { win = expected_win }) + vim.api.nvim_set_option_value('winbar', 'Actual', { win = actual_win }) + + local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) + local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + + buffer_utils.update_buffer_content(expected_buf, expected_lines, {}) + buffer_utils.update_buffer_content(actual_buf, actual_lines, {}) + + vim.api.nvim_set_option_value('diff', true, { win = expected_win }) + vim.api.nvim_set_option_value('diff', true, { win = actual_win }) + vim.api.nvim_win_call(expected_win, function() + vim.cmd.diffthis() + end) + vim.api.nvim_win_call(actual_win, function() + vim.cmd.diffthis() + end) + vim.api.nvim_set_option_value('foldcolumn', '0', { win = expected_win }) + vim.api.nvim_set_option_value('foldcolumn', '0', { win = actual_win }) + + return { + buffers = { expected_buf, actual_buf }, + windows = { expected_win, actual_win }, + cleanup = function() + pcall(vim.api.nvim_win_close, expected_win, true) + pcall(vim.api.nvim_win_close, actual_win, true) + pcall(vim.api.nvim_buf_delete, expected_buf, { force = true }) + pcall(vim.api.nvim_buf_delete, actual_buf, { force = true }) + end, + } +end + +local function create_git_diff_layout(parent_win, expected_content, actual_content) + local diff_buf = buffer_utils.create_buffer_with_options() + + vim.api.nvim_set_current_win(parent_win) + vim.cmd.split() + vim.cmd('resize ' .. math.floor(vim.o.lines * 0.35)) + local diff_win = vim.api.nvim_get_current_win() + vim.api.nvim_win_set_buf(diff_win, diff_buf) + + vim.api.nvim_set_option_value('filetype', 'cptest', { buf = diff_buf }) + vim.api.nvim_set_option_value('winbar', 'Expected vs Actual', { win = diff_win }) + + local diff_backend = require('cp.ui.diff') + local backend = diff_backend.get_best_backend('git') + local diff_result = backend.render(expected_content, actual_content) + local highlight = require('cp.ui.highlight') + local diff_namespace = highlight.create_namespace() + + if diff_result.raw_diff and diff_result.raw_diff ~= '' then + highlight.parse_and_apply_diff(diff_buf, diff_result.raw_diff, diff_namespace) + else + local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + buffer_utils.update_buffer_content(diff_buf, lines, {}) + end + + return { + buffers = { diff_buf }, + windows = { diff_win }, + cleanup = function() + pcall(vim.api.nvim_win_close, diff_win, true) + pcall(vim.api.nvim_buf_delete, diff_buf, { force = true }) + end, + } +end + +local function create_single_layout(parent_win, content) + local buf = buffer_utils.create_buffer_with_options() + local lines = vim.split(content, '\n', { plain = true, trimempty = true }) + buffer_utils.update_buffer_content(buf, lines, {}) + + vim.api.nvim_set_current_win(parent_win) + vim.cmd.split() + vim.cmd('resize ' .. math.floor(vim.o.lines * 0.35)) + local win = vim.api.nvim_get_current_win() + vim.api.nvim_win_set_buf(win, buf) + vim.api.nvim_set_option_value('filetype', 'cptest', { buf = buf }) + + return { + buffers = { buf }, + windows = { win }, + cleanup = function() + pcall(vim.api.nvim_win_close, win, true) + pcall(vim.api.nvim_buf_delete, buf, { force = true }) + end, + } +end + +function M.create_diff_layout(mode, parent_win, expected_content, actual_content) + if mode == 'single' then + return create_single_layout(parent_win, actual_content) + elseif mode == 'none' then + return create_none_diff_layout(parent_win, expected_content, actual_content) + elseif mode == 'git' then + return create_git_diff_layout(parent_win, expected_content, actual_content) + else + return create_vim_diff_layout(parent_win, expected_content, actual_content) + end +end + +function M.update_diff_panes( + current_diff_layout, + current_mode, + main_win, + run, + config, + setup_keybindings_for_buffer +) + local test_state = run.get_run_panel_state() + local current_test = test_state.test_cases[test_state.current_index] + + if not current_test then + return current_diff_layout, current_mode + end + + local expected_content = current_test.expected or '' + local actual_content = current_test.actual or '(not run yet)' + local actual_highlights = current_test.actual_highlights or {} + local is_compilation_failure = current_test.error + and current_test.error:match('Compilation failed') + local should_show_diff = current_test.status == 'fail' + and current_test.actual + and not is_compilation_failure + + if not should_show_diff then + expected_content = expected_content + actual_content = actual_content + end + + local desired_mode = is_compilation_failure and 'single' or config.run_panel.diff_mode + local highlight = require('cp.ui.highlight') + local diff_namespace = highlight.create_namespace() + local ansi_namespace = vim.api.nvim_create_namespace('cp_ansi_highlights') + + if current_diff_layout and current_mode ~= desired_mode then + local saved_pos = vim.api.nvim_win_get_cursor(0) + current_diff_layout.cleanup() + current_diff_layout = nil + current_mode = nil + + current_diff_layout = + M.create_diff_layout(desired_mode, main_win, expected_content, actual_content) + current_mode = desired_mode + + for _, buf in ipairs(current_diff_layout.buffers) do + setup_keybindings_for_buffer(buf) + end + + pcall(vim.api.nvim_win_set_cursor, 0, saved_pos) + return current_diff_layout, current_mode + end + + if not current_diff_layout then + current_diff_layout = + M.create_diff_layout(desired_mode, main_win, expected_content, actual_content) + current_mode = desired_mode + + for _, buf in ipairs(current_diff_layout.buffers) do + setup_keybindings_for_buffer(buf) + end + else + if desired_mode == 'single' then + local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + buffer_utils.update_buffer_content( + current_diff_layout.buffers[1], + lines, + actual_highlights, + ansi_namespace + ) + elseif desired_mode == 'git' then + local diff_backend = require('cp.ui.diff') + local backend = diff_backend.get_best_backend('git') + local diff_result = backend.render(expected_content, actual_content) + + if diff_result.raw_diff and diff_result.raw_diff ~= '' then + highlight.parse_and_apply_diff( + current_diff_layout.buffers[1], + diff_result.raw_diff, + diff_namespace + ) + else + local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + buffer_utils.update_buffer_content( + current_diff_layout.buffers[1], + lines, + actual_highlights, + ansi_namespace + ) + end + elseif desired_mode == 'none' then + local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) + local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + buffer_utils.update_buffer_content(current_diff_layout.buffers[1], expected_lines, {}) + buffer_utils.update_buffer_content( + current_diff_layout.buffers[2], + actual_lines, + actual_highlights, + ansi_namespace + ) + else + local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) + local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + buffer_utils.update_buffer_content(current_diff_layout.buffers[1], expected_lines, {}) + buffer_utils.update_buffer_content( + current_diff_layout.buffers[2], + actual_lines, + actual_highlights, + ansi_namespace + ) + + if should_show_diff then + vim.api.nvim_set_option_value('diff', true, { win = current_diff_layout.windows[1] }) + vim.api.nvim_set_option_value('diff', true, { win = current_diff_layout.windows[2] }) + vim.api.nvim_win_call(current_diff_layout.windows[1], function() + vim.cmd.diffthis() + end) + vim.api.nvim_win_call(current_diff_layout.windows[2], function() + vim.cmd.diffthis() + end) + vim.api.nvim_set_option_value('foldcolumn', '0', { win = current_diff_layout.windows[1] }) + vim.api.nvim_set_option_value('foldcolumn', '0', { win = current_diff_layout.windows[2] }) + else + vim.api.nvim_set_option_value('diff', false, { win = current_diff_layout.windows[1] }) + vim.api.nvim_set_option_value('diff', false, { win = current_diff_layout.windows[2] }) + end + end + end + + return current_diff_layout, current_mode +end + +return M diff --git a/lua/cp/ui/panel.lua b/lua/cp/ui/panel.lua new file mode 100644 index 0000000..5f3345d --- /dev/null +++ b/lua/cp/ui/panel.lua @@ -0,0 +1,206 @@ +local M = {} + +local buffer_utils = require('cp.utils.buffer') +local config_module = require('cp.config') +local layouts = require('cp.ui.layouts') +local logger = require('cp.log') +local problem = require('cp.problem') +local state = require('cp.state') + +local current_diff_layout = nil +local current_mode = nil + +local function get_current_problem() + local setup_nav = require('cp.setup.navigation') + return setup_nav.get_current_problem() +end + +function M.toggle_run_panel(is_debug) + if state.run_panel_active then + if current_diff_layout then + current_diff_layout.cleanup() + current_diff_layout = nil + current_mode = nil + end + if state.saved_session then + vim.cmd(('source %s'):format(state.saved_session)) + vim.fn.delete(state.saved_session) + state.saved_session = nil + end + + state.set_run_panel_active(false) + logger.log('test panel closed') + return + end + + if not state.get_platform() then + logger.log( + 'No contest configured. Use :CP to set up first.', + vim.log.levels.ERROR + ) + return + end + + local problem_id = get_current_problem() + if not problem_id then + return + end + + local config = config_module.get_config() + local ctx = problem.create_context( + state.get_platform() or '', + state.get_contest_id() or '', + state.get_problem_id(), + config + ) + local run = require('cp.runner.run') + + if not run.load_test_cases(ctx, state) then + logger.log('no test cases found', vim.log.levels.WARN) + return + end + + state.saved_session = vim.fn.tempname() + vim.cmd(('mksession! %s'):format(state.saved_session)) + + vim.cmd('silent only') + + local tab_buf = buffer_utils.create_buffer_with_options() + local main_win = vim.api.nvim_get_current_win() + vim.api.nvim_win_set_buf(main_win, tab_buf) + vim.api.nvim_set_option_value('filetype', 'cptest', { buf = tab_buf }) + + local test_windows = { + tab_win = main_win, + } + local test_buffers = { + tab_buf = tab_buf, + } + + local test_list_namespace = vim.api.nvim_create_namespace('cp_test_list') + + local setup_keybindings_for_buffer + + local function update_diff_panes() + current_diff_layout, current_mode = layouts.update_diff_panes( + current_diff_layout, + current_mode, + main_win, + run, + config, + setup_keybindings_for_buffer + ) + end + + local function refresh_run_panel() + if not test_buffers.tab_buf or not vim.api.nvim_buf_is_valid(test_buffers.tab_buf) then + return + end + + local run_render = require('cp.runner.run_render') + run_render.setup_highlights() + + local test_state = run.get_run_panel_state() + local tab_lines, tab_highlights = run_render.render_test_list(test_state) + buffer_utils.update_buffer_content( + test_buffers.tab_buf, + tab_lines, + tab_highlights, + test_list_namespace + ) + + update_diff_panes() + end + + local function navigate_test_case(delta) + local test_state = run.get_run_panel_state() + if #test_state.test_cases == 0 then + return + end + + test_state.current_index = test_state.current_index + delta + if test_state.current_index < 1 then + test_state.current_index = #test_state.test_cases + elseif test_state.current_index > #test_state.test_cases then + test_state.current_index = 1 + end + + refresh_run_panel() + end + + setup_keybindings_for_buffer = function(buf) + vim.keymap.set('n', 'q', function() + M.toggle_run_panel() + end, { buffer = buf, silent = true }) + vim.keymap.set('n', config.run_panel.toggle_diff_key, function() + local modes = { 'none', 'git', 'vim' } + local current_idx = nil + for i, mode in ipairs(modes) do + if config.run_panel.diff_mode == mode then + current_idx = i + break + end + end + current_idx = current_idx or 1 + config.run_panel.diff_mode = modes[(current_idx % #modes) + 1] + refresh_run_panel() + end, { buffer = buf, silent = true }) + vim.keymap.set('n', config.run_panel.next_test_key, function() + navigate_test_case(1) + end, { buffer = buf, silent = true }) + vim.keymap.set('n', config.run_panel.prev_test_key, function() + navigate_test_case(-1) + end, { buffer = buf, silent = true }) + end + + vim.keymap.set('n', config.run_panel.next_test_key, function() + navigate_test_case(1) + end, { buffer = test_buffers.tab_buf, silent = true }) + vim.keymap.set('n', config.run_panel.prev_test_key, function() + navigate_test_case(-1) + end, { buffer = test_buffers.tab_buf, silent = true }) + + setup_keybindings_for_buffer(test_buffers.tab_buf) + + if config.hooks and config.hooks.before_run then + config.hooks.before_run(ctx) + end + + if is_debug and config.hooks and config.hooks.before_debug then + config.hooks.before_debug(ctx) + end + + local execute = require('cp.runner.execute') + local contest_config = config.contests[state.get_platform() or ''] + local compile_result = execute.compile_problem(ctx, contest_config, is_debug) + if compile_result.success then + run.run_all_test_cases(ctx, contest_config, config) + else + run.handle_compilation_failure(compile_result.output) + end + + refresh_run_panel() + + vim.schedule(function() + if config.run_panel.ansi then + local ansi = require('cp.ui.ansi') + ansi.setup_highlight_groups() + end + if current_diff_layout then + update_diff_panes() + end + end) + + vim.api.nvim_set_current_win(test_windows.tab_win) + + state.run_panel_active = true + state.test_buffers = test_buffers + state.test_windows = test_windows + local test_state = run.get_run_panel_state() + logger.log( + string.format('test panel opened (%d test cases)', #test_state.test_cases), + vim.log.levels.INFO + ) +end + +return M diff --git a/lua/cp/utils/buffer.lua b/lua/cp/utils/buffer.lua new file mode 100644 index 0000000..759e94c --- /dev/null +++ b/lua/cp/utils/buffer.lua @@ -0,0 +1,29 @@ +local M = {} + +function M.create_buffer_with_options(filetype) + local buf = vim.api.nvim_create_buf(false, true) + vim.api.nvim_set_option_value('bufhidden', 'wipe', { buf = buf }) + vim.api.nvim_set_option_value('readonly', true, { buf = buf }) + vim.api.nvim_set_option_value('modifiable', false, { buf = buf }) + if filetype then + vim.api.nvim_set_option_value('filetype', filetype, { buf = buf }) + end + return buf +end + +function M.update_buffer_content(bufnr, lines, highlights, namespace) + local was_readonly = vim.api.nvim_get_option_value('readonly', { buf = bufnr }) + + vim.api.nvim_set_option_value('readonly', false, { buf = bufnr }) + vim.api.nvim_set_option_value('modifiable', true, { buf = bufnr }) + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) + vim.api.nvim_set_option_value('modifiable', false, { buf = bufnr }) + vim.api.nvim_set_option_value('readonly', was_readonly, { buf = bufnr }) + + if highlights and namespace then + local highlight = require('cp.ui.highlight') + highlight.apply_highlights(bufnr, highlights, namespace) + end +end + +return M diff --git a/scrapers/__init__.py b/scrapers/__init__.py index e69de29..8de8c42 100644 --- a/scrapers/__init__.py +++ b/scrapers/__init__.py @@ -0,0 +1,66 @@ +from .atcoder import AtCoderScraper +from .base import BaseScraper, ScraperConfig +from .codeforces import CodeforcesScraper +from .cses import CSESScraper +from .models import ( + ContestListResult, + ContestSummary, + MetadataResult, + ProblemSummary, + TestCase, + TestsResult, +) + +ALL_SCRAPERS: dict[str, type[BaseScraper]] = { + "atcoder": AtCoderScraper, + "codeforces": CodeforcesScraper, + "cses": CSESScraper, +} + +_SCRAPER_CLASSES = [ + "AtCoderScraper", + "CodeforcesScraper", + "CSESScraper", +] + +_BASE_EXPORTS = [ + "BaseScraper", + "ScraperConfig", + "ContestListResult", + "ContestSummary", + "MetadataResult", + "ProblemSummary", + "TestCase", + "TestsResult", +] + +_REGISTRY_FUNCTIONS = [ + "get_scraper", + "list_platforms", + "ALL_SCRAPERS", +] + +__all__ = _BASE_EXPORTS + _SCRAPER_CLASSES + _REGISTRY_FUNCTIONS + +_exported_types = ( + ScraperConfig, + ContestListResult, + ContestSummary, + MetadataResult, + ProblemSummary, + TestCase, + TestsResult, +) + + +def get_scraper(platform: str) -> type[BaseScraper]: + if platform not in ALL_SCRAPERS: + available = ", ".join(ALL_SCRAPERS.keys()) + raise KeyError( + f"Unknown platform '{platform}'. Available platforms: {available}" + ) + return ALL_SCRAPERS[platform] + + +def list_platforms() -> list[str]: + return list(ALL_SCRAPERS.keys()) diff --git a/scrapers/atcoder.py b/scrapers/atcoder.py index 1935c6e..20cc3d3 100644 --- a/scrapers/atcoder.py +++ b/scrapers/atcoder.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import concurrent.futures import json import re import sys @@ -9,6 +10,7 @@ import backoff import requests from bs4 import BeautifulSoup, Tag +from .base import BaseScraper from .models import ( ContestListResult, ContestSummary, @@ -167,8 +169,6 @@ def scrape(url: str) -> list[TestCase]: def scrape_contests() -> list[ContestSummary]: - import concurrent.futures - def get_max_pages() -> int: try: headers = { @@ -296,6 +296,101 @@ def scrape_contests() -> list[ContestSummary]: return all_contests +class AtCoderScraper(BaseScraper): + @property + def platform_name(self) -> str: + return "atcoder" + + def scrape_contest_metadata(self, contest_id: str) -> MetadataResult: + return self._safe_execute("metadata", self._scrape_metadata_impl, contest_id) + + def scrape_problem_tests(self, contest_id: str, problem_id: str) -> TestsResult: + return self._safe_execute( + "tests", self._scrape_tests_impl, contest_id, problem_id + ) + + def scrape_contest_list(self) -> ContestListResult: + return self._safe_execute("contests", self._scrape_contests_impl) + + def _safe_execute(self, operation: str, func, *args): + try: + return func(*args) + except Exception as e: + error_msg = f"{self.platform_name}: {str(e)}" + + if operation == "metadata": + return MetadataResult(success=False, error=error_msg) + elif operation == "tests": + return TestsResult( + success=False, + error=error_msg, + problem_id="", + url="", + tests=[], + timeout_ms=0, + memory_mb=0, + ) + elif operation == "contests": + return ContestListResult(success=False, error=error_msg) + + def _scrape_metadata_impl(self, contest_id: str) -> MetadataResult: + problems = scrape_contest_problems(contest_id) + if not problems: + return MetadataResult( + success=False, + error=f"{self.platform_name}: No problems found for contest {contest_id}", + ) + return MetadataResult( + success=True, error="", contest_id=contest_id, problems=problems + ) + + def _scrape_tests_impl(self, contest_id: str, problem_id: str) -> TestsResult: + problem_letter = problem_id.upper() + url = parse_problem_url(contest_id, problem_letter) + tests = scrape(url) + + response = requests.get( + url, + headers={ + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" + }, + timeout=10, + ) + response.raise_for_status() + + soup = BeautifulSoup(response.text, "html.parser") + timeout_ms, memory_mb = extract_problem_limits(soup) + + if not tests: + return TestsResult( + success=False, + error=f"{self.platform_name}: No tests found for {contest_id} {problem_letter}", + problem_id=f"{contest_id}_{problem_id.lower()}", + url=url, + tests=[], + timeout_ms=timeout_ms, + memory_mb=memory_mb, + ) + + return TestsResult( + success=True, + error="", + problem_id=f"{contest_id}_{problem_id.lower()}", + url=url, + tests=tests, + timeout_ms=timeout_ms, + memory_mb=memory_mb, + ) + + def _scrape_contests_impl(self) -> ContestListResult: + contests = scrape_contests() + if not contests: + return ContestListResult( + success=False, error=f"{self.platform_name}: No contests found" + ) + return ContestListResult(success=True, error="", contests=contests) + + def main() -> None: if len(sys.argv) < 2: result = MetadataResult( @@ -306,6 +401,7 @@ def main() -> None: sys.exit(1) mode: str = sys.argv[1] + scraper = AtCoderScraper() if mode == "metadata": if len(sys.argv) != 3: @@ -317,23 +413,10 @@ def main() -> None: sys.exit(1) contest_id: str = sys.argv[2] - problems: list[ProblemSummary] = scrape_contest_problems(contest_id) - - if not problems: - result = MetadataResult( - success=False, - error=f"No problems found for contest {contest_id}", - ) - print(json.dumps(asdict(result))) - sys.exit(1) - - result = MetadataResult( - success=True, - error="", - contest_id=contest_id, - problems=problems, - ) + result = scraper.scrape_contest_metadata(contest_id) print(json.dumps(asdict(result))) + if not result.success: + sys.exit(1) elif mode == "tests": if len(sys.argv) != 4: @@ -351,55 +434,10 @@ def main() -> None: test_contest_id: str = sys.argv[2] problem_letter: str = sys.argv[3] - problem_id: str = f"{test_contest_id}_{problem_letter.lower()}" - - url: str = parse_problem_url(test_contest_id, problem_letter) - tests: list[TestCase] = scrape(url) - - try: - headers = { - "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" - } - response = requests.get(url, headers=headers, timeout=10) - response.raise_for_status() - soup = BeautifulSoup(response.text, "html.parser") - timeout_ms, memory_mb = extract_problem_limits(soup) - except Exception as e: - tests_result = TestsResult( - success=False, - error=f"Failed to extract constraints: {e}", - problem_id=problem_id, - url=url, - tests=[], - timeout_ms=0, - memory_mb=0, - ) - print(json.dumps(asdict(tests_result))) - sys.exit(1) - - if not tests: - tests_result = TestsResult( - success=False, - error=f"No tests found for {test_contest_id} {problem_letter}", - problem_id=problem_id, - url=url, - tests=[], - timeout_ms=timeout_ms, - memory_mb=memory_mb, - ) - print(json.dumps(asdict(tests_result))) - sys.exit(1) - - tests_result = TestsResult( - success=True, - error="", - problem_id=problem_id, - url=url, - tests=tests, - timeout_ms=timeout_ms, - memory_mb=memory_mb, - ) + tests_result = scraper.scrape_problem_tests(test_contest_id, problem_letter) print(json.dumps(asdict(tests_result))) + if not tests_result.success: + sys.exit(1) elif mode == "contests": if len(sys.argv) != 2: @@ -409,14 +447,10 @@ def main() -> None: print(json.dumps(asdict(contest_result))) sys.exit(1) - contests = scrape_contests() - if not contests: - contest_result = ContestListResult(success=False, error="No contests found") - print(json.dumps(asdict(contest_result))) - sys.exit(1) - - contest_result = ContestListResult(success=True, error="", contests=contests) + contest_result = scraper.scrape_contest_list() print(json.dumps(asdict(contest_result))) + if not contest_result.success: + sys.exit(1) else: result = MetadataResult( diff --git a/scrapers/base.py b/scrapers/base.py new file mode 100644 index 0000000..c8336a8 --- /dev/null +++ b/scrapers/base.py @@ -0,0 +1,72 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from .models import ContestListResult, MetadataResult, TestsResult + + +@dataclass +class ScraperConfig: + timeout_seconds: int = 30 + max_retries: int = 3 + backoff_base: float = 2.0 + rate_limit_delay: float = 1.0 + + +class BaseScraper(ABC): + def __init__(self, config: ScraperConfig | None = None): + self.config = config or ScraperConfig() + + @property + @abstractmethod + def platform_name(self) -> str: ... + + @abstractmethod + def scrape_contest_metadata(self, contest_id: str) -> MetadataResult: ... + + @abstractmethod + def scrape_problem_tests(self, contest_id: str, problem_id: str) -> TestsResult: ... + + @abstractmethod + def scrape_contest_list(self) -> ContestListResult: ... + + def _create_metadata_error( + self, error_msg: str, contest_id: str = "" + ) -> MetadataResult: + return MetadataResult( + success=False, + error=f"{self.platform_name}: {error_msg}", + contest_id=contest_id, + ) + + def _create_tests_error( + self, error_msg: str, problem_id: str = "", url: str = "" + ) -> TestsResult: + return TestsResult( + success=False, + error=f"{self.platform_name}: {error_msg}", + problem_id=problem_id, + url=url, + tests=[], + timeout_ms=0, + memory_mb=0, + ) + + def _create_contests_error(self, error_msg: str) -> ContestListResult: + return ContestListResult( + success=False, error=f"{self.platform_name}: {error_msg}" + ) + + def _safe_execute(self, operation: str, func, *args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + if operation == "metadata": + contest_id = args[0] if args else "" + return self._create_metadata_error(str(e), contest_id) + elif operation == "tests": + problem_id = args[1] if len(args) > 1 else "" + return self._create_tests_error(str(e), problem_id) + elif operation == "contests": + return self._create_contests_error(str(e)) + else: + raise diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index 89d568e..0ec1958 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -8,6 +8,7 @@ from dataclasses import asdict import cloudscraper from bs4 import BeautifulSoup, Tag +from .base import BaseScraper from .models import ( ContestListResult, ContestSummary, @@ -18,6 +19,69 @@ from .models import ( ) +class CodeforcesScraper(BaseScraper): + @property + def platform_name(self) -> str: + return "codeforces" + + def scrape_contest_metadata(self, contest_id: str) -> MetadataResult: + return self._safe_execute( + "metadata", self._scrape_contest_metadata_impl, contest_id + ) + + def scrape_problem_tests(self, contest_id: str, problem_id: str) -> TestsResult: + return self._safe_execute( + "tests", self._scrape_problem_tests_impl, contest_id, problem_id + ) + + def scrape_contest_list(self) -> ContestListResult: + return self._safe_execute("contests", self._scrape_contest_list_impl) + + def _scrape_contest_metadata_impl(self, contest_id: str) -> MetadataResult: + problems = scrape_contest_problems(contest_id) + if not problems: + return self._create_metadata_error( + f"No problems found for contest {contest_id}", contest_id + ) + return MetadataResult( + success=True, error="", contest_id=contest_id, problems=problems + ) + + def _scrape_problem_tests_impl( + self, contest_id: str, problem_letter: str + ) -> TestsResult: + problem_id = contest_id + problem_letter.lower() + url = parse_problem_url(contest_id, problem_letter) + tests = scrape_sample_tests(url) + + scraper = cloudscraper.create_scraper() + response = scraper.get(url, timeout=self.config.timeout_seconds) + response.raise_for_status() + soup = BeautifulSoup(response.text, "html.parser") + timeout_ms, memory_mb = extract_problem_limits(soup) + + if not tests: + return self._create_tests_error( + f"No tests found for {contest_id} {problem_letter}", problem_id, url + ) + + return TestsResult( + success=True, + error="", + problem_id=problem_id, + url=url, + tests=tests, + timeout_ms=timeout_ms, + memory_mb=memory_mb, + ) + + def _scrape_contest_list_impl(self) -> ContestListResult: + contests = scrape_contests() + if not contests: + return self._create_contests_error("No contests found") + return ContestListResult(success=True, error="", contests=contests) + + def scrape(url: str) -> list[TestCase]: try: scraper = cloudscraper.create_scraper() @@ -223,28 +287,23 @@ def scrape_sample_tests(url: str) -> list[TestCase]: def scrape_contests() -> list[ContestSummary]: - try: - scraper = cloudscraper.create_scraper() - response = scraper.get("https://codeforces.com/api/contest.list", timeout=10) - response.raise_for_status() + scraper = cloudscraper.create_scraper() + response = scraper.get("https://codeforces.com/api/contest.list", timeout=10) + response.raise_for_status() - data = response.json() - if data["status"] != "OK": - return [] - - contests = [] - for contest in data["result"]: - contest_id = str(contest["id"]) - name = contest["name"] - - contests.append(ContestSummary(id=contest_id, name=name, display_name=name)) - - return contests - - except Exception as e: - print(f"Failed to fetch contests: {e}", file=sys.stderr) + data = response.json() + if data["status"] != "OK": return [] + contests = [] + for contest in data["result"]: + contest_id = str(contest["id"]) + name = contest["name"] + + contests.append(ContestSummary(id=contest_id, name=name, display_name=name)) + + return contests + def main() -> None: if len(sys.argv) < 2: @@ -255,6 +314,7 @@ def main() -> None: print(json.dumps(asdict(result))) sys.exit(1) + scraper = CodeforcesScraper() mode: str = sys.argv[1] if mode == "metadata": @@ -266,18 +326,7 @@ def main() -> None: sys.exit(1) contest_id: str = sys.argv[2] - problems: list[ProblemSummary] = scrape_contest_problems(contest_id) - - if not problems: - result = MetadataResult( - success=False, error=f"No problems found for contest {contest_id}" - ) - print(json.dumps(asdict(result))) - sys.exit(1) - - result = MetadataResult( - success=True, error="", contest_id=contest_id, problems=problems - ) + result = scraper.scrape_contest_metadata(contest_id) print(json.dumps(asdict(result))) elif mode == "tests": @@ -296,52 +345,7 @@ def main() -> None: tests_contest_id: str = sys.argv[2] problem_letter: str = sys.argv[3] - problem_id: str = tests_contest_id + problem_letter.lower() - - url: str = parse_problem_url(tests_contest_id, problem_letter) - tests: list[TestCase] = scrape_sample_tests(url) - - try: - scraper = cloudscraper.create_scraper() - response = scraper.get(url, timeout=10) - response.raise_for_status() - soup = BeautifulSoup(response.text, "html.parser") - timeout_ms, memory_mb = extract_problem_limits(soup) - except Exception as e: - tests_result = TestsResult( - success=False, - error=f"Failed to extract constraints: {e}", - problem_id=problem_id, - url=url, - tests=[], - timeout_ms=0, - memory_mb=0, - ) - print(json.dumps(asdict(tests_result))) - sys.exit(1) - - if not tests: - tests_result = TestsResult( - success=False, - error=f"No tests found for {tests_contest_id} {problem_letter}", - problem_id=problem_id, - url=url, - tests=[], - timeout_ms=timeout_ms, - memory_mb=memory_mb, - ) - print(json.dumps(asdict(tests_result))) - sys.exit(1) - - tests_result = TestsResult( - success=True, - error="", - problem_id=problem_id, - url=url, - tests=tests, - timeout_ms=timeout_ms, - memory_mb=memory_mb, - ) + tests_result = scraper.scrape_problem_tests(tests_contest_id, problem_letter) print(json.dumps(asdict(tests_result))) elif mode == "contests": @@ -352,13 +356,7 @@ def main() -> None: print(json.dumps(asdict(contest_result))) sys.exit(1) - contests = scrape_contests() - if not contests: - contest_result = ContestListResult(success=False, error="No contests found") - print(json.dumps(asdict(contest_result))) - sys.exit(1) - - contest_result = ContestListResult(success=True, error="", contests=contests) + contest_result = scraper.scrape_contest_list() print(json.dumps(asdict(contest_result))) else: diff --git a/scrapers/cses.py b/scrapers/cses.py index 3c5db7a..c9144c6 100755 --- a/scrapers/cses.py +++ b/scrapers/cses.py @@ -9,6 +9,7 @@ import backoff import requests from bs4 import BeautifulSoup, Tag +from .base import BaseScraper from .models import ( ContestListResult, ContestSummary, @@ -322,6 +323,111 @@ def scrape(url: str) -> list[TestCase]: return [] +class CSESScraper(BaseScraper): + @property + def platform_name(self) -> str: + return "cses" + + def scrape_contest_metadata(self, contest_id: str) -> MetadataResult: + return self._safe_execute("metadata", self._scrape_metadata_impl, contest_id) + + def scrape_problem_tests(self, contest_id: str, problem_id: str) -> TestsResult: + return self._safe_execute( + "tests", self._scrape_tests_impl, contest_id, problem_id + ) + + def scrape_contest_list(self) -> ContestListResult: + return self._safe_execute("contests", self._scrape_contests_impl) + + def _safe_execute(self, operation: str, func, *args): + try: + return func(*args) + except Exception as e: + error_msg = f"{self.platform_name}: {str(e)}" + + if operation == "metadata": + return MetadataResult(success=False, error=error_msg) + elif operation == "tests": + return TestsResult( + success=False, + error=error_msg, + problem_id="", + url="", + tests=[], + timeout_ms=0, + memory_mb=0, + ) + elif operation == "contests": + return ContestListResult(success=False, error=error_msg) + + def _scrape_metadata_impl(self, category_id: str) -> MetadataResult: + problems = scrape_category_problems(category_id) + if not problems: + return MetadataResult( + success=False, + error=f"{self.platform_name}: No problems found for category: {category_id}", + ) + return MetadataResult( + success=True, error="", contest_id=category_id, problems=problems + ) + + def _scrape_tests_impl(self, category: str, problem_id: str) -> TestsResult: + url = parse_problem_url(problem_id) + if not url: + return TestsResult( + success=False, + error=f"{self.platform_name}: Invalid problem input: {problem_id}. Use either problem ID (e.g., 1068) or full URL", + problem_id=problem_id if problem_id.isdigit() else "", + url="", + tests=[], + timeout_ms=0, + memory_mb=0, + ) + + tests = scrape(url) + actual_problem_id = ( + problem_id if problem_id.isdigit() else problem_id.split("/")[-1] + ) + + headers = { + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" + } + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + + soup = BeautifulSoup(response.text, "html.parser") + timeout_ms, memory_mb = extract_problem_limits(soup) + + if not tests: + return TestsResult( + success=False, + error=f"{self.platform_name}: No tests found for {problem_id}", + problem_id=actual_problem_id, + url=url, + tests=[], + timeout_ms=timeout_ms, + memory_mb=memory_mb, + ) + + return TestsResult( + success=True, + error="", + problem_id=actual_problem_id, + url=url, + tests=tests, + timeout_ms=timeout_ms, + memory_mb=memory_mb, + ) + + def _scrape_contests_impl(self) -> ContestListResult: + categories = scrape_categories() + if not categories: + return ContestListResult( + success=False, error=f"{self.platform_name}: No contests found" + ) + return ContestListResult(success=True, error="", contests=categories) + + def main() -> None: if len(sys.argv) < 2: result = MetadataResult( @@ -332,6 +438,7 @@ def main() -> None: sys.exit(1) mode: str = sys.argv[1] + scraper = CSESScraper() if mode == "metadata": if len(sys.argv) != 3: @@ -343,18 +450,10 @@ def main() -> None: sys.exit(1) category_id = sys.argv[2] - problems = scrape_category_problems(category_id) - - if not problems: - result = MetadataResult( - success=False, - error=f"No problems found for category: {category_id}", - ) - print(json.dumps(asdict(result))) - return - - result = MetadataResult(success=True, error="", problems=problems) + result = scraper.scrape_contest_metadata(category_id) print(json.dumps(asdict(result))) + if not result.success: + sys.exit(1) elif mode == "tests": if len(sys.argv) != 4: @@ -370,73 +469,12 @@ def main() -> None: print(json.dumps(asdict(tests_result))) sys.exit(1) - problem_input: str = sys.argv[3] - url: str | None = parse_problem_url(problem_input) - - if not url: - tests_result = TestsResult( - success=False, - error=f"Invalid problem input: {problem_input}. Use either problem ID (e.g., 1068) or full URL", - problem_id=problem_input if problem_input.isdigit() else "", - url="", - tests=[], - timeout_ms=0, - memory_mb=0, - ) - print(json.dumps(asdict(tests_result))) - sys.exit(1) - - tests: list[TestCase] = scrape(url) - - problem_id: str = ( - problem_input if problem_input.isdigit() else problem_input.split("/")[-1] - ) - - try: - headers = { - "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" - } - response = requests.get(url, headers=headers, timeout=10) - response.raise_for_status() - soup = BeautifulSoup(response.text, "html.parser") - timeout_ms, memory_mb = extract_problem_limits(soup) - except Exception as e: - tests_result = TestsResult( - success=False, - error=f"Failed to extract constraints: {e}", - problem_id=problem_id, - url=url, - tests=[], - timeout_ms=0, - memory_mb=0, - ) - print(json.dumps(asdict(tests_result))) - sys.exit(1) - - if not tests: - tests_result = TestsResult( - success=False, - error=f"No tests found for {problem_input}", - problem_id=problem_id, - url=url, - tests=[], - timeout_ms=timeout_ms, - memory_mb=memory_mb, - ) - print(json.dumps(asdict(tests_result))) - sys.exit(1) - - test_cases = tests - tests_result = TestsResult( - success=True, - error="", - problem_id=problem_id, - url=url, - tests=test_cases, - timeout_ms=timeout_ms, - memory_mb=memory_mb, - ) + category = sys.argv[2] + problem_id = sys.argv[3] + tests_result = scraper.scrape_problem_tests(category, problem_id) print(json.dumps(asdict(tests_result))) + if not tests_result.success: + sys.exit(1) elif mode == "contests": if len(sys.argv) != 2: @@ -446,14 +484,10 @@ def main() -> None: print(json.dumps(asdict(contest_result))) sys.exit(1) - categories = scrape_categories() - if not categories: - contest_result = ContestListResult(success=False, error="No contests found") - print(json.dumps(asdict(contest_result))) - sys.exit(1) - - contest_result = ContestListResult(success=True, error="", contests=categories) + contest_result = scraper.scrape_contest_list() print(json.dumps(asdict(contest_result))) + if not contest_result.success: + sys.exit(1) else: result = MetadataResult( diff --git a/spec/diff_spec.lua b/spec/diff_spec.lua index 49bd120..31fa395 100644 --- a/spec/diff_spec.lua +++ b/spec/diff_spec.lua @@ -44,57 +44,7 @@ describe('cp.diff', function() end) end) - describe('is_git_available', function() - it('returns true when git command succeeds', function() - local mock_system = stub(vim, 'system') - mock_system.returns({ - wait = function() - return { code = 0 } - end, - }) - - local result = diff.is_git_available() - assert.is_true(result) - - mock_system:revert() - end) - - it('returns false when git command fails', function() - local mock_system = stub(vim, 'system') - mock_system.returns({ - wait = function() - return { code = 1 } - end, - }) - - local result = diff.is_git_available() - assert.is_false(result) - - mock_system:revert() - end) - end) - describe('get_best_backend', function() - it('returns preferred backend when available', function() - local mock_is_available = stub(diff, 'is_git_available') - mock_is_available.returns(true) - - local backend = diff.get_best_backend('git') - assert.equals('git', backend.name) - - mock_is_available:revert() - end) - - it('falls back to vim when git unavailable', function() - local mock_is_available = stub(diff, 'is_git_available') - mock_is_available.returns(false) - - local backend = diff.get_best_backend('git') - assert.equals('vim', backend.name) - - mock_is_available:revert() - end) - it('defaults to vim backend', function() local backend = diff.get_best_backend() assert.equals('vim', backend.name) @@ -124,96 +74,18 @@ describe('cp.diff', function() end) end) - describe('git backend', function() - it('creates temp files for diff', function() - local mock_system = stub(vim, 'system') - local mock_tempname = stub(vim.fn, 'tempname') - local mock_writefile = stub(vim.fn, 'writefile') - local mock_delete = stub(vim.fn, 'delete') - - mock_tempname.returns('/tmp/expected', '/tmp/actual') - mock_system.returns({ - wait = function() - return { code = 1, stdout = 'diff output' } - end, - }) - - local backend = diff.get_backend('git') - backend.render('expected text', 'actual text') - - assert.stub(mock_writefile).was_called(2) - assert.stub(mock_delete).was_called(2) - - mock_system:revert() - mock_tempname:revert() - mock_writefile:revert() - mock_delete:revert() - end) - - it('returns raw diff output', function() - local mock_system = stub(vim, 'system') - local mock_tempname = stub(vim.fn, 'tempname') - local mock_writefile = stub(vim.fn, 'writefile') - local mock_delete = stub(vim.fn, 'delete') - - mock_tempname.returns('/tmp/expected', '/tmp/actual') - mock_system.returns({ - wait = function() - return { code = 1, stdout = 'git diff output' } - end, - }) - - local backend = diff.get_backend('git') - local result = backend.render('expected', 'actual') - - assert.equals('git diff output', result.raw_diff) - - mock_system:revert() - mock_tempname:revert() - mock_writefile:revert() - mock_delete:revert() - end) - - it('handles no differences', function() - local mock_system = stub(vim, 'system') - local mock_tempname = stub(vim.fn, 'tempname') - local mock_writefile = stub(vim.fn, 'writefile') - local mock_delete = stub(vim.fn, 'delete') - - mock_tempname.returns('/tmp/expected', '/tmp/actual') - mock_system.returns({ - wait = function() - return { code = 0 } - end, - }) - - local backend = diff.get_backend('git') - local result = backend.render('same', 'same') - - assert.same({ 'same' }, result.content) - assert.same({}, result.highlights) - - mock_system:revert() - mock_tempname:revert() - mock_writefile:revert() - mock_delete:revert() + describe('is_git_available', function() + it('returns boolean without errors', function() + local result = diff.is_git_available() + assert.equals('boolean', type(result)) end) end) describe('render_diff', function() - it('uses best available backend', function() - local mock_backend = { - render = function() - return {} - end, - } - local mock_get_best = stub(diff, 'get_best_backend') - mock_get_best.returns(mock_backend) - - diff.render_diff('expected', 'actual', 'vim') - - assert.stub(mock_get_best).was_called_with('vim') - mock_get_best:revert() + it('returns result without errors', function() + assert.has_no_errors(function() + diff.render_diff('expected', 'actual', 'vim') + end) end) end) end) diff --git a/spec/error_boundaries_spec.lua b/spec/error_boundaries_spec.lua new file mode 100644 index 0000000..9c711fb --- /dev/null +++ b/spec/error_boundaries_spec.lua @@ -0,0 +1,221 @@ +describe('Error boundary handling', function() + local cp + local state + local logged_messages + + before_each(function() + logged_messages = {} + local mock_logger = { + log = function(msg, level) + table.insert(logged_messages, { msg = msg, level = level }) + end, + set_config = function() end, + } + package.loaded['cp.log'] = mock_logger + + package.loaded['cp.scrape'] = { + scrape_problem = function(ctx) + if ctx.contest_id == 'fail_scrape' then + return { + success = false, + error = 'Network error', + } + end + return { + success = true, + problem_id = ctx.problem_id, + test_cases = { + { input = '1', expected = '2' }, + }, + test_count = 1, + } + end, + scrape_contest_metadata = function(_, contest_id) + if contest_id == 'fail_scrape' then + return { + success = false, + error = 'Network error', + } + end + if contest_id == 'fail_metadata' then + return { + success = false, + error = 'Contest not found', + } + end + return { + success = true, + problems = { + { id = 'a' }, + { id = 'b' }, + }, + } + end, + scrape_problems_parallel = function() + return {} + end, + } + + local cache = require('cp.cache') + cache.load = function() end + cache.set_test_cases = function() end + cache.set_file_state = function() end + cache.get_file_state = function() + return nil + end + cache.get_contest_data = function() + return nil + end + cache.get_test_cases = function() + return {} + end + + if not vim.fn then + vim.fn = {} + end + vim.fn.expand = vim.fn.expand or function() + return '/tmp/test.cpp' + end + vim.fn.mkdir = vim.fn.mkdir or function() end + if not vim.api then + vim.api = {} + end + vim.api.nvim_get_current_buf = vim.api.nvim_get_current_buf or function() + return 1 + end + vim.api.nvim_buf_get_lines = vim.api.nvim_buf_get_lines + or function() + return { '' } + end + if not vim.cmd then + vim.cmd = {} + end + vim.cmd.e = function() end + vim.cmd.only = function() end + if not vim.system then + vim.system = function(_) + return { + wait = function() + return { code = 0 } + end, + } + end + end + + state = require('cp.state') + state.reset() + + cp = require('cp') + cp.setup({ + contests = { + codeforces = { + default_language = 'cpp', + cpp = { extension = 'cpp', test = { 'echo', 'test' } }, + }, + }, + scrapers = { 'codeforces' }, + }) + end) + + after_each(function() + package.loaded['cp.log'] = nil + package.loaded['cp.scrape'] = nil + if state then + state.reset() + end + end) + + it('should handle scraping failures without state corruption', function() + cp.handle_command({ fargs = { 'codeforces', 'fail_scrape', 'a' } }) + + local has_metadata_error = false + for _, log_entry in ipairs(logged_messages) do + if log_entry.msg and log_entry.msg:match('failed to load contest metadata') then + has_metadata_error = true + break + end + end + assert.is_true(has_metadata_error, 'Should log contest metadata failure') + + local context = cp.get_current_context() + assert.equals('codeforces', context.platform) + assert.equals('fail_scrape', context.contest_id) + + assert.has_no_errors(function() + cp.handle_command({ fargs = { 'run' } }) + end) + end) + + it('should handle missing contest data without crashing navigation', function() + state.set_platform('codeforces') + state.set_contest_id('nonexistent') + state.set_problem_id('a') + + assert.has_no_errors(function() + cp.handle_command({ fargs = { 'next' } }) + end) + + local has_nav_error = false + for _, log_entry in ipairs(logged_messages) do + if log_entry.msg and log_entry.msg:match('no contest metadata found') then + has_nav_error = true + break + end + end + assert.is_true(has_nav_error, 'Should log navigation error') + end) + + it('should handle validation errors without crashing', function() + state.reset() + + assert.has_no_errors(function() + cp.handle_command({ fargs = { 'next' } }) + end) + + assert.has_no_errors(function() + cp.handle_command({ fargs = { 'prev' } }) + end) + + assert.has_no_errors(function() + cp.handle_command({ fargs = { 'run' } }) + end) + + local has_validation_error = false + local has_appropriate_errors = 0 + for _, log_entry in ipairs(logged_messages) do + if log_entry.msg and log_entry.msg:match('expected string, got nil') then + has_validation_error = true + elseif + log_entry.msg + and (log_entry.msg:match('no contest set') or log_entry.msg:match('No contest configured')) + then + has_appropriate_errors = has_appropriate_errors + 1 + end + end + + assert.is_false(has_validation_error, 'Should not have validation errors') + assert.is_true(has_appropriate_errors > 0, 'Should have user-facing errors') + end) + + it('should handle partial state gracefully', function() + state.set_platform('codeforces') + + assert.has_no_errors(function() + cp.handle_command({ fargs = { 'run' } }) + end) + + assert.has_no_errors(function() + cp.handle_command({ fargs = { 'next' } }) + end) + + local missing_contest_errors = 0 + for _, log_entry in ipairs(logged_messages) do + if + log_entry.msg and (log_entry.msg:match('no contest') or log_entry.msg:match('No contest')) + then + missing_contest_errors = missing_contest_errors + 1 + end + end + assert.is_true(missing_contest_errors > 0, 'Should report missing contest') + end) +end) diff --git a/spec/extmark_spec.lua b/spec/extmark_spec.lua deleted file mode 100644 index 2b4b25a..0000000 --- a/spec/extmark_spec.lua +++ /dev/null @@ -1,215 +0,0 @@ -describe('extmarks', function() - local spec_helper = require('spec.spec_helper') - local highlight - - before_each(function() - spec_helper.setup() - highlight = require('cp.ui.highlight') - end) - - after_each(function() - spec_helper.teardown() - end) - - describe('buffer deletion', function() - it('clears namespace on buffer delete', function() - local bufnr = 1 - local namespace = 100 - local mock_clear = stub(vim.api, 'nvim_buf_clear_namespace') - local mock_extmark = stub(vim.api, 'nvim_buf_set_extmark') - - highlight.apply_highlights(bufnr, { - { - line = 0, - col_start = 0, - col_end = 5, - highlight_group = 'CpDiffAdded', - }, - }, namespace) - - assert.stub(mock_clear).was_called_with(bufnr, namespace, 0, -1) - mock_clear:revert() - mock_extmark:revert() - end) - - it('handles invalid buffer gracefully', function() - local bufnr = 999 - local namespace = 100 - local mock_clear = stub(vim.api, 'nvim_buf_clear_namespace') - local mock_extmark = stub(vim.api, 'nvim_buf_set_extmark') - - mock_clear.on_call_with(bufnr, namespace, 0, -1).invokes(function() - error('Invalid buffer') - end) - - local success = pcall(highlight.apply_highlights, bufnr, { - { - line = 0, - col_start = 0, - col_end = 5, - highlight_group = 'CpDiffAdded', - }, - }, namespace) - - assert.is_false(success) - mock_clear:revert() - mock_extmark:revert() - end) - end) - - describe('namespace isolation', function() - it('creates unique namespaces', function() - local mock_create = stub(vim.api, 'nvim_create_namespace') - mock_create.on_call_with('cp_diff_highlights').returns(100) - mock_create.on_call_with('cp_test_list').returns(200) - mock_create.on_call_with('cp_ansi_highlights').returns(300) - - local diff_ns = highlight.create_namespace() - local test_ns = vim.api.nvim_create_namespace('cp_test_list') - local ansi_ns = vim.api.nvim_create_namespace('cp_ansi_highlights') - - assert.equals(100, diff_ns) - assert.equals(200, test_ns) - assert.equals(300, ansi_ns) - - mock_create:revert() - end) - - it('clears specific namespace independently', function() - local bufnr = 1 - local ns1 = 100 - local ns2 = 200 - local mock_clear = stub(vim.api, 'nvim_buf_clear_namespace') - local mock_extmark = stub(vim.api, 'nvim_buf_set_extmark') - - highlight.apply_highlights(bufnr, { - { line = 0, col_start = 0, col_end = 5, highlight_group = 'CpDiffAdded' }, - }, ns1) - - highlight.apply_highlights(bufnr, { - { line = 1, col_start = 0, col_end = 3, highlight_group = 'CpDiffRemoved' }, - }, ns2) - - assert.stub(mock_clear).was_called_with(bufnr, ns1, 0, -1) - assert.stub(mock_clear).was_called_with(bufnr, ns2, 0, -1) - assert.stub(mock_clear).was_called(2) - - mock_clear:revert() - mock_extmark:revert() - end) - end) - - describe('multiple updates', function() - it('clears previous extmarks on each update', function() - local bufnr = 1 - local namespace = 100 - local mock_clear = stub(vim.api, 'nvim_buf_clear_namespace') - local mock_extmark = stub(vim.api, 'nvim_buf_set_extmark') - - highlight.apply_highlights(bufnr, { - { line = 0, col_start = 0, col_end = 5, highlight_group = 'CpDiffAdded' }, - }, namespace) - - highlight.apply_highlights(bufnr, { - { line = 1, col_start = 0, col_end = 3, highlight_group = 'CpDiffRemoved' }, - }, namespace) - - assert.stub(mock_clear).was_called(2) - assert.stub(mock_clear).was_called_with(bufnr, namespace, 0, -1) - assert.stub(mock_extmark).was_called(2) - - mock_clear:revert() - mock_extmark:revert() - end) - - it('handles empty highlights', function() - local bufnr = 1 - local namespace = 100 - local mock_clear = stub(vim.api, 'nvim_buf_clear_namespace') - local mock_extmark = stub(vim.api, 'nvim_buf_set_extmark') - - highlight.apply_highlights(bufnr, { - { line = 0, col_start = 0, col_end = 5, highlight_group = 'CpDiffAdded' }, - }, namespace) - - highlight.apply_highlights(bufnr, {}, namespace) - - assert.stub(mock_clear).was_called(2) - assert.stub(mock_extmark).was_called(1) - - mock_clear:revert() - mock_extmark:revert() - end) - - it('skips invalid highlights', function() - local bufnr = 1 - local namespace = 100 - local mock_clear = stub(vim.api, 'nvim_buf_clear_namespace') - local mock_extmark = stub(vim.api, 'nvim_buf_set_extmark') - - highlight.apply_highlights(bufnr, { - { line = 0, col_start = 5, col_end = 5, highlight_group = 'CpDiffAdded' }, - { line = 1, col_start = 7, col_end = 3, highlight_group = 'CpDiffAdded' }, - { line = 2, col_start = 0, col_end = 5, highlight_group = 'CpDiffAdded' }, - }, namespace) - - assert.stub(mock_clear).was_called_with(bufnr, namespace, 0, -1) - assert.stub(mock_extmark).was_called(1) - assert.stub(mock_extmark).was_called_with(bufnr, namespace, 2, 0, { - end_col = 5, - hl_group = 'CpDiffAdded', - priority = 100, - }) - - mock_clear:revert() - mock_extmark:revert() - end) - end) - - describe('error handling', function() - it('fails when clear_namespace fails', function() - local bufnr = 1 - local namespace = 100 - local mock_clear = stub(vim.api, 'nvim_buf_clear_namespace') - local mock_extmark = stub(vim.api, 'nvim_buf_set_extmark') - - mock_clear.on_call_with(bufnr, namespace, 0, -1).invokes(function() - error('Namespace clear failed') - end) - - local success = pcall(highlight.apply_highlights, bufnr, { - { line = 0, col_start = 0, col_end = 5, highlight_group = 'CpDiffAdded' }, - }, namespace) - - assert.is_false(success) - assert.stub(mock_extmark).was_not_called() - - mock_clear:revert() - mock_extmark:revert() - end) - end) - - describe('parse_and_apply_diff cleanup', function() - it('clears namespace before applying parsed diff', function() - local bufnr = 1 - local namespace = 100 - local mock_clear = stub(vim.api, 'nvim_buf_clear_namespace') - local mock_extmark = stub(vim.api, 'nvim_buf_set_extmark') - local mock_set_lines = stub(vim.api, 'nvim_buf_set_lines') - local mock_get_option = stub(vim.api, 'nvim_get_option_value') - local mock_set_option = stub(vim.api, 'nvim_set_option_value') - - mock_get_option.returns(false) - - highlight.parse_and_apply_diff(bufnr, '+hello {+world+}', namespace) - - assert.stub(mock_clear).was_called_with(bufnr, namespace, 0, -1) - - mock_clear:revert() - mock_extmark:revert() - mock_set_lines:revert() - mock_get_option:revert() - mock_set_option:revert() - end) - end) -end) diff --git a/spec/highlight_spec.lua b/spec/highlight_spec.lua index 9afd773..8897cc4 100644 --- a/spec/highlight_spec.lua +++ b/spec/highlight_spec.lua @@ -60,22 +60,15 @@ index 1234567..abcdefg 100644 end) describe('apply_highlights', function() - it('clears existing highlights', function() - local mock_clear = spy.on(vim.api, 'nvim_buf_clear_namespace') - local bufnr = 1 - local namespace = 100 - - highlight.apply_highlights(bufnr, {}, namespace) - - assert.spy(mock_clear).was_called_with(bufnr, namespace, 0, -1) - mock_clear:revert() + it('handles empty highlights without errors', function() + local namespace = highlight.create_namespace() + assert.has_no_errors(function() + highlight.apply_highlights(1, {}, namespace) + end) end) - it('applies extmarks with correct positions', function() - local mock_extmark = stub(vim.api, 'nvim_buf_set_extmark') - local mock_clear = stub(vim.api, 'nvim_buf_clear_namespace') - local bufnr = 1 - local namespace = 100 + it('handles valid highlight data without errors', function() + vim.api.nvim_buf_set_lines(1, 0, -1, false, { 'hello world test line' }) local highlights = { { line = 0, @@ -84,109 +77,31 @@ index 1234567..abcdefg 100644 highlight_group = 'CpDiffAdded', }, } - - highlight.apply_highlights(bufnr, highlights, namespace) - - assert.stub(mock_extmark).was_called_with(bufnr, namespace, 0, 5, { - end_col = 10, - hl_group = 'CpDiffAdded', - priority = 100, - }) - mock_extmark:revert() - mock_clear:revert() - end) - - it('uses correct highlight groups', function() - local mock_extmark = stub(vim.api, 'nvim_buf_set_extmark') - local mock_clear = stub(vim.api, 'nvim_buf_clear_namespace') - local highlights = { - { - line = 0, - col_start = 0, - col_end = 5, - highlight_group = 'CpDiffAdded', - }, - } - - highlight.apply_highlights(1, highlights, 100) - - assert.stub(mock_extmark).was_called_with(1, 100, 0, 0, { - end_col = 5, - hl_group = 'CpDiffAdded', - priority = 100, - }) - mock_extmark:revert() - mock_clear:revert() - end) - - it('handles empty highlights', function() - local mock_extmark = stub(vim.api, 'nvim_buf_set_extmark') - local mock_clear = stub(vim.api, 'nvim_buf_clear_namespace') - - highlight.apply_highlights(1, {}, 100) - - assert.stub(mock_extmark).was_not_called() - mock_extmark:revert() - mock_clear:revert() + local namespace = highlight.create_namespace() + assert.has_no_errors(function() + highlight.apply_highlights(1, highlights, namespace) + end) end) end) describe('create_namespace', function() - it('creates unique namespace', function() - local mock_create = stub(vim.api, 'nvim_create_namespace') - mock_create.returns(42) - + it('returns a number', function() local result = highlight.create_namespace() - - assert.equals(42, result) - assert.stub(mock_create).was_called_with('cp_diff_highlights') - mock_create:revert() + assert.equals('number', type(result)) end) end) describe('parse_and_apply_diff', function() - it('parses diff and applies to buffer', function() - local mock_set_lines = stub(vim.api, 'nvim_buf_set_lines') - local mock_apply = stub(highlight, 'apply_highlights') - local bufnr = 1 - local namespace = 100 - local diff_output = '+hello {+world+}' - - local result = highlight.parse_and_apply_diff(bufnr, diff_output, namespace) - - assert.same({ 'hello world' }, result) - assert.stub(mock_set_lines).was_called_with(bufnr, 0, -1, false, { 'hello world' }) - assert.stub(mock_apply).was_called() - - mock_set_lines:revert() - mock_apply:revert() - end) - - it('sets buffer content', function() - local mock_set_lines = stub(vim.api, 'nvim_buf_set_lines') - local mock_apply = stub(highlight, 'apply_highlights') - - highlight.parse_and_apply_diff(1, '+test line', 100) - - assert.stub(mock_set_lines).was_called_with(1, 0, -1, false, { 'test line' }) - mock_set_lines:revert() - mock_apply:revert() - end) - - it('applies highlights', function() - local mock_set_lines = stub(vim.api, 'nvim_buf_set_lines') - local mock_apply = stub(highlight, 'apply_highlights') - - highlight.parse_and_apply_diff(1, '+hello {+world+}', 100) - - assert.stub(mock_apply).was_called() - mock_set_lines:revert() - mock_apply:revert() - end) - it('returns content lines', function() - local result = highlight.parse_and_apply_diff(1, '+first\n+second', 100) + local namespace = highlight.create_namespace() + local result = highlight.parse_and_apply_diff(1, '+first\n+second', namespace) assert.same({ 'first', 'second' }, result) end) + + it('handles empty diff', function() + local namespace = highlight.create_namespace() + local result = highlight.parse_and_apply_diff(1, '', namespace) + assert.same({}, result) + end) end) end) diff --git a/spec/panel_spec.lua b/spec/panel_spec.lua new file mode 100644 index 0000000..ff24e16 --- /dev/null +++ b/spec/panel_spec.lua @@ -0,0 +1,80 @@ +describe('Panel integration', function() + local spec_helper = require('spec.spec_helper') + local cp + local state + + before_each(function() + spec_helper.setup_full() + spec_helper.mock_scraper_success() + + state = require('cp.state') + state.reset() + + cp = require('cp') + cp.setup({ + contests = { + codeforces = { + default_language = 'cpp', + cpp = { extension = 'cpp', test = { 'echo', 'test' } }, + }, + }, + scrapers = { 'codeforces' }, + }) + end) + + after_each(function() + spec_helper.teardown() + if state then + state.reset() + end + end) + + it('should handle run command with properly set contest context', function() + cp.handle_command({ fargs = { 'codeforces', '2146', 'b' } }) + + local context = cp.get_current_context() + assert.equals('codeforces', context.platform) + assert.equals('2146', context.contest_id) + assert.equals('b', context.problem_id) + + assert.has_no_errors(function() + cp.handle_command({ fargs = { 'run' } }) + end) + + local has_validation_error = false + for _, log_entry in ipairs(spec_helper.logged_messages) do + if + log_entry.level == vim.log.levels.ERROR + and log_entry.msg:match('expected string, got nil') + then + has_validation_error = true + break + end + end + assert.is_false(has_validation_error) + end) + + it('should handle state module interface correctly', function() + local run = require('cp.runner.run') + + state.set_platform('codeforces') + state.set_contest_id('2146') + state.set_problem_id('b') + + local problem = require('cp.problem') + local config_module = require('cp.config') + local processed_config = config_module.setup({ + contests = { codeforces = { cpp = { extension = 'cpp' } } }, + }) + local ctx = problem.create_context('codeforces', '2146', 'b', processed_config) + + assert.has_no_errors(function() + run.load_test_cases(ctx, state) + end) + + local fake_state_data = { platform = 'codeforces', contest_id = '2146', problem_id = 'b' } + assert.has_errors(function() + run.load_test_cases(ctx, fake_state_data) + end) + end) +end) diff --git a/spec/picker_spec.lua b/spec/picker_spec.lua index 92b32a2..6fd5a81 100644 --- a/spec/picker_spec.lua +++ b/spec/picker_spec.lua @@ -141,20 +141,24 @@ describe('cp.picker', function() it('falls back to scraping when cache miss', function() local cache = require('cp.cache') - local scrape = require('cp.scrape') cache.load = function() end cache.get_contest_data = function(_, _) return nil end - scrape.scrape_contest_metadata = function(_, _) - return { - success = true, - problems = { - { id = 'x', name = 'Problem X' }, - }, - } - end + + package.loaded['cp.scrape'] = { + scrape_contest_metadata = function(_, _) + return { + success = true, + problems = { + { id = 'x', name = 'Problem X' }, + }, + } + end, + } + + picker = spec_helper.fresh_require('cp.pickers', { 'cp.pickers.init' }) local problems = picker.get_problems_for_contest('test_platform', 'test_contest') assert.is_table(problems) @@ -177,6 +181,8 @@ describe('cp.picker', function() } end + picker = spec_helper.fresh_require('cp.pickers', { 'cp.pickers.init' }) + local problems = picker.get_problems_for_contest('test_platform', 'test_contest') assert.is_table(problems) assert.equals(0, #problems) diff --git a/spec/run_render_spec.lua b/spec/run_render_spec.lua index a647331..72f58c4 100644 --- a/spec/run_render_spec.lua +++ b/spec/run_render_spec.lua @@ -164,17 +164,10 @@ describe('cp.run_render', function() end) describe('setup_highlights', function() - it('sets up all highlight groups', function() - local mock_set_hl = spy.on(vim.api, 'nvim_set_hl') - run_render.setup_highlights() - - assert.spy(mock_set_hl).was_called(7) - assert.spy(mock_set_hl).was_called_with(0, 'CpTestAC', { fg = '#10b981' }) - assert.spy(mock_set_hl).was_called_with(0, 'CpTestWA', { fg = '#ef4444' }) - assert.spy(mock_set_hl).was_called_with(0, 'CpTestTLE', { fg = '#f59e0b' }) - assert.spy(mock_set_hl).was_called_with(0, 'CpTestRTE', { fg = '#8b5cf6' }) - - mock_set_hl:revert() + it('runs without errors', function() + assert.has_no_errors(function() + run_render.setup_highlights() + end) end) end) diff --git a/spec/scraper_spec.lua b/spec/scraper_spec.lua index cc02b6b..c81f8e2 100644 --- a/spec/scraper_spec.lua +++ b/spec/scraper_spec.lua @@ -56,8 +56,7 @@ describe('cp.scrape', function() package.loaded['cp.cache'] = mock_cache package.loaded['cp.utils'] = mock_utils - package.loaded['cp.scrape'] = nil - scrape = require('cp.scrape') + scrape = spec_helper.fresh_require('cp.scrape') local original_fn = vim.fn vim.fn = vim.tbl_extend('force', vim.fn, { @@ -125,8 +124,7 @@ describe('cp.scrape', function() stored_data = { platform = platform, contest_id = contest_id, problems = problems } end - package.loaded['cp.scrape'] = nil - scrape = require('cp.scrape') + scrape = spec_helper.fresh_require('cp.scrape') local result = scrape.scrape_contest_metadata('atcoder', 'abc123') diff --git a/spec/snippets_spec.lua b/spec/snippets_spec.lua index ce34d3d..944e0d9 100644 --- a/spec/snippets_spec.lua +++ b/spec/snippets_spec.lua @@ -5,8 +5,7 @@ describe('cp.snippets', function() before_each(function() spec_helper.setup() - package.loaded['cp.snippets'] = nil - snippets = require('cp.snippets') + snippets = spec_helper.fresh_require('cp.snippets') mock_luasnip = { snippet = function(trigger, body) return { trigger = trigger, body = body } diff --git a/spec/spec_helper.lua b/spec/spec_helper.lua index fd9673f..6f87157 100644 --- a/spec/spec_helper.lua +++ b/spec/spec_helper.lua @@ -1,14 +1,141 @@ local M = {} +M.logged_messages = {} + +local mock_logger = { + log = function(msg, level) + table.insert(M.logged_messages, { msg = msg, level = level }) + end, + set_config = function() end, +} + +local function setup_vim_mocks() + if not vim.fn then + vim.fn = {} + end + vim.fn.expand = vim.fn.expand or function() + return '/tmp/test.cpp' + end + vim.fn.mkdir = vim.fn.mkdir or function() end + vim.fn.fnamemodify = vim.fn.fnamemodify or function(path) + return path + end + vim.fn.tempname = vim.fn.tempname or function() + return '/tmp/session' + end + if not vim.api then + vim.api = {} + end + vim.api.nvim_get_current_buf = vim.api.nvim_get_current_buf or function() + return 1 + end + vim.api.nvim_buf_get_lines = vim.api.nvim_buf_get_lines or function() + return { '' } + end + if not vim.cmd then + vim.cmd = {} + end + vim.cmd.e = function() end + vim.cmd.only = function() end + vim.cmd.split = function() end + vim.cmd.vsplit = function() end + if not vim.system then + vim.system = function(_) + return { + wait = function() + return { code = 0 } + end, + } + end + end +end + function M.setup() - package.loaded['cp.log'] = { - log = function() end, - set_config = function() end, + M.logged_messages = {} + package.loaded['cp.log'] = mock_logger +end + +function M.setup_full() + M.setup() + setup_vim_mocks() + + local cache = require('cp.cache') + cache.load = function() end + cache.set_test_cases = function() end + cache.set_file_state = function() end + cache.get_file_state = function() + return nil + end + cache.get_contest_data = function() + return nil + end + cache.get_test_cases = function() + return {} + end +end + +function M.mock_scraper_success() + package.loaded['cp.scrape'] = { + scrape_problem = function(ctx) + return { + success = true, + problem_id = ctx.problem_id, + test_cases = { + { input = '1 2', expected = '3' }, + { input = '3 4', expected = '7' }, + }, + test_count = 2, + } + end, + scrape_contest_metadata = function(_, _) + return { + success = true, + problems = { + { id = 'a' }, + { id = 'b' }, + { id = 'c' }, + }, + } + end, + scrape_problems_parallel = function() + return {} + end, } end +function M.has_error_logged() + for _, log_entry in ipairs(M.logged_messages) do + if log_entry.level == vim.log.levels.ERROR then + return true + end + end + return false +end + +function M.find_logged_message(pattern) + for _, log_entry in ipairs(M.logged_messages) do + if log_entry.msg and log_entry.msg:match(pattern) then + return log_entry + end + end + return nil +end + +function M.fresh_require(module_name, additional_clears) + additional_clears = additional_clears or {} + + for _, clear_module in ipairs(additional_clears) do + package.loaded[clear_module] = nil + end + package.loaded[module_name] = nil + + return require(module_name) +end + function M.teardown() package.loaded['cp.log'] = nil + package.loaded['cp.scrape'] = nil + M.logged_messages = {} end return M diff --git a/tests/scrapers/conftest.py b/tests/scrapers/conftest.py index 3248ec2..ecb8c77 100644 --- a/tests/scrapers/conftest.py +++ b/tests/scrapers/conftest.py @@ -4,6 +4,8 @@ import pytest @pytest.fixture def mock_codeforces_html(): return """ +
Time limit: 1 seconds
+
Memory limit: 256 megabytes
             
3
diff --git a/tests/scrapers/test_codeforces.py b/tests/scrapers/test_codeforces.py index 14b263c..a7ff800 100644 --- a/tests/scrapers/test_codeforces.py +++ b/tests/scrapers/test_codeforces.py @@ -1,6 +1,6 @@ from unittest.mock import Mock -from scrapers.codeforces import scrape, scrape_contest_problems, scrape_contests +from scrapers.codeforces import CodeforcesScraper from scrapers.models import ContestSummary, ProblemSummary @@ -14,11 +14,13 @@ def test_scrape_success(mocker, mock_codeforces_html): "scrapers.codeforces.cloudscraper.create_scraper", return_value=mock_scraper ) - result = scrape("https://codeforces.com/contest/1900/problem/A") + scraper = CodeforcesScraper() + result = scraper.scrape_problem_tests("1900", "A") - assert len(result) == 1 - assert result[0].input == "1\n3\n1 2 3" - assert result[0].expected == "6" + assert result.success + assert len(result.tests) == 1 + assert result.tests[0].input == "1\n3\n1 2 3" + assert result.tests[0].expected == "6" def test_scrape_contest_problems(mocker): @@ -34,11 +36,13 @@ def test_scrape_contest_problems(mocker): "scrapers.codeforces.cloudscraper.create_scraper", return_value=mock_scraper ) - result = scrape_contest_problems("1900") + scraper = CodeforcesScraper() + result = scraper.scrape_contest_metadata("1900") - assert len(result) == 2 - assert result[0] == ProblemSummary(id="a", name="A. Problem A") - assert result[1] == ProblemSummary(id="b", name="B. Problem B") + assert result.success + assert len(result.problems) == 2 + assert result.problems[0] == ProblemSummary(id="a", name="A. Problem A") + assert result.problems[1] == ProblemSummary(id="b", name="B. Problem B") def test_scrape_network_error(mocker): @@ -49,9 +53,11 @@ def test_scrape_network_error(mocker): "scrapers.codeforces.cloudscraper.create_scraper", return_value=mock_scraper ) - result = scrape("https://codeforces.com/contest/1900/problem/A") + scraper = CodeforcesScraper() + result = scraper.scrape_problem_tests("1900", "A") - assert result == [] + assert not result.success + assert "network error" in result.error.lower() def test_scrape_contests_success(mocker): @@ -71,20 +77,22 @@ def test_scrape_contests_success(mocker): "scrapers.codeforces.cloudscraper.create_scraper", return_value=mock_scraper ) - result = scrape_contests() + scraper = CodeforcesScraper() + result = scraper.scrape_contest_list() - assert len(result) == 3 - assert result[0] == ContestSummary( + assert result.success + assert len(result.contests) == 3 + assert result.contests[0] == ContestSummary( id="1951", name="Educational Codeforces Round 168 (Rated for Div. 2)", display_name="Educational Codeforces Round 168 (Rated for Div. 2)", ) - assert result[1] == ContestSummary( + assert result.contests[1] == ContestSummary( id="1950", name="Codeforces Round 936 (Div. 2)", display_name="Codeforces Round 936 (Div. 2)", ) - assert result[2] == ContestSummary( + assert result.contests[2] == ContestSummary( id="1949", name="Codeforces Global Round 26", display_name="Codeforces Global Round 26", @@ -101,9 +109,11 @@ def test_scrape_contests_api_error(mocker): "scrapers.codeforces.cloudscraper.create_scraper", return_value=mock_scraper ) - result = scrape_contests() + scraper = CodeforcesScraper() + result = scraper.scrape_contest_list() - assert result == [] + assert not result.success + assert "no contests found" in result.error.lower() def test_scrape_contests_network_error(mocker): @@ -114,6 +124,8 @@ def test_scrape_contests_network_error(mocker): "scrapers.codeforces.cloudscraper.create_scraper", return_value=mock_scraper ) - result = scrape_contests() + scraper = CodeforcesScraper() + result = scraper.scrape_contest_list() - assert result == [] + assert not result.success + assert "network error" in result.error.lower() diff --git a/tests/scrapers/test_interface_compliance.py b/tests/scrapers/test_interface_compliance.py new file mode 100644 index 0000000..753e0de --- /dev/null +++ b/tests/scrapers/test_interface_compliance.py @@ -0,0 +1,162 @@ +from unittest.mock import Mock + +import pytest + +from scrapers import ALL_SCRAPERS, BaseScraper +from scrapers.models import ContestListResult, MetadataResult, TestsResult + +ALL_SCRAPER_CLASSES = list(ALL_SCRAPERS.values()) + + +class TestScraperInterfaceCompliance: + @pytest.mark.parametrize("scraper_class", ALL_SCRAPER_CLASSES) + def test_implements_base_interface(self, scraper_class): + scraper = scraper_class() + + assert isinstance(scraper, BaseScraper) + assert hasattr(scraper, "platform_name") + assert hasattr(scraper, "scrape_contest_metadata") + assert hasattr(scraper, "scrape_problem_tests") + assert hasattr(scraper, "scrape_contest_list") + + @pytest.mark.parametrize("scraper_class", ALL_SCRAPER_CLASSES) + def test_platform_name_is_string(self, scraper_class): + scraper = scraper_class() + platform_name = scraper.platform_name + + assert isinstance(platform_name, str) + assert len(platform_name) > 0 + assert platform_name.islower() # Convention: lowercase platform names + + @pytest.mark.parametrize("scraper_class", ALL_SCRAPER_CLASSES) + def test_metadata_method_signature(self, scraper_class, mocker): + scraper = scraper_class() + + # Mock the underlying HTTP calls to avoid network requests + if scraper.platform_name == "codeforces": + mock_scraper = Mock() + mock_response = Mock() + mock_response.text = "A. Test" + mock_scraper.get.return_value = mock_response + mocker.patch( + "scrapers.codeforces.cloudscraper.create_scraper", + return_value=mock_scraper, + ) + + result = scraper.scrape_contest_metadata("test_contest") + + assert isinstance(result, MetadataResult) + assert hasattr(result, "success") + assert hasattr(result, "error") + assert hasattr(result, "problems") + assert hasattr(result, "contest_id") + assert isinstance(result.success, bool) + assert isinstance(result.error, str) + + @pytest.mark.parametrize("scraper_class", ALL_SCRAPER_CLASSES) + def test_problem_tests_method_signature(self, scraper_class, mocker): + scraper = scraper_class() + + if scraper.platform_name == "codeforces": + mock_scraper = Mock() + mock_response = Mock() + mock_response.text = """ +
Time limit: 1 seconds
+
Memory limit: 256 megabytes
+
3
+
6
+ """ + mock_scraper.get.return_value = mock_response + mocker.patch( + "scrapers.codeforces.cloudscraper.create_scraper", + return_value=mock_scraper, + ) + + result = scraper.scrape_problem_tests("test_contest", "A") + + assert isinstance(result, TestsResult) + assert hasattr(result, "success") + assert hasattr(result, "error") + assert hasattr(result, "tests") + assert hasattr(result, "problem_id") + assert hasattr(result, "url") + assert hasattr(result, "timeout_ms") + assert hasattr(result, "memory_mb") + assert isinstance(result.success, bool) + assert isinstance(result.error, str) + + @pytest.mark.parametrize("scraper_class", ALL_SCRAPER_CLASSES) + def test_contest_list_method_signature(self, scraper_class, mocker): + scraper = scraper_class() + + if scraper.platform_name == "codeforces": + mock_scraper = Mock() + mock_response = Mock() + mock_response.json.return_value = { + "status": "OK", + "result": [{"id": 1900, "name": "Test Contest"}], + } + mock_scraper.get.return_value = mock_response + mocker.patch( + "scrapers.codeforces.cloudscraper.create_scraper", + return_value=mock_scraper, + ) + + result = scraper.scrape_contest_list() + + assert isinstance(result, ContestListResult) + assert hasattr(result, "success") + assert hasattr(result, "error") + assert hasattr(result, "contests") + assert isinstance(result.success, bool) + assert isinstance(result.error, str) + + @pytest.mark.parametrize("scraper_class", ALL_SCRAPER_CLASSES) + def test_error_message_format(self, scraper_class, mocker): + scraper = scraper_class() + platform_name = scraper.platform_name + + # Force an error by mocking HTTP failure + if scraper.platform_name == "codeforces": + mock_scraper = Mock() + mock_scraper.get.side_effect = Exception("Network error") + mocker.patch( + "scrapers.codeforces.cloudscraper.create_scraper", + return_value=mock_scraper, + ) + elif scraper.platform_name == "atcoder": + mocker.patch( + "scrapers.atcoder.requests.get", side_effect=Exception("Network error") + ) + elif scraper.platform_name == "cses": + mocker.patch( + "scrapers.cses.make_request", side_effect=Exception("Network error") + ) + + # Test metadata error format + result = scraper.scrape_contest_metadata("test") + assert not result.success + assert result.error.startswith(f"{platform_name}: ") + + # Test problem tests error format + result = scraper.scrape_problem_tests("test", "A") + assert not result.success + assert result.error.startswith(f"{platform_name}: ") + + # Test contest list error format + result = scraper.scrape_contest_list() + assert not result.success + assert result.error.startswith(f"{platform_name}: ") + + @pytest.mark.parametrize("scraper_class", ALL_SCRAPER_CLASSES) + def test_scraper_instantiation(self, scraper_class): + scraper1 = scraper_class() + assert isinstance(scraper1, BaseScraper) + assert scraper1.config is not None + + from scrapers.base import ScraperConfig + + custom_config = ScraperConfig(timeout_seconds=60) + scraper2 = scraper_class(custom_config) + assert isinstance(scraper2, BaseScraper) + assert scraper2.config.timeout_seconds == 60 diff --git a/tests/scrapers/test_registry.py b/tests/scrapers/test_registry.py new file mode 100644 index 0000000..a656d1e --- /dev/null +++ b/tests/scrapers/test_registry.py @@ -0,0 +1,58 @@ +import pytest + +from scrapers import ALL_SCRAPERS, get_scraper, list_platforms +from scrapers.base import BaseScraper +from scrapers.codeforces import CodeforcesScraper + + +class TestScraperRegistry: + def test_get_scraper_valid_platform(self): + scraper_class = get_scraper("codeforces") + assert scraper_class == CodeforcesScraper + assert issubclass(scraper_class, BaseScraper) + + scraper = scraper_class() + assert isinstance(scraper, BaseScraper) + assert scraper.platform_name == "codeforces" + + def test_get_scraper_invalid_platform(self): + with pytest.raises(KeyError) as exc_info: + get_scraper("nonexistent") + + error_msg = str(exc_info.value) + assert "nonexistent" in error_msg + assert "Available platforms" in error_msg + + def test_list_platforms(self): + platforms = list_platforms() + + assert isinstance(platforms, list) + assert len(platforms) > 0 + assert "codeforces" in platforms + + assert set(platforms) == set(ALL_SCRAPERS.keys()) + + def test_all_scrapers_registry(self): + assert isinstance(ALL_SCRAPERS, dict) + assert len(ALL_SCRAPERS) > 0 + + for platform_name, scraper_class in ALL_SCRAPERS.items(): + assert isinstance(platform_name, str) + assert platform_name.islower() + + assert issubclass(scraper_class, BaseScraper) + + scraper = scraper_class() + assert scraper.platform_name == platform_name + + def test_registry_import_consistency(self): + from scrapers.codeforces import CodeforcesScraper as DirectImport + + registry_class = get_scraper("codeforces") + assert registry_class == DirectImport + + def test_all_scrapers_can_be_instantiated(self): + for platform_name, scraper_class in ALL_SCRAPERS.items(): + scraper = scraper_class() + assert isinstance(scraper, BaseScraper) + assert scraper.platform_name == platform_name