diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 5c56ef8..86d806f 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -92,6 +92,26 @@ function M.get_contest_data(platform, contest_id) return cache_data[platform][contest_id] end +---Get all cached contest IDs for a platform +---@param platform string +---@return string[] +function M.get_cached_contest_ids(platform) + vim.validate({ + platform = { platform, 'string' }, + }) + + if not cache_data[platform] then + return {} + end + + local contest_ids = {} + for contest_id, _ in pairs(cache_data[platform]) do + table.insert(contest_ids, contest_id) + end + table.sort(contest_ids) + return contest_ids +end + ---@param platform string ---@param contest_id string ---@param problems Problem[] diff --git a/lua/cp/commands/init.lua b/lua/cp/commands/init.lua index fa4f658..f48727d 100644 --- a/lua/cp/commands/init.lua +++ b/lua/cp/commands/init.lua @@ -59,16 +59,15 @@ local function parse_command(args) local debug = false local test_index = nil - for i = 2, #args do - local arg = args[i] - if arg == '--debug' then + if #args == 2 then + if args[2] == '--debug' then debug = true else - local idx = tonumber(arg) + local idx = tonumber(args[2]) if not idx then return { type = 'error', - message = ("Invalid argument '%s': expected test number or --debug"):format(arg), + message = ("Invalid argument '%s': expected test number or --debug"):format(args[2]), } end if idx < 1 or idx ~= math.floor(idx) then @@ -76,6 +75,30 @@ local function parse_command(args) end test_index = idx end + elseif #args == 3 then + local idx = tonumber(args[2]) + if not idx then + return { + type = 'error', + message = ("Invalid argument '%s': expected test number"):format(args[2]), + } + end + if idx < 1 or idx ~= math.floor(idx) then + return { type = 'error', message = ("'%s' is not a valid test index"):format(idx) } + end + if args[3] ~= '--debug' then + return { + type = 'error', + message = ("Invalid argument '%s': expected --debug"):format(args[3]), + } + end + test_index = idx + debug = true + elseif #args > 3 then + return { + type = 'error', + message = 'Too many arguments. Usage: :CP ' .. first .. ' [test_num] [--debug]', + } end return { type = 'action', action = first, test_index = test_index, debug = debug } diff --git a/plugin/cp.lua b/plugin/cp.lua index 5d4df32..b4954a4 100644 --- a/plugin/cp.lua +++ b/plugin/cp.lua @@ -22,12 +22,30 @@ end, { num_args = num_args + 1 end + local function filter_candidates(candidates) + return vim.tbl_filter(function(cmd) + return cmd:find(ArgLead, 1, true) == 1 + end, candidates) + end + + local function get_enabled_languages(platform) + local config = require('cp.config').get_config() + if platform and config.platforms[platform] then + return config.platforms[platform].enabled_languages + end + return vim.tbl_keys(config.languages) + end + if num_args == 2 then local candidates = {} local state = require('cp.state') local platform = state.get_platform() local contest_id = state.get_contest_id() + vim.list_extend(candidates, platforms) + table.insert(candidates, 'cache') + table.insert(candidates, 'pick') + if platform and contest_id then vim.list_extend(candidates, actions) local cache = require('cp.cache') @@ -39,44 +57,75 @@ end, { table.sort(ids) vim.list_extend(candidates, ids) end - else - vim.list_extend(candidates, platforms) - table.insert(candidates, 'cache') - table.insert(candidates, 'pick') end - return vim.tbl_filter(function(cmd) - return cmd:find(ArgLead, 1, true) == 1 - end, candidates) + return filter_candidates(candidates) elseif num_args == 3 then - if args[2] == 'cache' then - return vim.tbl_filter(function(cmd) - return cmd:find(ArgLead, 1, true) == 1 - end, { 'clear', 'read' }) + if vim.tbl_contains(platforms, args[2]) then + local cache = require('cp.cache') + cache.load() + local contests = cache.get_cached_contest_ids(args[2]) + return filter_candidates(contests) + elseif args[2] == 'cache' then + return filter_candidates({ 'clear', 'read' }) elseif args[2] == 'interact' then - local cands = utils.cwd_executables() - return vim.tbl_filter(function(cmd) - return cmd:find(ArgLead, 1, true) == 1 - end, cands) + return filter_candidates(utils.cwd_executables()) + elseif args[2] == 'run' or args[2] == 'panel' then + local state = require('cp.state') + local platform = state.get_platform() + local contest_id = state.get_contest_id() + local problem_id = state.get_problem_id() + local candidates = { '--debug' } + if platform and contest_id and problem_id then + local cache = require('cp.cache') + cache.load() + local test_cases = cache.get_test_cases(platform, contest_id, problem_id) + if test_cases then + for i = 1, #test_cases do + table.insert(candidates, tostring(i)) + end + end + end + return filter_candidates(candidates) + elseif args[2] == 'next' or args[2] == 'prev' or args[2] == 'pick' then + return filter_candidates({ '--lang' }) + else + local state = require('cp.state') + if state.get_platform() and state.get_contest_id() then + return filter_candidates({ '--lang' }) + end end elseif num_args == 4 then if args[2] == 'cache' and args[3] == 'clear' then - return vim.tbl_filter(function(cmd) - return cmd:find(ArgLead, 1, true) == 1 - end, platforms) + return filter_candidates(platforms) + elseif args[3] == '--lang' then + local platform = require('cp.state').get_platform() + return filter_candidates(get_enabled_languages(platform)) + elseif (args[2] == 'run' or args[2] == 'panel') and tonumber(args[3]) then + return filter_candidates({ '--debug' }) elseif vim.tbl_contains(platforms, args[2]) then local cache = require('cp.cache') cache.load() local contest_data = cache.get_contest_data(args[2], args[3]) + local candidates = { '--lang' } if contest_data and contest_data.problems then - local candidates = {} for _, problem in ipairs(contest_data.problems) do table.insert(candidates, problem.id) end - return vim.tbl_filter(function(cmd) - return cmd:find(ArgLead, 1, true) == 1 - end, candidates) end + return filter_candidates(candidates) + end + elseif num_args == 5 then + if vim.tbl_contains(platforms, args[2]) then + if args[4] == '--lang' then + return filter_candidates(get_enabled_languages(args[2])) + else + return filter_candidates({ '--lang' }) + end + end + elseif num_args == 6 then + if vim.tbl_contains(platforms, args[2]) and args[5] == '--lang' then + return filter_candidates(get_enabled_languages(args[2])) end end return {}