From 6ae94887619f7722abbf4d2304eb21760f2339f1 Mon Sep 17 00:00:00 2001 From: Barrett Ruth Date: Sun, 5 Oct 2025 23:55:23 -0400 Subject: [PATCH] fix: typing --- lua/cp/scraper.lua | 8 +- lua/cp/setup.lua | 211 ++++++++++++++++++++++++++++++++++----------- lua/cp/state.lua | 40 ++++++++- 3 files changed, 205 insertions(+), 54 deletions(-) diff --git a/lua/cp/scraper.lua b/lua/cp/scraper.lua index 0c3494a..f8cb817 100644 --- a/lua/cp/scraper.lua +++ b/lua/cp/scraper.lua @@ -1,6 +1,6 @@ local M = {} -local constants = require('cp.log') +local constants = require('cp.constants') local logger = require('cp.log') local utils = require('cp.utils') @@ -168,7 +168,11 @@ function M.scrape_all_tests(platform, contest_id, callback) end if ev.error and ev.problem_id then logger.log( - ("Failed to load tests for problem '%s': %s"):format(contest_id, ev.problem_id, ev.error), + ("Failed to load tests for problem '%s' in contest '%s': %s"):format( + ev.problem_id, + contest_id, + ev.error + ), vim.log.levels.WARN ) return diff --git a/lua/cp/setup.lua b/lua/cp/setup.lua index 486aad9..e05c81b 100644 --- a/lua/cp/setup.lua +++ b/lua/cp/setup.lua @@ -2,12 +2,11 @@ local M = {} local cache = require('cp.cache') local config_module = require('cp.config') +local constants = require('cp.constants') local logger = require('cp.log') local scraper = require('cp.scraper') local state = require('cp.state') -local constants = require('cp.constants') - ---@class TestCaseLite ---@field input string ---@field expected string @@ -23,6 +22,46 @@ local constants = require('cp.constants') ---@field succeeded integer|nil ---@field failed integer|nil +---@param cd table|nil +---@return boolean +local function is_metadata_ready(cd) + return cd + and type(cd.problems) == 'table' + and #cd.problems > 0 + and type(cd.index_map) == 'table' + and next(cd.index_map) ~= nil +end + +---@param platform string +---@param contest_id string +---@param problems table +local function start_tests(platform, contest_id, problems) + local cached_len = #vim.tbl_filter(function(p) + return not vim.tbl_isempty(cache.get_test_cases(platform, contest_id, p.id)) + end, problems) + if cached_len ~= #problems then + logger.log(('Fetching test cases... (%d/%d)'):format(cached_len, #problems)) + scraper.scrape_all_tests(platform, contest_id, function(ev) + local cached_tests = {} + if not ev.interactive and vim.tbl_isempty(ev.tests) then + logger.log(("No tests found for problem '%s'."):format(ev.problem_id), vim.log.levels.WARN) + end + for i, t in ipairs(ev.tests) do + cached_tests[i] = { index = i, input = t.input, expected = t.expected } + end + cache.set_test_cases( + platform, + contest_id, + ev.problem_id, + cached_tests, + ev.timeout_ms or 0, + ev.memory_mb or 0, + ev.interactive + ) + end) + end +end + ---@param platform string ---@param contest_id string ---@param problem_id? string @@ -34,49 +73,86 @@ function M.setup_contest(platform, contest_id, problem_id, language) local function proceed(contest_data) local problems = contest_data.problems - local pid = problems[(problem_id and contest_data.index_map[problem_id] or 1)].id + local pid = problem_id and problem_id or problems[1].id M.setup_problem(pid, language) - - local cached_len = #vim.tbl_filter(function(p) - return not vim.tbl_isempty(cache.get_test_cases(platform, contest_id, p.id)) - end, problems) - - if cached_len ~= #problems then - logger.log(('Fetching test cases...'):format(cached_len, #problems)) - scraper.scrape_all_tests(platform, contest_id, function(ev) - local cached_tests = {} - if not ev.interactive and vim.tbl_isempty(ev.tests) then - logger.log( - ("No tests found for problem '%s'."):format(ev.problem_id), - vim.log.levels.WARN - ) - end - for i, t in ipairs(ev.tests) do - cached_tests[i] = { index = i, input = t.input, expected = t.expected } - end - cache.set_test_cases( - platform, - contest_id, - ev.problem_id, - cached_tests, - ev.timeout_ms or 0, - ev.memory_mb or 0, - ev.interactive - ) - logger.log('Test cases loaded.') - end) - end + start_tests(platform, contest_id, problems) end local contest_data = cache.get_contest_data(platform, contest_id) - if not contest_data or not contest_data.problems then + if not is_metadata_ready(contest_data) then + local cfg = config_module.get_config() + local lang = language or (cfg.platforms[platform] and cfg.platforms[platform].default_language) + + vim.cmd.only({ mods = { silent = true } }) + local bufnr = vim.api.nvim_create_buf(true, false) + vim.api.nvim_win_set_buf(0, bufnr) + if lang then + vim.bo[bufnr].filetype = lang + end + vim.bo[bufnr].buftype = '' + + local ext = cfg.runtime + and cfg.runtime.effective[platform] + and cfg.runtime.effective[platform][lang] + and cfg.runtime.effective[platform][lang].extension + local provisional_name = nil + if ext then + provisional_name = (config_module.default_filename(contest_id) .. '.' .. ext) + vim.api.nvim_buf_set_name(bufnr, provisional_name) + end + + if cfg.hooks and cfg.hooks.setup_code and not vim.b[bufnr].cp_setup_done then + local ok = pcall(cfg.hooks.setup_code, state) + if ok then + vim.b[bufnr].cp_setup_done = true + end + end + + if provisional_name then + cache.set_file_state( + vim.fn.fnamemodify(provisional_name, ':p'), + platform, + contest_id, + '', + lang + ) + end + + state.set_provisional({ + bufnr = bufnr, + platform = platform, + contest_id = contest_id, + language = lang, + requested_problem_id = problem_id, + token = vim.loop.hrtime(), + }) + logger.log('Fetching contests problems...', vim.log.levels.INFO, true) - scraper.scrape_contest_metadata(platform, contest_id, function(result) - local problems = result.problems or {} - cache.set_contest_data(platform, contest_id, problems) - logger.log(('Found %d problems for %s contest %s.'):format(#problems, platform, contest_id)) - proceed(cache.get_contest_data(platform, contest_id)) - end) + scraper.scrape_contest_metadata( + platform, + contest_id, + vim.schedule_wrap(function(result) + local problems = result.problems or {} + cache.set_contest_data(platform, contest_id, problems) + local prov = state.get_provisional() + if not prov or prov.platform ~= platform or prov.contest_id ~= contest_id then + return + end + local cd = cache.get_contest_data(platform, contest_id) + if not is_metadata_ready(cd) then + return + end + local pid = prov.requested_problem_id + if not pid or not cd.index_map or not cd.index_map[pid] then + pid = cd.problems[1] and cd.problems[1].id or nil + end + if not pid then + return + end + M.setup_problem(pid, prov.language) + start_tests(platform, contest_id, cd.problems) + end) + ) return end @@ -88,25 +164,58 @@ end function M.setup_problem(problem_id, language) local platform = state.get_platform() if not platform then - logger.log('No platform set.', vim.log.levels.ERROR) return end state.set_problem_id(problem_id) - local config = config_module.get_config() + local lang = language + or (config.platforms[platform] and config.platforms[platform].default_language) + local source_file = state.get_source_file(lang) + if not source_file then + return + end + + local prov = state.get_provisional() + if prov and prov.platform == platform and prov.contest_id == (state.get_contest_id() or '') then + if vim.api.nvim_buf_is_valid(prov.bufnr) then + local old = vim.api.nvim_buf_get_name(prov.bufnr) + local new = source_file + if old ~= '' and old ~= new then + local st = vim.loop.fs_stat(old) + if st and st.type == 'file' then + pcall(vim.loop.fs_rename, old, new) + end + end + vim.api.nvim_buf_set_name(prov.bufnr, new) + if config.hooks and config.hooks.setup_code and not vim.b[prov.bufnr].cp_setup_done then + local ok = pcall(config.hooks.setup_code, state) + if ok then + vim.b[prov.bufnr].cp_setup_done = true + end + end + cache.set_file_state( + vim.fn.fnamemodify(new, ':p'), + platform, + state.get_contest_id() or '', + state.get_problem_id() or '', + lang + ) + end + state.set_provisional(nil) + return + end vim.schedule(function() vim.cmd.only({ mods = { silent = true } }) - - local lang = language or config.platforms[platform].default_language - local source_file = state.get_source_file(lang) vim.cmd.e(source_file) - - if config.hooks and config.hooks.setup_code then - config.hooks.setup_code(state) + local bufnr = vim.api.nvim_get_current_buf() + if config.hooks and config.hooks.setup_code and not vim.b[bufnr].cp_setup_done then + local ok = pcall(config.hooks.setup_code, state) + if ok then + vim.b[bufnr].cp_setup_done = true + end end - cache.set_file_state( vim.fn.expand('%:p'), platform, @@ -117,6 +226,7 @@ function M.setup_problem(problem_id, language) end) end +---@param direction integer function M.navigate_problem(direction) if direction == 0 then return @@ -126,7 +236,6 @@ function M.navigate_problem(direction) local platform = state.get_platform() local contest_id = state.get_contest_id() local current_problem_id = state.get_problem_id() - if not platform or not contest_id or not current_problem_id then logger.log('No platform configured.', vim.log.levels.ERROR) return @@ -134,7 +243,7 @@ function M.navigate_problem(direction) cache.load() local contest_data = cache.get_contest_data(platform, contest_id) - if not contest_data or not contest_data.problems then + if not is_metadata_ready(contest_data) then logger.log( ('No data available for %s contest %s.'):format( constants.PLATFORM_DISPLAY_NAMES[platform], diff --git a/lua/cp/state.lua b/lua/cp/state.lua index e228212..9172fe4 100644 --- a/lua/cp/state.lua +++ b/lua/cp/state.lua @@ -1,3 +1,11 @@ +---@class cp.ProvisionalState +---@field bufnr integer +---@field platform string +---@field contest_id string +---@field language string +---@field requested_problem_id string|nil +---@field token integer + ---@class cp.State ---@field get_platform fun(): string? ---@field set_platform fun(platform: string) @@ -6,16 +14,19 @@ ---@field get_problem_id fun(): string? ---@field set_problem_id fun(problem_id: string) ---@field get_active_panel fun(): string? ----@field set_active_panel fun(): string? +---@field set_active_panel fun(panel: string?) ---@field get_base_name fun(): string? ---@field get_source_file fun(language?: string): string? ---@field get_binary_file fun(): string? ---@field get_input_file fun(): string? ---@field get_output_file fun(): string? ---@field get_expected_file fun(): string? +---@field get_provisional fun(): cp.ProvisionalState|nil +---@field set_provisional fun(p: cp.ProvisionalState|nil) local M = {} +---@type table local state = { platform = nil, contest_id = nil, @@ -23,32 +34,40 @@ local state = { test_cases = nil, saved_session = nil, active_panel = nil, + provisional = nil, } +---@return string|nil function M.get_platform() return state.platform end +---@param platform string function M.set_platform(platform) state.platform = platform end +---@return string|nil function M.get_contest_id() return state.contest_id end +---@param contest_id string function M.set_contest_id(contest_id) state.contest_id = contest_id end +---@return string|nil function M.get_problem_id() return state.problem_id end +---@param problem_id string function M.set_problem_id(problem_id) state.problem_id = problem_id end +---@return string|nil function M.get_base_name() local platform, contest_id, problem_id = M.get_platform(), M.get_contest_id(), M.get_problem_id() if not platform or not contest_id or not problem_id then @@ -65,10 +84,13 @@ function M.get_base_name() end end +---@return string|nil function M.get_language() return end +---@param language? string +---@return string|nil function M.get_source_file(language) local base_name = M.get_base_name() if not base_name or not M.get_platform() then @@ -90,34 +112,50 @@ function M.get_source_file(language) return base_name .. '.' .. eff.extension end +---@return string|nil function M.get_binary_file() local base_name = M.get_base_name() return base_name and ('build/%s.run'):format(base_name) or nil end +---@return string|nil function M.get_input_file() local base_name = M.get_base_name() return base_name and ('io/%s.cpin'):format(base_name) or nil end +---@return string|nil function M.get_output_file() local base_name = M.get_base_name() return base_name and ('io/%s.cpout'):format(base_name) or nil end +---@return string|nil function M.get_expected_file() local base_name = M.get_base_name() return base_name and ('io/%s.expected'):format(base_name) or nil end +---@return string|nil function M.get_active_panel() return state.active_panel end +---@param panel string|nil function M.set_active_panel(panel) state.active_panel = panel end +---@return cp.ProvisionalState|nil +function M.get_provisional() + return state.provisional +end + +---@param p cp.ProvisionalState|nil +function M.set_provisional(p) + state.provisional = p +end + M._state = state return M