diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 81ce011..544aaf4 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -191,12 +191,6 @@ end ---@param problem_id? string ---@return CombinedTest? function M.get_combined_test(platform, contest_id, problem_id) - vim.validate({ - platform = { platform, 'string' }, - contest_id = { contest_id, 'string' }, - problem_id = { problem_id, { 'string', 'nil' }, true }, - }) - if not cache_data[platform] or not cache_data[platform][contest_id] diff --git a/lua/cp/setup.lua b/lua/cp/setup.lua index 8939395..d05b417 100644 --- a/lua/cp/setup.lua +++ b/lua/cp/setup.lua @@ -216,7 +216,16 @@ function M.setup_problem(problem_id, language) return end + local old_problem_id = state.get_problem_id() state.set_problem_id(problem_id) + + if old_problem_id ~= problem_id then + local io_state = state.get_io_view_state() + if io_state and io_state.output_buf and vim.api.nvim_buf_is_valid(io_state.output_buf) then + local utils = require('cp.utils') + utils.update_buffer_content(io_state.output_buf, {}, nil, nil) + end + end local config = config_module.get_config() local lang = language or (config.platforms[platform] and config.platforms[platform].default_language) @@ -369,6 +378,12 @@ function M.navigate_problem(direction, language) end end + local io_state = state.get_io_view_state() + if io_state and io_state.output_buf and vim.api.nvim_buf_is_valid(io_state.output_buf) then + local utils = require('cp.utils') + utils.update_buffer_content(io_state.output_buf, {}, nil, nil) + end + M.setup_contest(platform, contest_id, problems[new_index].id, lang) end diff --git a/lua/cp/state.lua b/lua/cp/state.lua index 621b184..40eed86 100644 --- a/lua/cp/state.lua +++ b/lua/cp/state.lua @@ -9,8 +9,6 @@ ---@class cp.IoViewState ---@field output_buf integer ---@field input_buf integer ----@field output_win integer ----@field input_win integer ---@field current_test_index integer? ---@class cp.State @@ -200,19 +198,7 @@ end ---@return cp.IoViewState? function M.get_io_view_state() - if not state.io_view_state then - return nil - end - local s = state.io_view_state - if - vim.api.nvim_buf_is_valid(s.output_buf) - and vim.api.nvim_buf_is_valid(s.input_buf) - and vim.api.nvim_win_is_valid(s.output_win) - and vim.api.nvim_win_is_valid(s.input_win) - then - return s - end - return nil + return state.io_view_state end ---@param s cp.IoViewState? diff --git a/lua/cp/ui/views.lua b/lua/cp/ui/views.lua index a52779e..d541dec 100644 --- a/lua/cp/ui/views.lua +++ b/lua/cp/ui/views.lua @@ -193,6 +193,136 @@ function M.toggle_interactive(interactor_cmd) state.set_active_panel('interactive') end +---@return integer, integer +local function get_or_create_io_buffers() + local io_state = state.get_io_view_state() + + if io_state then + local output_valid = io_state.output_buf and vim.api.nvim_buf_is_valid(io_state.output_buf) + local input_valid = io_state.input_buf and vim.api.nvim_buf_is_valid(io_state.input_buf) + + if output_valid and input_valid then + return io_state.output_buf, io_state.input_buf + end + end + + local output_buf = utils.create_buffer_with_options('cpout') + local input_buf = utils.create_buffer_with_options('cpin') + + state.set_io_view_state({ + output_buf = output_buf, + input_buf = input_buf, + current_test_index = 1, + }) + + local solution_win = state.get_solution_win() + local source_buf = vim.api.nvim_win_get_buf(solution_win) + + local group_name = 'cp_io_cleanup_buf' .. source_buf + vim.api.nvim_create_augroup(group_name, { clear = true }) + vim.api.nvim_create_autocmd('BufDelete', { + group = group_name, + buffer = source_buf, + callback = function() + local io = state.get_io_view_state() + if io then + if io.output_buf and vim.api.nvim_buf_is_valid(io.output_buf) then + vim.api.nvim_buf_delete(io.output_buf, { force = true }) + end + if io.input_buf and vim.api.nvim_buf_is_valid(io.input_buf) then + vim.api.nvim_buf_delete(io.input_buf, { force = true }) + end + state.set_io_view_state(nil) + end + end, + }) + + local cfg = config_module.get_config() + local platform = state.get_platform() + local contest_id = state.get_contest_id() + local problem_id = state.get_problem_id() + + local function navigate_test(delta) + local io_view_state = state.get_io_view_state() + if not io_view_state then + return + end + if not platform or not contest_id or not problem_id then + return + end + local test_cases = cache.get_test_cases(platform, contest_id, problem_id) + if not test_cases or #test_cases == 0 then + return + end + local new_index = (io_view_state.current_test_index or 1) + delta + if new_index < 1 or new_index > #test_cases then + return + end + io_view_state.current_test_index = new_index + M.run_io_view(new_index) + end + + if cfg.ui.run.next_test_key then + vim.keymap.set('n', cfg.ui.run.next_test_key, function() + navigate_test(1) + end, { buffer = output_buf, silent = true, desc = 'Next test' }) + vim.keymap.set('n', cfg.ui.run.next_test_key, function() + navigate_test(1) + end, { buffer = input_buf, silent = true, desc = 'Next test' }) + end + + if cfg.ui.run.prev_test_key then + vim.keymap.set('n', cfg.ui.run.prev_test_key, function() + navigate_test(-1) + end, { buffer = output_buf, silent = true, desc = 'Previous test' }) + vim.keymap.set('n', cfg.ui.run.prev_test_key, function() + navigate_test(-1) + end, { buffer = input_buf, silent = true, desc = 'Previous test' }) + end + + return output_buf, input_buf +end + +---@param output_buf integer +---@param input_buf integer +---@return boolean +local function buffers_are_displayed(output_buf, input_buf) + local output_displayed = false + local input_displayed = false + + for _, win in ipairs(vim.api.nvim_list_wins()) do + local buf = vim.api.nvim_win_get_buf(win) + if buf == output_buf then + output_displayed = true + end + if buf == input_buf then + input_displayed = true + end + end + + return output_displayed and input_displayed +end + +---@param output_buf integer +---@param input_buf integer +local function create_window_layout(output_buf, input_buf) + local solution_win = state.get_solution_win() + vim.api.nvim_set_current_win(solution_win) + + vim.cmd.vsplit() + local output_win = vim.api.nvim_get_current_win() + local cfg = config_module.get_config() + local width = math.floor(vim.o.columns * (cfg.ui.run.width or 0.3)) + vim.api.nvim_win_set_width(output_win, width) + vim.api.nvim_win_set_buf(output_win, output_buf) + + vim.cmd.split() + local input_win = vim.api.nvim_get_current_win() + vim.api.nvim_win_set_buf(input_win, input_buf) + + vim.api.nvim_set_current_win(solution_win) +end + function M.ensure_io_view() local platform, contest_id, problem_id = state.get_platform(), state.get_contest_id(), state.get_problem_id() @@ -204,6 +334,21 @@ function M.ensure_io_view() return end + local source_file = state.get_source_file() + if source_file then + local source_file_abs = vim.fn.fnamemodify(source_file, ':p') + for _, win in ipairs(vim.api.nvim_list_wins()) do + local buf = vim.api.nvim_win_get_buf(win) + local buf_name = vim.api.nvim_buf_get_name(buf) + if buf_name == source_file_abs then + state.set_solution_win(win) + break + end + end + else + state.set_solution_win(vim.api.nvim_get_current_win()) + end + cache.load() local contest_data = cache.get_contest_data(platform, contest_id) if @@ -215,107 +360,39 @@ function M.ensure_io_view() return end - local solution_win = state.get_solution_win() - local io_state = state.get_io_view_state() - local output_buf, input_buf, output_win, input_win + local output_buf, input_buf = get_or_create_io_buffers() - if io_state then - output_buf = io_state.output_buf - input_buf = io_state.input_buf - output_win = io_state.output_win - input_win = io_state.input_win - else - vim.api.nvim_set_current_win(solution_win) + if not buffers_are_displayed(output_buf, input_buf) then + local solution_win = state.get_solution_win() - vim.cmd.vsplit() - output_win = vim.api.nvim_get_current_win() - local cfg = config_module.get_config() - local width = math.floor(vim.o.columns * (cfg.ui.run.width or 0.3)) - vim.api.nvim_win_set_width(output_win, width) - output_buf = utils.create_buffer_with_options('cpout') - vim.api.nvim_win_set_buf(output_win, output_buf) - - vim.cmd.split() - input_win = vim.api.nvim_get_current_win() - input_buf = utils.create_buffer_with_options('cpin') - vim.api.nvim_win_set_buf(input_win, input_buf) - - state.set_io_view_state({ - output_buf = output_buf, - input_buf = input_buf, - output_win = output_win, - input_win = input_win, - current_test_index = 1, - }) - - local source_buf = vim.api.nvim_win_get_buf(solution_win) - vim.api.nvim_create_autocmd('BufDelete', { - buffer = source_buf, - callback = function() - local io = state.get_io_view_state() - if io then - if io.output_buf and vim.api.nvim_buf_is_valid(io.output_buf) then - vim.api.nvim_buf_delete(io.output_buf, { force = true }) - end - if io.input_buf and vim.api.nvim_buf_is_valid(io.input_buf) then - vim.api.nvim_buf_delete(io.input_buf, { force = true }) - end - state.set_io_view_state(nil) - end - end, - }) - - if cfg.hooks and cfg.hooks.setup_io_output then - pcall(cfg.hooks.setup_io_output, output_buf, state) - end - - if cfg.hooks and cfg.hooks.setup_io_input then - pcall(cfg.hooks.setup_io_input, input_buf, state) - end - - local function navigate_test(delta) - local io_view_state = state.get_io_view_state() - if not io_view_state then - return + for _, win in ipairs(vim.api.nvim_list_wins()) do + if win ~= solution_win then + pcall(vim.api.nvim_win_close, win, true) end - local test_cases = cache.get_test_cases(platform, contest_id, problem_id) - if not test_cases or #test_cases == 0 then - return - end - local new_index = (io_view_state.current_test_index or 1) + delta - if new_index < 1 or new_index > #test_cases then - return - end - io_view_state.current_test_index = new_index - M.run_io_view({ new_index }, false, 'individual') end - if cfg.ui.run.next_test_key then - vim.keymap.set('n', cfg.ui.run.next_test_key, function() - navigate_test(1) - end, { buffer = output_buf, silent = true, desc = 'Next test' }) - vim.keymap.set('n', cfg.ui.run.next_test_key, function() - navigate_test(1) - end, { buffer = input_buf, silent = true, desc = 'Next test' }) - end - - if cfg.ui.run.prev_test_key then - vim.keymap.set('n', cfg.ui.run.prev_test_key, function() - navigate_test(-1) - end, { buffer = output_buf, silent = true, desc = 'Previous test' }) - vim.keymap.set('n', cfg.ui.run.prev_test_key, function() - navigate_test(-1) - end, { buffer = input_buf, silent = true, desc = 'Previous test' }) - end + create_window_layout(output_buf, input_buf) end - utils.update_buffer_content(input_buf, {}) - utils.update_buffer_content(output_buf, {}) + local cfg = config_module.get_config() + + if cfg.hooks and cfg.hooks.setup_io_output then + pcall(cfg.hooks.setup_io_output, output_buf, state) + end + + if cfg.hooks and cfg.hooks.setup_io_input then + pcall(cfg.hooks.setup_io_input, input_buf, state) + end local test_cases = cache.get_test_cases(platform, contest_id, problem_id) if test_cases and #test_cases > 0 then local input_lines = {} - local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test + local is_multi_test = contest_data + and contest_data.problems + and contest_data.index_map + and contest_data.index_map[problem_id] + and contest_data.problems[contest_data.index_map[problem_id]].multi_test + or false if is_multi_test and #test_cases > 1 then table.insert(input_lines, tostring(#test_cases)) @@ -334,8 +411,6 @@ function M.ensure_io_view() end utils.update_buffer_content(input_buf, input_lines, nil, nil) end - - vim.api.nvim_set_current_win(solution_win) end function M.run_io_view(test_indices_arg, debug, mode) @@ -675,13 +750,14 @@ function M.toggle_panel(panel_opts) local io_state = state.get_io_view_state() if io_state then - if vim.api.nvim_win_is_valid(io_state.output_win) then - vim.api.nvim_win_close(io_state.output_win, true) + for _, win in ipairs(vim.api.nvim_list_wins()) do + local buf = vim.api.nvim_win_get_buf(win) + if buf == io_state.output_buf or buf == io_state.input_buf then + if vim.api.nvim_win_is_valid(win) then + vim.api.nvim_win_close(win, true) + end + end end - if vim.api.nvim_win_is_valid(io_state.input_win) then - vim.api.nvim_win_close(io_state.input_win, true) - end - state.set_io_view_state(nil) end local session_file = vim.fn.tempname() diff --git a/lua/cp/utils.lua b/lua/cp/utils.lua index e9bba54..7ef7791 100644 --- a/lua/cp/utils.lua +++ b/lua/cp/utils.lua @@ -116,7 +116,7 @@ end ---@param filetype? string function M.create_buffer_with_options(filetype) local buf = vim.api.nvim_create_buf(false, true) - vim.api.nvim_set_option_value('bufhidden', 'wipe', { buf = buf }) + vim.api.nvim_set_option_value('bufhidden', 'hide', { buf = buf }) vim.api.nvim_set_option_value('readonly', true, { buf = buf }) vim.api.nvim_set_option_value('modifiable', false, { buf = buf }) if filetype then diff --git a/scrapers/atcoder.py b/scrapers/atcoder.py index 54ec6fc..e6010bd 100644 --- a/scrapers/atcoder.py +++ b/scrapers/atcoder.py @@ -315,10 +315,8 @@ class AtcoderScraper(BaseScraper): return data = await asyncio.to_thread(_scrape_problem_page_sync, category_id, slug) tests: list[TestCase] = data.get("tests", []) - - combined_input = "\n".join(t.input for t in tests) - combined_expected = "\n".join(t.expected for t in tests) - + combined_input = "\n".join(t.input for t in tests) if tests else "" + combined_expected = "\n".join(t.expected for t in tests) if tests else "" print( json.dumps( { diff --git a/scrapers/codechef.py b/scrapers/codechef.py index 59efa01..37fd9b5 100644 --- a/scrapers/codechef.py +++ b/scrapers/codechef.py @@ -231,8 +231,10 @@ class CodeChefScraper(BaseScraper): memory_mb = 256.0 interactive = False - combined_input = "\n".join(t.input for t in tests) - combined_expected = "\n".join(t.expected for t in tests) + combined_input = "\n".join(t.input for t in tests) if tests else "" + combined_expected = ( + "\n".join(t.expected for t in tests) if tests else "" + ) return { "problem_id": problem_code, diff --git a/scrapers/cses.py b/scrapers/cses.py index ea71da8..620cb7f 100644 --- a/scrapers/cses.py +++ b/scrapers/cses.py @@ -235,8 +235,10 @@ class CSESScraper(BaseScraper): tests = [] timeout_ms, memory_mb, interactive = 0, 0, False - combined_input = "\n".join(t.input for t in tests) - combined_expected = "\n".join(t.expected for t in tests) + combined_input = "\n".join(t.input for t in tests) if tests else "" + combined_expected = ( + "\n".join(t.expected for t in tests) if tests else "" + ) return { "problem_id": pid, diff --git a/tests/test_scrapers.py b/tests/test_scrapers.py index 51f9ab2..75f3cb0 100644 --- a/tests/test_scrapers.py +++ b/tests/test_scrapers.py @@ -55,6 +55,7 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode): else: assert len(model.contests) >= 1 else: + assert len(objs) >= 1, "No test objects returned" validated_any = False for obj in objs: if "success" in obj and "tests" in obj and "problem_id" in obj: