diff --git a/lua/cp/init.lua b/lua/cp/init.lua index 8abfcf4..4c7d0c2 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -1,27 +1,35 @@ local M = {} -local config = {} -local config_module, snippets, execute, scrape, window, logger, problem, cache +local config_module = require("cp.config") +local snippets = require("cp.snippets") +local execute = require("cp.execute") +local scrape = require("cp.scrape") +local window = require("cp.window") +local logger = require("cp.log") +local problem = require("cp.problem") +local cache = require("cp.cache") -local function lazy_require() - if not config_module then - if not vim.fn.has("nvim-0.10.0") then - vim.notify("[cp.nvim]: requires nvim-0.10.0+", vim.log.levels.ERROR) - return false - end - - config_module = require("cp.config") - snippets = require("cp.snippets") - execute = require("cp.execute") - scrape = require("cp.scrape") - window = require("cp.window") - logger = require("cp.log") - problem = require("cp.problem") - cache = require("cp.cache") - end - return true +if not vim.fn.has("nvim-0.10.0") then + vim.notify("[cp.nvim]: requires nvim-0.10.0+", vim.log.levels.ERROR) + return {} end +vim.g.cp = vim.g.cp or {} +local user_config = {} +local config = config_module.setup(user_config) +logger.set_config(config) +local snippets_initialized = false + +local state = { + platform = nil, + contest_id = nil, + problem_id = nil, + diff_mode = false, + saved_layout = nil, + saved_session = nil, + temp_output = nil, +} + local platforms = { "atcoder", "codeforces", "cses" } local actions = { "run", "debug", "diff", "next", "prev" } @@ -31,8 +39,7 @@ local function set_platform(platform) return false end - vim.g.cp = vim.g.cp or {} - vim.g.cp.platform = platform + state.platform = platform vim.fn.mkdir("build", "p") vim.fn.mkdir("io", "p") return true @@ -41,12 +48,12 @@ end ---@param contest_id string ---@param problem_id? string local function setup_problem(contest_id, problem_id) - if not vim.g.cp or not vim.g.cp.platform then + if not state.platform then logger.log("no platform set. run :CP first", vim.log.levels.ERROR) return end - local metadata_result = scrape.scrape_contest_metadata(vim.g.cp.platform, contest_id) + local metadata_result = scrape.scrape_contest_metadata(state.platform, contest_id) if not metadata_result.success then logger.log( "failed to load contest metadata: " .. (metadata_result.error or "unknown error"), @@ -54,25 +61,25 @@ local function setup_problem(contest_id, problem_id) ) end - if vim.g.cp and vim.g.cp.diff_mode then + if state.diff_mode then vim.cmd.diffoff() - if vim.g.cp.saved_session then - vim.fn.delete(vim.g.cp.saved_session) - vim.g.cp.saved_session = nil + if state.saved_session then + vim.fn.delete(state.saved_session) + state.saved_session = nil end - if vim.g.cp.temp_output then - vim.fn.delete(vim.g.cp.temp_output) - vim.g.cp.temp_output = nil + if state.temp_output then + vim.fn.delete(state.temp_output) + state.temp_output = nil end - vim.g.cp.diff_mode = false + state.diff_mode = false end vim.cmd("silent only") - vim.g.cp.contest_id = contest_id - vim.g.cp.problem_id = problem_id + state.contest_id = contest_id + state.problem_id = problem_id - local ctx = problem.create_context(vim.g.cp.platform, contest_id, problem_id, config) + local ctx = problem.create_context(state.platform, contest_id, problem_id, config) local scrape_result = scrape.scrape_problem(ctx) @@ -89,8 +96,8 @@ local function setup_problem(contest_id, problem_id) if vim.api.nvim_buf_get_lines(0, 0, -1, true)[1] == "" then local has_luasnip, luasnip = pcall(require, "luasnip") if has_luasnip then - vim.api.nvim_buf_set_lines(0, 0, -1, false, { vim.g.cp.platform }) - vim.api.nvim_win_set_cursor(0, { 1, #vim.g.cp.platform }) + vim.api.nvim_buf_set_lines(0, 0, -1, false, { state.platform }) + vim.api.nvim_win_set_cursor(0, { 1, #state.platform }) vim.cmd.startinsert({ bang = true }) vim.schedule(function() @@ -100,7 +107,7 @@ local function setup_problem(contest_id, problem_id) vim.cmd.stopinsert() end) else - vim.api.nvim_input(("i%s"):format(vim.g.cp.platform)) + vim.api.nvim_input(("i%s"):format(state.platform)) end end @@ -132,6 +139,7 @@ local function get_current_problem() end local function run_problem() + local problem_id = get_current_problem() if not problem_id then return @@ -141,21 +149,22 @@ local function run_problem() config.hooks.before_run(problem_id) end - if not vim.g.cp or not vim.g.cp.platform then + if not state.platform then logger.log("no platform set", vim.log.levels.ERROR) return end - local contest_config = config.contests[vim.g.cp.platform] + local contest_config = config.contests[state.platform] vim.schedule(function() - local ctx = problem.create_context(vim.g.cp.platform, vim.g.cp.contest_id, vim.g.cp.problem_id, config) + local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) execute.run_problem(ctx, contest_config, false) vim.cmd.checktime() end) end local function debug_problem() + local problem_id = get_current_problem() if not problem_id then return @@ -165,26 +174,27 @@ local function debug_problem() config.hooks.before_debug(problem_id) end - if not vim.g.cp or not vim.g.cp.platform then + if not state.platform then logger.log("no platform set", vim.log.levels.ERROR) return end - local contest_config = config.contests[vim.g.cp.platform] + local contest_config = config.contests[state.platform] vim.schedule(function() - local ctx = problem.create_context(vim.g.cp.platform, vim.g.cp.contest_id, vim.g.cp.problem_id, config) + local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) execute.run_problem(ctx, contest_config, true) vim.cmd.checktime() end) end local function diff_problem() - if vim.g.cp and vim.g.cp.diff_mode then + + if state.diff_mode then local tile_fn = config.tile or window.default_tile - window.restore_layout(vim.g.cp.saved_layout, tile_fn) - vim.g.cp.diff_mode = false - vim.g.cp.saved_layout = nil + window.restore_layout(state.saved_layout, tile_fn) + state.diff_mode = false + state.saved_layout = nil logger.log("exited diff mode") else local problem_id = get_current_problem() @@ -192,35 +202,35 @@ local function diff_problem() return end - local ctx = problem.create_context(vim.g.cp.platform, vim.g.cp.contest_id, vim.g.cp.problem_id, config) + local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config) if vim.fn.filereadable(ctx.expected_file) == 0 then logger.log(("No expected output file found: %s"):format(ctx.expected_file), vim.log.levels.ERROR) return end - vim.g.cp = vim.g.cp or {} - vim.g.cp.saved_layout = window.save_layout() + state.saved_layout = window.save_layout() local result = vim.system({ "awk", "/^\\[[^]]*\\]:/ {exit} {print}", ctx.output_file }, { text = true }):wait() local actual_output = result.stdout window.setup_diff_layout(actual_output, ctx.expected_file, ctx.input_file) - vim.g.cp.diff_mode = true + state.diff_mode = true logger.log("entered diff mode") end end ---@param delta number 1 for next, -1 for prev local function navigate_problem(delta) - if not vim.g.cp or not vim.g.cp.platform or not vim.g.cp.contest_id then + + if not state.platform or not state.contest_id then logger.log("no contest set. run :CP first", vim.log.levels.ERROR) return end cache.load() - local contest_data = cache.get_contest_data(vim.g.cp.platform, vim.g.cp.contest_id) + local contest_data = cache.get_contest_data(state.platform, state.contest_id) if not contest_data or not contest_data.problems then logger.log("no contest metadata found. set up a problem first to cache contest data", vim.log.levels.ERROR) return @@ -229,10 +239,10 @@ local function navigate_problem(delta) local problems = contest_data.problems local current_problem_id - if vim.g.cp.platform == "cses" then - current_problem_id = vim.g.cp.contest_id + if state.platform == "cses" then + current_problem_id = state.contest_id else - current_problem_id = vim.g.cp.problem_id + current_problem_id = state.problem_id end if not current_problem_id then @@ -263,28 +273,13 @@ local function navigate_problem(delta) local new_problem = problems[new_index] - if vim.g.cp.platform == "cses" then + if state.platform == "cses" then setup_problem(new_problem.id) else - setup_problem(vim.g.cp.contest_id, new_problem.id) + setup_problem(state.contest_id, new_problem.id) end end -local function ensure_initialized() - if config then - return - end - - if not lazy_require() then - return - end - - vim.g.cp = vim.g.cp or {} - config = config_module.setup(vim.g.cp.config) - logger.set_config(config) - snippets.setup(config) -end - local function parse_command(args) if #args == 0 then return { type = "error", message = "Usage: :CP [problem] | :CP | :CP " } @@ -300,7 +295,11 @@ local function parse_command(args) if #args == 1 then return { type = "platform_only", platform = first } elseif #args == 2 then - return { type = "contest_setup", platform = first, contest = args[2] } + if first == "cses" then + return { type = "cses_problem", platform = first, problem = args[2] } + else + return { type = "contest_setup", platform = first, contest = args[2] } + end elseif #args == 3 then return { type = "full_setup", platform = first, contest = args[2], problem = args[3] } else @@ -308,7 +307,7 @@ local function parse_command(args) end end - if vim.g.cp and vim.g.cp.platform and vim.g.cp.contest_id then + if state.platform and state.contest_id then return { type = "problem_switch", problem = first } end @@ -316,7 +315,6 @@ local function parse_command(args) end function M.handle_command(opts) - ensure_initialized() local cmd = parse_command(opts.fargs) if cmd.type == "error" then @@ -346,7 +344,7 @@ function M.handle_command(opts) if cmd.type == "contest_setup" then if set_platform(cmd.platform) then - vim.g.cp.contest_id = cmd.contest + state.contest_id = cmd.contest local metadata_result = scrape.scrape_contest_metadata(cmd.platform, cmd.contest) if not metadata_result.success then logger.log( @@ -364,24 +362,62 @@ function M.handle_command(opts) if cmd.type == "full_setup" then if set_platform(cmd.platform) then - vim.g.cp.contest_id = cmd.contest + state.contest_id = cmd.contest + local metadata_result = scrape.scrape_contest_metadata(cmd.platform, cmd.contest) + if not metadata_result.success then + logger.log( + "failed to load contest metadata: " .. (metadata_result.error or "unknown error"), + vim.log.levels.WARN + ) + else + logger.log( + ("loaded %d problems for %s %s"):format(#metadata_result.problems, cmd.platform, cmd.contest) + ) + end + setup_problem(cmd.contest, cmd.problem) end return end + if cmd.type == "cses_problem" then + if set_platform(cmd.platform) then + local metadata_result = scrape.scrape_contest_metadata(cmd.platform, "") + if not metadata_result.success then + logger.log( + "failed to load contest metadata: " .. (metadata_result.error or "unknown error"), + vim.log.levels.WARN + ) + end + setup_problem(cmd.problem) + end + return + end + if cmd.type == "problem_switch" then - if vim.g.cp.platform == "cses" then + if state.platform == "cses" then setup_problem(cmd.problem) else - setup_problem(vim.g.cp.contest_id, cmd.problem) + setup_problem(state.contest_id, cmd.problem) end return end end + +function M.setup(opts) + opts = opts or {} + user_config = opts + config = config_module.setup(user_config) + logger.set_config(config) + if not snippets_initialized then + snippets.setup(config) + snippets_initialized = true + end +end + function M.is_initialized() - return config ~= nil + return true end return M diff --git a/lua/cp/scrape.lua b/lua/cp/scrape.lua index c2436e4..09ba5c2 100644 --- a/lua/cp/scrape.lua +++ b/lua/cp/scrape.lua @@ -59,7 +59,13 @@ function M.scrape_contest_metadata(platform, contest_id) local plugin_path = get_plugin_path() local scraper_path = plugin_path .. "/scrapers/" .. platform .. ".py" - local args = { "uv", "run", scraper_path, "metadata", contest_id } + + local args + if platform == "cses" then + args = { "uv", "run", scraper_path, "metadata" } + else + args = { "uv", "run", scraper_path, "metadata", contest_id } + end local result = vim.system(args, { cwd = plugin_path, diff --git a/lua/cp/snippets.lua b/lua/cp/snippets.lua index da8d91f..fbddfff 100644 --- a/lua/cp/snippets.lua +++ b/lua/cp/snippets.lua @@ -63,6 +63,24 @@ int main() {{ solve(); #endif + return 0; +}}]], + { i(1) } + ) + ), + + s( + "cses", + fmt( + [[#include + +using namespace std; + +int main() {{ + std::cin.tie(nullptr)->sync_with_stdio(false); + + {} + return 0; }}]], { i(1) }