diff --git a/README.md b/README.md index 2e9ee89..b31febd 100644 --- a/README.md +++ b/README.md @@ -28,11 +28,11 @@ cp.nvim follows a simple principle: **solve locally, submit remotely**. ### Basic Usage -1. **Find a problem** on the judge website -2. **Set up locally** with `:CP ` +1. **Find a contest or problem** on the judge website +2. **Set up locally** with `:CP []` ``` - :CP codeforces 1848 A + :CP codeforces 1848 ``` 3. **Code and test** with instant feedback and rich diffs @@ -62,7 +62,3 @@ See [my config](https://github.com/barrett-ruth/dots/blob/main/nvim/lua/plugins/ - [competitest.nvim](https://github.com/xeluxee/competitest.nvim) - [assistant.nvim](https://github.com/A7Lavinraj/assistant.nvim) - -## TODO - -- Windows support diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 04d7dbd..c59fc61 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -342,4 +342,24 @@ function M.clear_contest_list(platform) end end +function M.clear_all() + cache_data = {} + M.save() +end + +---@param platform string +function M.clear_platform(platform) + vim.validate({ + platform = { platform, 'string' }, + }) + + if cache_data[platform] then + cache_data[platform] = nil + end + if cache_data.contest_lists and cache_data.contest_lists[platform] then + cache_data.contest_lists[platform] = nil + end + M.save() +end + return M diff --git a/lua/cp/config.lua b/lua/cp/config.lua index 838107b..9386b75 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -101,7 +101,7 @@ M.defaults = { filename = nil, run_panel = { ansi = true, - diff_mode = 'vim', + diff_mode = 'git', next_test_key = '', prev_test_key = '', toggle_diff_key = 't', @@ -178,7 +178,6 @@ function M.setup(user_config) local config = vim.tbl_deep_extend('force', M.defaults, user_config or {}) - -- Validate merged config values vim.validate({ before_run = { config.hooks.before_run, @@ -267,12 +266,8 @@ function M.setup(user_config) error('No language configurations found') end - if vim.tbl_contains(available_langs, 'cpp') then - contest_config.default_language = 'cpp' - else - table.sort(available_langs) - contest_config.default_language = available_langs[1] - end + table.sort(available_langs) + contest_config.default_language = available_langs[1] end end diff --git a/lua/cp/constants.lua b/lua/cp/constants.lua index c14569b..7544435 100644 --- a/lua/cp/constants.lua +++ b/lua/cp/constants.lua @@ -1,7 +1,7 @@ local M = {} M.PLATFORMS = { 'atcoder', 'codeforces', 'cses' } -M.ACTIONS = { 'run', 'next', 'prev', 'pick' } +M.ACTIONS = { 'run', 'next', 'prev', 'pick', 'cache' } M.PLATFORM_DISPLAY_NAMES = { atcoder = 'AtCoder', diff --git a/lua/cp/init.lua b/lua/cp/init.lua index 1c21880..c50316e 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -246,7 +246,7 @@ local function toggle_run_panel(is_debug) end local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) - local run = require('cp.run') + local run = require('cp.runner.run') if not run.load_test_cases(ctx, state) then logger.log('no test cases found', vim.log.levels.WARN) @@ -270,7 +270,7 @@ local function toggle_run_panel(is_debug) tab_buf = tab_buf, } - local highlight = require('cp.highlight') + local highlight = require('cp.ui.highlight') local diff_namespace = highlight.create_namespace() local test_list_namespace = vim.api.nvim_create_namespace('cp_test_list') @@ -294,6 +294,7 @@ local function toggle_run_panel(is_debug) vim.api.nvim_set_current_win(parent_win) vim.cmd.split() + vim.cmd('resize ' .. math.floor(vim.o.lines * 0.35)) local actual_win = vim.api.nvim_get_current_win() vim.api.nvim_win_set_buf(actual_win, actual_buf) @@ -341,13 +342,14 @@ local function toggle_run_panel(is_debug) vim.api.nvim_set_current_win(parent_win) vim.cmd.split() + vim.cmd('resize ' .. math.floor(vim.o.lines * 0.35)) local diff_win = vim.api.nvim_get_current_win() vim.api.nvim_win_set_buf(diff_win, diff_buf) vim.api.nvim_set_option_value('filetype', 'cptest', { buf = diff_buf }) vim.api.nvim_set_option_value('winbar', 'Expected vs Actual', { win = diff_win }) - local diff_backend = require('cp.diff') + local diff_backend = require('cp.ui.diff') local backend = diff_backend.get_best_backend('git') local diff_result = backend.render(expected_content, actual_content) @@ -375,6 +377,7 @@ local function toggle_run_panel(is_debug) vim.api.nvim_set_current_win(parent_win) vim.cmd.split() + vim.cmd('resize ' .. math.floor(vim.o.lines * 0.35)) local win = vim.api.nvim_get_current_win() vim.api.nvim_win_set_buf(win, buf) vim.api.nvim_set_option_value('filetype', 'cptest', { buf = buf }) @@ -459,7 +462,7 @@ local function toggle_run_panel(is_debug) ansi_namespace ) elseif desired_mode == 'git' then - local diff_backend = require('cp.diff') + local diff_backend = require('cp.ui.diff') local backend = diff_backend.get_best_backend('git') local diff_result = backend.render(expected_content, actual_content) @@ -513,7 +516,7 @@ local function toggle_run_panel(is_debug) return end - local run_render = require('cp.run_render') + local run_render = require('cp.runner.run_render') run_render.setup_highlights() local test_state = run.get_run_panel_state() @@ -573,7 +576,7 @@ local function toggle_run_panel(is_debug) config.hooks.before_debug(ctx) end - local execute = require('cp.execute') + local execute = require('cp.runner.execute') local contest_config = config.contests[state.platform] local compile_result = execute.compile_problem(ctx, contest_config, is_debug) if compile_result.success then @@ -586,7 +589,7 @@ local function toggle_run_panel(is_debug) vim.schedule(function() if config.run_panel.ansi then - local ansi = require('cp.ansi') + local ansi = require('cp.ui.ansi') ansi.setup_highlight_groups() end if current_diff_layout then @@ -600,7 +603,10 @@ local function toggle_run_panel(is_debug) state.test_buffers = test_buffers state.test_windows = test_windows local test_state = run.get_run_panel_state() - logger.log(string.format('test panel opened (%d test cases)', #test_state.test_cases)) + logger.log( + string.format('test panel opened (%d test cases)', #test_state.test_cases), + vim.log.levels.INFO + ) end ---@param contest_id string @@ -751,6 +757,29 @@ local function handle_pick_action() end end +local function handle_cache_command(cmd) + if cmd.subcommand == 'clear' then + cache.load() + if cmd.platform then + if vim.tbl_contains(platforms, cmd.platform) then + cache.clear_platform(cmd.platform) + logger.log(('cleared cache for %s'):format(cmd.platform), vim.log.levels.INFO, true) + else + logger.log( + ('unknown platform: %s. Available: %s'):format( + cmd.platform, + table.concat(platforms, ', ') + ), + vim.log.levels.ERROR + ) + end + else + cache.clear_all() + logger.log('cleared all cache', vim.log.levels.INFO, true) + end + end +end + local function restore_from_current_file() local current_file = vim.fn.expand('%:p') if current_file == '' then @@ -820,7 +849,24 @@ local function parse_command(args) local first = filtered_args[1] if vim.tbl_contains(actions, first) then - return { type = 'action', action = first, language = language, debug = debug } + if first == 'cache' then + local subcommand = filtered_args[2] + if not subcommand then + return { type = 'error', message = 'cache command requires subcommand: clear' } + end + if subcommand == 'clear' then + local platform = filtered_args[3] + return { + type = 'cache', + subcommand = 'clear', + platform = platform, + } + else + return { type = 'error', message = 'unknown cache subcommand: ' .. subcommand } + end + else + return { type = 'action', action = first, language = language, debug = debug } + end end if vim.tbl_contains(platforms, first) then @@ -896,6 +942,11 @@ function M.handle_command(opts) return end + if cmd.type == 'cache' then + handle_cache_command(cmd) + return + end + if cmd.type == 'platform_only' then set_platform(cmd.platform) return @@ -929,7 +980,9 @@ function M.handle_command(opts) #metadata_result.problems, cmd.platform, cmd.contest - ) + ), + vim.log.levels.INFO, + true ) problem_ids = vim.tbl_map(function(prob) return prob.id diff --git a/lua/cp/pickers/init.lua b/lua/cp/pickers/init.lua index 5a1b644..b981b59 100644 --- a/lua/cp/pickers/init.lua +++ b/lua/cp/pickers/init.lua @@ -159,8 +159,10 @@ end ---@param contest_id string Contest identifier ---@param problem_id string Problem identifier local function setup_problem(platform, contest_id, problem_id) - local cp = require('cp') - cp.handle_command({ fargs = { platform, contest_id, problem_id } }) + vim.schedule(function() + local cp = require('cp') + cp.handle_command({ fargs = { platform, contest_id, problem_id } }) + end) end M.get_platforms = get_platforms diff --git a/lua/cp/execute.lua b/lua/cp/runner/execute.lua similarity index 97% rename from lua/cp/execute.lua rename to lua/cp/runner/execute.lua index 1c433d6..62c0f99 100644 --- a/lua/cp/execute.lua +++ b/lua/cp/runner/execute.lua @@ -89,12 +89,12 @@ function M.compile_generic(language_config, substitutions) :wait() local compile_time = (vim.uv.hrtime() - start_time) / 1000000 - local ansi = require('cp.ansi') + local ansi = require('cp.ui.ansi') result.stdout = ansi.bytes_to_string(result.stdout or '') result.stderr = ansi.bytes_to_string(result.stderr or '') if result.code == 0 then - logger.log(('compilation successful (%.1fms)'):format(compile_time)) + logger.log(('compilation successful (%.1fms)'):format(compile_time), vim.log.levels.INFO) else logger.log(('compilation failed (%.1fms)'):format(compile_time)) end @@ -235,7 +235,10 @@ function M.compile_problem(ctx, contest_config, is_debug) if compile_result.code ~= 0 then return { success = false, output = compile_result.stdout or 'unknown error' } end - logger.log(('compilation successful (%s)'):format(is_debug and 'debug mode' or 'test mode')) + logger.log( + ('compilation successful (%s)'):format(is_debug and 'debug mode' or 'test mode'), + vim.log.levels.INFO + ) end return { success = true, output = nil } diff --git a/lua/cp/run.lua b/lua/cp/runner/run.lua similarity index 98% rename from lua/cp/run.lua rename to lua/cp/runner/run.lua index ddd2346..f996a60 100644 --- a/lua/cp/run.lua +++ b/lua/cp/runner/run.lua @@ -194,7 +194,7 @@ local function run_single_test_case(ctx, contest_config, cp_config, test_case) .system({ 'sh', '-c', table.concat(redirected_cmd, ' ') }, { text = false }) :wait() - local ansi = require('cp.ansi') + local ansi = require('cp.ui.ansi') compile_result.stdout = ansi.bytes_to_string(compile_result.stdout or '') compile_result.stderr = ansi.bytes_to_string(compile_result.stderr or '') @@ -234,7 +234,7 @@ local function run_single_test_case(ctx, contest_config, cp_config, test_case) :wait() local execution_time = (vim.uv.hrtime() - start_time) / 1000000 - local ansi = require('cp.ansi') + local ansi = require('cp.ui.ansi') local stdout_str = ansi.bytes_to_string(result.stdout or '') local actual_output = stdout_str:gsub('\n$', '') @@ -315,7 +315,7 @@ function M.load_test_cases(ctx, state) run_panel_state.constraints.memory_mb ) or '' - logger.log(('loaded %d test case(s)%s'):format(#test_cases, constraint_info)) + logger.log(('loaded %d test case(s)%s'):format(#test_cases, constraint_info), vim.log.levels.INFO) return #test_cases > 0 end @@ -365,7 +365,7 @@ function M.get_run_panel_state() end function M.handle_compilation_failure(compilation_output) - local ansi = require('cp.ansi') + local ansi = require('cp.ui.ansi') local config = require('cp.config').setup() local clean_text diff --git a/lua/cp/run_render.lua b/lua/cp/runner/run_render.lua similarity index 100% rename from lua/cp/run_render.lua rename to lua/cp/runner/run_render.lua diff --git a/lua/cp/scrape.lua b/lua/cp/scrape.lua index 5ecaf43..f7a48c8 100644 --- a/lua/cp/scrape.lua +++ b/lua/cp/scrape.lua @@ -17,7 +17,7 @@ local problem = require('cp.problem') local utils = require('cp.utils') local function check_internet_connectivity() - local result = vim.system({ 'ping', '-c', '1', '-W', '3', '8.8.8.8' }, { text = true }):wait() + local result = vim.system({ 'ping', '-c', '5', '-W', '3', '8.8.8.8' }, { text = true }):wait() return result.code == 0 end diff --git a/lua/cp/ansi.lua b/lua/cp/ui/ansi.lua similarity index 100% rename from lua/cp/ansi.lua rename to lua/cp/ui/ansi.lua diff --git a/lua/cp/diff.lua b/lua/cp/ui/diff.lua similarity index 100% rename from lua/cp/diff.lua rename to lua/cp/ui/diff.lua diff --git a/lua/cp/highlight.lua b/lua/cp/ui/highlight.lua similarity index 100% rename from lua/cp/highlight.lua rename to lua/cp/ui/highlight.lua diff --git a/lua/cp/window.lua b/lua/cp/window.lua deleted file mode 100644 index b150965..0000000 --- a/lua/cp/window.lua +++ /dev/null @@ -1,145 +0,0 @@ ----@class WindowState ----@field windows table ----@field current_win integer ----@field layout string - ----@class WindowData ----@field bufnr integer ----@field view table ----@field width integer ----@field height integer - -local M = {} -local constants = require('cp.constants') - ----@return WindowState -function M.save_layout() - local windows = {} - for _, win in ipairs(vim.api.nvim_list_wins()) do - if vim.api.nvim_win_is_valid(win) then - local bufnr = vim.api.nvim_win_get_buf(win) - windows[win] = { - bufnr = bufnr, - view = vim.fn.winsaveview(), - width = vim.api.nvim_win_get_width(win), - height = vim.api.nvim_win_get_height(win), - } - end - end - - return { - windows = windows, - current_win = vim.api.nvim_get_current_win(), - layout = vim.fn.winrestcmd(), - } -end - ----@param state? WindowState ----@param tile_fn? fun(source_buf: integer, input_buf: integer, output_buf: integer) -function M.restore_layout(state, tile_fn) - vim.validate({ - state = { state, { 'table', 'nil' }, true }, - tile_fn = { tile_fn, { 'function', 'nil' }, true }, - }) - - if not state then - return - end - - vim.cmd.diffoff() - - local problem_id = vim.fn.expand('%:t:r') - if problem_id == '' then - for win, win_state in pairs(state.windows) do - if vim.api.nvim_win_is_valid(win) and vim.api.nvim_buf_is_valid(win_state.bufnr) then - local bufname = vim.api.nvim_buf_get_name(win_state.bufnr) - if - not bufname:match('%.in$') - and not bufname:match('%.out$') - and not bufname:match('%.expected$') - then - problem_id = vim.fn.fnamemodify(bufname, ':t:r') - break - end - end - end - end - - if problem_id ~= '' then - vim.cmd('silent only') - - local base_fp = vim.fn.getcwd() - local input_file = ('%s/io/%s.in'):format(base_fp, problem_id) - local output_file = ('%s/io/%s.out'):format(base_fp, problem_id) - local source_files = vim.fn.glob(problem_id .. '.*') - local source_file - if source_files ~= '' then - local files = vim.split(source_files, '\n') - -- Prefer known extensions first, but accept any extension - local known_extensions = vim.tbl_keys(constants.filetype_to_language) - for _, file in ipairs(files) do - local ext = vim.fn.fnamemodify(file, ':e') - if vim.tbl_contains(known_extensions, ext) then - source_file = file - break - end - end - source_file = source_file or files[1] - end - - if not source_file or vim.fn.filereadable(source_file) == 0 then - return - end - - vim.cmd.edit(source_file) - local source_buf = vim.api.nvim_get_current_buf() - local input_buf = vim.fn.bufnr(input_file, true) - local output_buf = vim.fn.bufnr(output_file, true) - - if tile_fn then - tile_fn(source_buf, input_buf, output_buf) - else - M.default_tile(source_buf, input_buf, output_buf) - end - else - vim.cmd(state.layout) - - for win, win_state in pairs(state.windows) do - if vim.api.nvim_win_is_valid(win) then - vim.api.nvim_set_current_win(win) - if vim.api.nvim_get_current_buf() == win_state.bufnr then - vim.fn.winrestview(win_state.view) - end - end - end - - if vim.api.nvim_win_is_valid(state.current_win) then - vim.api.nvim_set_current_win(state.current_win) - end - end -end - ----@param source_buf integer ----@param input_buf integer ----@param output_buf integer -local function default_tile(source_buf, input_buf, output_buf) - vim.validate({ - source_buf = { source_buf, 'number' }, - input_buf = { input_buf, 'number' }, - output_buf = { output_buf, 'number' }, - }) - - vim.api.nvim_set_current_buf(source_buf) - vim.cmd.vsplit() - vim.api.nvim_set_current_buf(output_buf) - vim.bo.filetype = 'cp' - vim.cmd(('vertical resize %d'):format(math.floor(vim.o.columns * 0.3))) - vim.cmd.split() - vim.api.nvim_set_current_buf(input_buf) - vim.bo.filetype = 'cp' - vim.cmd.wincmd('h') -end - -M.default_tile = default_tile - -return M diff --git a/plugin/cp.lua b/plugin/cp.lua index 2bf4707..da193dc 100644 --- a/plugin/cp.lua +++ b/plugin/cp.lua @@ -36,12 +36,24 @@ end, { end else vim.list_extend(candidates, platforms) + table.insert(candidates, 'cache') + table.insert(candidates, 'pick') end return vim.tbl_filter(function(cmd) return cmd:find(ArgLead, 1, true) == 1 end, candidates) + elseif num_args == 3 then + if args[2] == 'cache' then + return vim.tbl_filter(function(cmd) + return cmd:find(ArgLead, 1, true) == 1 + end, { 'clear' }) + end elseif num_args == 4 then - if vim.tbl_contains(platforms, args[2]) then + if args[2] == 'cache' and args[3] == 'clear' then + return vim.tbl_filter(function(cmd) + return cmd:find(ArgLead, 1, true) == 1 + end, platforms) + elseif vim.tbl_contains(platforms, args[2]) then local cache = require('cp.cache') cache.load() local contest_data = cache.get_contest_data(args[2], args[3]) diff --git a/scrapers/atcoder.py b/scrapers/atcoder.py index 83e1cc1..1935c6e 100644 --- a/scrapers/atcoder.py +++ b/scrapers/atcoder.py @@ -272,75 +272,7 @@ def scrape_contests() -> list[ContestSummary]: r"[\uff01-\uff5e]", lambda m: chr(ord(m.group()) - 0xFEE0), name ) - def generate_display_name_from_id(contest_id: str) -> str: - parts = contest_id.replace("-", " ").replace("_", " ") - - parts = re.sub( - r"\b(jsc|JSC)\b", - "Japanese Student Championship", - parts, - flags=re.IGNORECASE, - ) - parts = re.sub( - r"\b(wtf|WTF)\b", - "World Tour Finals", - parts, - flags=re.IGNORECASE, - ) - parts = re.sub( - r"\b(ahc)(\d+)\b", - r"Heuristic Contest \2 (AHC)", - parts, - flags=re.IGNORECASE, - ) - parts = re.sub( - r"\b(arc)(\d+)\b", - r"Regular Contest \2 (ARC)", - parts, - flags=re.IGNORECASE, - ) - parts = re.sub( - r"\b(abc)(\d+)\b", - r"Beginner Contest \2 (ABC)", - parts, - flags=re.IGNORECASE, - ) - parts = re.sub( - r"\b(agc)(\d+)\b", - r"Grand Contest \2 (AGC)", - parts, - flags=re.IGNORECASE, - ) - - return parts.title() - - english_chars = sum(1 for c in name if c.isascii() and c.isalpha()) - total_chars = len(re.sub(r"\s+", "", name)) - - if total_chars > 0 and english_chars / total_chars < 0.3: - display_name = generate_display_name_from_id(contest_id) - else: - display_name = name - if "AtCoder Beginner Contest" in name: - match = re.search(r"AtCoder Beginner Contest (\d+)", name) - if match: - display_name = f"Beginner Contest {match.group(1)} (ABC)" - elif "AtCoder Regular Contest" in name: - match = re.search(r"AtCoder Regular Contest (\d+)", name) - if match: - display_name = f"Regular Contest {match.group(1)} (ARC)" - elif "AtCoder Grand Contest" in name: - match = re.search(r"AtCoder Grand Contest (\d+)", name) - if match: - display_name = f"Grand Contest {match.group(1)} (AGC)" - elif "AtCoder Heuristic Contest" in name: - match = re.search(r"AtCoder Heuristic Contest (\d+)", name) - if match: - display_name = f"Heuristic Contest {match.group(1)} (AHC)" - - contests.append( - ContestSummary(id=contest_id, name=name, display_name=display_name) - ) + contests.append(ContestSummary(id=contest_id, name=name, display_name=name)) return contests diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index 1402827..89d568e 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -237,45 +237,9 @@ def scrape_contests() -> list[ContestSummary]: contest_id = str(contest["id"]) name = contest["name"] - display_name = name - if "Educational Codeforces Round" in name: - match = re.search(r"Educational Codeforces Round (\d+)", name) - if match: - display_name = f"Educational Round {match.group(1)}" - elif "Codeforces Global Round" in name: - match = re.search(r"Codeforces Global Round (\d+)", name) - if match: - display_name = f"Global Round {match.group(1)}" - elif "Codeforces Round" in name: - div_match = re.search(r"Codeforces Round (\d+) \(Div\. (\d+)\)", name) - if div_match: - display_name = ( - f"Round {div_match.group(1)} (Div. {div_match.group(2)})" - ) - else: - combined_match = re.search( - r"Codeforces Round (\d+) \(Div\. 1 \+ Div\. 2\)", name - ) - if combined_match: - display_name = ( - f"Round {combined_match.group(1)} (Div. 1 + Div. 2)" - ) - else: - single_div_match = re.search( - r"Codeforces Round (\d+) \(Div\. 1\)", name - ) - if single_div_match: - display_name = f"Round {single_div_match.group(1)} (Div. 1)" - else: - round_match = re.search(r"Codeforces Round (\d+)", name) - if round_match: - display_name = f"Round {round_match.group(1)}" + contests.append(ContestSummary(id=contest_id, name=name, display_name=name)) - contests.append( - ContestSummary(id=contest_id, name=name, display_name=display_name) - ) - - return contests[:100] + return contests except Exception as e: print(f"Failed to fetch contests: {e}", file=sys.stderr) diff --git a/spec/ansi_spec.lua b/spec/ansi_spec.lua index ff7a96b..af62c36 100644 --- a/spec/ansi_spec.lua +++ b/spec/ansi_spec.lua @@ -1,5 +1,5 @@ describe('ansi parser', function() - local ansi = require('cp.ansi') + local ansi = require('cp.ui.ansi') describe('bytes_to_string', function() it('returns string as-is', function() @@ -224,7 +224,6 @@ describe('ansi parser', function() ansi.setup_highlight_groups() local highlight = vim.api.nvim_get_hl(0, { name = 'CpAnsiRed' }) - -- When 'NONE' is set, nvim_get_hl returns nil for that field assert.is_nil(highlight.fg) for i = 0, 15 do diff --git a/spec/cache_spec.lua b/spec/cache_spec.lua index 5b06911..a72946a 100644 --- a/spec/cache_spec.lua +++ b/spec/cache_spec.lua @@ -156,4 +156,38 @@ describe('cp.cache', function() assert.equals('python', result.language) end) end) + + describe('cache management', function() + it('clears all cache data', function() + cache.set_contest_data('atcoder', 'test_contest', { { id = 'A' } }) + cache.set_contest_data('codeforces', 'test_contest', { { id = 'B' } }) + cache.set_file_state('/tmp/test.cpp', 'atcoder', 'abc123', 'a', 'cpp') + + cache.clear_all() + + assert.is_nil(cache.get_contest_data('atcoder', 'test_contest')) + assert.is_nil(cache.get_contest_data('codeforces', 'test_contest')) + assert.is_nil(cache.get_file_state('/tmp/test.cpp')) + end) + + it('clears cache for specific platform', function() + cache.set_contest_data('atcoder', 'test_contest', { { id = 'A' } }) + cache.set_contest_data('codeforces', 'test_contest', { { id = 'B' } }) + cache.set_contest_list('atcoder', { { id = '123', name = 'Test' } }) + cache.set_contest_list('codeforces', { { id = '456', name = 'Test' } }) + + cache.clear_platform('atcoder') + + assert.is_nil(cache.get_contest_data('atcoder', 'test_contest')) + assert.is_nil(cache.get_contest_list('atcoder')) + assert.is_not_nil(cache.get_contest_data('codeforces', 'test_contest')) + assert.is_not_nil(cache.get_contest_list('codeforces')) + end) + + it('handles clear platform for non-existent platform', function() + assert.has_no_errors(function() + cache.clear_platform('nonexistent') + end) + end) + end) end) diff --git a/spec/command_parsing_spec.lua b/spec/command_parsing_spec.lua index 2d856c6..693f2b2 100644 --- a/spec/command_parsing_spec.lua +++ b/spec/command_parsing_spec.lua @@ -293,4 +293,345 @@ describe('cp command parsing', function() end end) end) + + describe('cache commands', function() + it('handles cache clear without platform', function() + local opts = { fargs = { 'cache', 'clear' } } + + assert.has_no_errors(function() + cp.handle_command(opts) + end) + + local success_logged = false + for _, log_entry in ipairs(logged_messages) do + if log_entry.msg and log_entry.msg:match('cleared all cache') then + success_logged = true + break + end + end + assert.is_true(success_logged) + end) + + it('handles cache clear with valid platform', function() + local opts = { fargs = { 'cache', 'clear', 'atcoder' } } + + assert.has_no_errors(function() + cp.handle_command(opts) + end) + + local success_logged = false + for _, log_entry in ipairs(logged_messages) do + if log_entry.msg and log_entry.msg:match('cleared cache for atcoder') then + success_logged = true + break + end + end + assert.is_true(success_logged) + end) + + it('logs error for cache clear with invalid platform', function() + local opts = { fargs = { 'cache', 'clear', 'invalid_platform' } } + + cp.handle_command(opts) + + local error_logged = false + for _, log_entry in ipairs(logged_messages) do + if log_entry.level == vim.log.levels.ERROR and log_entry.msg:match('unknown platform') then + error_logged = true + break + end + end + assert.is_true(error_logged) + end) + + it('logs error for cache command without subcommand', function() + local opts = { fargs = { 'cache' } } + + cp.handle_command(opts) + + local error_logged = false + for _, log_entry in ipairs(logged_messages) do + if + log_entry.level == vim.log.levels.ERROR + and log_entry.msg:match('cache command requires subcommand') + then + error_logged = true + break + end + end + assert.is_true(error_logged) + end) + + it('logs error for invalid cache subcommand', function() + local opts = { fargs = { 'cache', 'invalid' } } + + cp.handle_command(opts) + + local error_logged = false + for _, log_entry in ipairs(logged_messages) do + if + log_entry.level == vim.log.levels.ERROR + and log_entry.msg:match('unknown cache subcommand') + then + error_logged = true + break + end + end + assert.is_true(error_logged) + end) + end) + + describe('CP command completion', function() + local complete_fn + + before_each(function() + package.loaded['cp'] = nil + package.loaded['cp.cache'] = nil + + complete_fn = function(ArgLead, CmdLine, _) + local constants = require('cp.constants') + local platforms = constants.PLATFORMS + local actions = constants.ACTIONS + + local args = vim.split(vim.trim(CmdLine), '%s+') + local num_args = #args + if CmdLine:sub(-1) == ' ' then + num_args = num_args + 1 + end + + if num_args == 2 then + local candidates = {} + local cp_mod = require('cp') + local context = cp_mod.get_current_context() + if context.platform and context.contest_id then + vim.list_extend(candidates, actions) + local cache = require('cp.cache') + cache.load() + local contest_data = cache.get_contest_data(context.platform, context.contest_id) + if contest_data and contest_data.problems then + for _, problem in ipairs(contest_data.problems) do + table.insert(candidates, problem.id) + end + end + else + vim.list_extend(candidates, platforms) + table.insert(candidates, 'cache') + table.insert(candidates, 'pick') + end + return vim.tbl_filter(function(cmd) + return cmd:find(ArgLead, 1, true) == 1 + end, candidates) + elseif num_args == 3 then + if args[2] == 'cache' then + return vim.tbl_filter(function(cmd) + return cmd:find(ArgLead, 1, true) == 1 + end, { 'clear' }) + end + elseif num_args == 4 then + if args[2] == 'cache' and args[3] == 'clear' then + return vim.tbl_filter(function(cmd) + return cmd:find(ArgLead, 1, true) == 1 + end, platforms) + elseif vim.tbl_contains(platforms, args[2]) then + local cache = require('cp.cache') + cache.load() + local contest_data = cache.get_contest_data(args[2], args[3]) + if contest_data and contest_data.problems then + local candidates = {} + for _, problem in ipairs(contest_data.problems) do + table.insert(candidates, problem.id) + end + return vim.tbl_filter(function(cmd) + return cmd:find(ArgLead, 1, true) == 1 + end, candidates) + end + end + end + return {} + end + + package.loaded['cp'] = { + get_current_context = function() + return { platform = nil, contest_id = nil } + end, + } + + package.loaded['cp.cache'] = { + load = function() end, + get_contest_data = function() + return nil + end, + } + end) + + after_each(function() + package.loaded['cp'] = nil + package.loaded['cp.cache'] = nil + end) + + it('completes platforms and global actions when no contest context', function() + local result = complete_fn('', 'CP ', 3) + + assert.is_table(result) + + local has_atcoder = false + local has_codeforces = false + local has_cses = false + local has_cache = false + local has_pick = false + local has_run = false + local has_next = false + local has_prev = false + + for _, item in ipairs(result) do + if item == 'atcoder' then + has_atcoder = true + end + if item == 'codeforces' then + has_codeforces = true + end + if item == 'cses' then + has_cses = true + end + if item == 'cache' then + has_cache = true + end + if item == 'pick' then + has_pick = true + end + if item == 'run' then + has_run = true + end + if item == 'next' then + has_next = true + end + if item == 'prev' then + has_prev = true + end + end + + assert.is_true(has_atcoder) + assert.is_true(has_codeforces) + assert.is_true(has_cses) + assert.is_true(has_cache) + assert.is_true(has_pick) + assert.is_false(has_run) + assert.is_false(has_next) + assert.is_false(has_prev) + end) + + it('completes all actions and problems when contest context exists', function() + package.loaded['cp'] = { + get_current_context = function() + return { platform = 'atcoder', contest_id = 'abc350' } + end, + } + package.loaded['cp.cache'] = { + load = function() end, + get_contest_data = function() + return { + problems = { + { id = 'a' }, + { id = 'b' }, + { id = 'c' }, + }, + } + end, + } + + local result = complete_fn('', 'CP ', 3) + + assert.is_table(result) + + local items = {} + for _, item in ipairs(result) do + items[item] = true + end + + assert.is_true(items['run']) + assert.is_true(items['next']) + assert.is_true(items['prev']) + assert.is_true(items['pick']) + assert.is_true(items['cache']) + + assert.is_true(items['a']) + assert.is_true(items['b']) + assert.is_true(items['c']) + end) + + it('completes cache subcommands', function() + local result = complete_fn('c', 'CP cache c', 10) + + assert.is_table(result) + assert.equals(1, #result) + assert.equals('clear', result[1]) + end) + + it('completes cache subcommands with exact match', function() + local result = complete_fn('clear', 'CP cache clear', 14) + + assert.is_table(result) + assert.equals(1, #result) + assert.equals('clear', result[1]) + end) + + it('completes platforms for cache clear', function() + local result = complete_fn('a', 'CP cache clear a', 16) + + assert.is_table(result) + + local has_atcoder = false + local has_cache = false + + for _, item in ipairs(result) do + if item == 'atcoder' then + has_atcoder = true + end + if item == 'cache' then + has_cache = true + end + end + + assert.is_true(has_atcoder) + assert.is_false(has_cache) + end) + + it('filters completions based on current input', function() + local result = complete_fn('at', 'CP at', 5) + + assert.is_table(result) + assert.equals(1, #result) + assert.equals('atcoder', result[1]) + end) + + it('returns empty array when no matches', function() + local result = complete_fn('xyz', 'CP xyz', 6) + + assert.is_table(result) + assert.equals(0, #result) + end) + + it('handles problem completion for platform contest', function() + package.loaded['cp.cache'] = { + load = function() end, + get_contest_data = function(platform, contest) + if platform == 'atcoder' and contest == 'abc350' then + return { + problems = { + { id = 'a' }, + { id = 'b' }, + }, + } + end + return nil + end, + } + + local result = complete_fn('a', 'CP atcoder abc350 a', 18) + + assert.is_table(result) + assert.equals(1, #result) + assert.equals('a', result[1]) + end) + end) end) diff --git a/spec/config_spec.lua b/spec/config_spec.lua index f3f3738..9724832 100644 --- a/spec/config_spec.lua +++ b/spec/config_spec.lua @@ -169,7 +169,7 @@ describe('cp.config', function() assert.equals('cpp', result.contests.test.default_language) end) - it('sets default_language to first available when cpp not present', function() + it('sets default_language to single available language when only one configured', function() local user_config = { contests = { test = { @@ -183,6 +183,38 @@ describe('cp.config', function() assert.equals('python', result.contests.test.default_language) end) + it('sets default_language to single available language even when not cpp', function() + local user_config = { + contests = { + test = { + rust = { + test = { './target/release/solution' }, + extension = 'rs', + }, + }, + }, + } + + local result = config.setup(user_config) + + assert.equals('rust', result.contests.test.default_language) + end) + + it('uses first available language when multiple configured', function() + local user_config = { + contests = { + test = { + python = { test = { 'python3' } }, + cpp = { compile = { 'g++' } }, + }, + }, + } + + local result = config.setup(user_config) + + assert.is_true(vim.tbl_contains({ 'cpp', 'python' }, result.contests.test.default_language)) + end) + it('preserves explicit default_language', function() local user_config = { contests = { diff --git a/spec/diff_spec.lua b/spec/diff_spec.lua index 9bb8b63..b50d05e 100644 --- a/spec/diff_spec.lua +++ b/spec/diff_spec.lua @@ -4,7 +4,7 @@ describe('cp.diff', function() before_each(function() spec_helper.setup() - diff = require('cp.diff') + diff = require('cp.ui.diff') end) after_each(function() diff --git a/spec/execute_spec.lua b/spec/execute_spec.lua index 12e7f67..38784be 100644 --- a/spec/execute_spec.lua +++ b/spec/execute_spec.lua @@ -6,7 +6,7 @@ describe('cp.execute', function() before_each(function() spec_helper.setup() - execute = require('cp.execute') + execute = require('cp.runner.execute') mock_system_calls = {} temp_files = {} @@ -416,7 +416,7 @@ describe('cp.execute', function() } end - local execute_command = require('cp.execute').execute_command + local execute_command = require('cp.runner.execute').execute_command or function(command, stdin_data, timeout) local redirected_cmd = vim.deepcopy(command) if #redirected_cmd > 0 then diff --git a/spec/extmark_spec.lua b/spec/extmark_spec.lua index 6383d55..2b4b25a 100644 --- a/spec/extmark_spec.lua +++ b/spec/extmark_spec.lua @@ -4,7 +4,7 @@ describe('extmarks', function() before_each(function() spec_helper.setup() - highlight = require('cp.highlight') + highlight = require('cp.ui.highlight') end) after_each(function() diff --git a/spec/highlight_spec.lua b/spec/highlight_spec.lua index 74f1c91..9afd773 100644 --- a/spec/highlight_spec.lua +++ b/spec/highlight_spec.lua @@ -4,7 +4,7 @@ describe('cp.highlight', function() before_each(function() spec_helper.setup() - highlight = require('cp.highlight') + highlight = require('cp.ui.highlight') end) after_each(function() diff --git a/spec/picker_spec.lua b/spec/picker_spec.lua index ab4d36e..92b32a2 100644 --- a/spec/picker_spec.lua +++ b/spec/picker_spec.lua @@ -194,6 +194,10 @@ describe('cp.picker', function() picker.setup_problem('codeforces', '1951', 'a') + vim.wait(100, function() + return called_with ~= nil + end) + assert.is_table(called_with) assert.is_table(called_with.fargs) assert.equals('codeforces', called_with.fargs[1]) diff --git a/spec/run_render_spec.lua b/spec/run_render_spec.lua index bcb0c78..a647331 100644 --- a/spec/run_render_spec.lua +++ b/spec/run_render_spec.lua @@ -1,5 +1,5 @@ describe('cp.run_render', function() - local run_render = require('cp.run_render') + local run_render = require('cp.runner.run_render') local spec_helper = require('spec.spec_helper') before_each(function() diff --git a/spec/run_spec.lua b/spec/run_spec.lua index e70f4f7..f7eb772 100644 --- a/spec/run_spec.lua +++ b/spec/run_spec.lua @@ -1,5 +1,5 @@ describe('run module', function() - local run = require('cp.run') + local run = require('cp.runner.run') describe('basic functionality', function() it('has required functions', function() diff --git a/tests/scrapers/test_atcoder.py b/tests/scrapers/test_atcoder.py index a2a88e5..dcde406 100644 --- a/tests/scrapers/test_atcoder.py +++ b/tests/scrapers/test_atcoder.py @@ -101,12 +101,12 @@ def test_scrape_contests_success(mocker): assert result[0] == ContestSummary( id="abc350", name="AtCoder Beginner Contest 350", - display_name="Beginner Contest 350 (ABC)", + display_name="AtCoder Beginner Contest 350", ) assert result[1] == ContestSummary( id="arc170", name="AtCoder Regular Contest 170", - display_name="Regular Contest 170 (ARC)", + display_name="AtCoder Regular Contest 170", ) diff --git a/tests/scrapers/test_codeforces.py b/tests/scrapers/test_codeforces.py index b95a489..14b263c 100644 --- a/tests/scrapers/test_codeforces.py +++ b/tests/scrapers/test_codeforces.py @@ -77,15 +77,17 @@ def test_scrape_contests_success(mocker): assert result[0] == ContestSummary( id="1951", name="Educational Codeforces Round 168 (Rated for Div. 2)", - display_name="Educational Round 168", + display_name="Educational Codeforces Round 168 (Rated for Div. 2)", ) assert result[1] == ContestSummary( id="1950", name="Codeforces Round 936 (Div. 2)", - display_name="Round 936 (Div. 2)", + display_name="Codeforces Round 936 (Div. 2)", ) assert result[2] == ContestSummary( - id="1949", name="Codeforces Global Round 26", display_name="Global Round 26" + id="1949", + name="Codeforces Global Round 26", + display_name="Codeforces Global Round 26", )