diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index ba977a8..11471db 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -92,9 +92,7 @@ end ---@param platform string ---@param contest_id string ---@param problems Problem[] ----@param contest_name? string ----@param display_name? string -function M.set_contest_data(platform, contest_id, problems, contest_name, display_name) +function M.set_contest_data(platform, contest_id, problems) vim.validate({ platform = { platform, 'string' }, contest_id = { contest_id, 'string' }, @@ -102,9 +100,11 @@ function M.set_contest_data(platform, contest_id, problems, contest_name, displa }) cache_data[platform] = cache_data[platform] or {} + local prev = cache_data[platform][contest_id] or {} + local out = { - name = contest_name, - display_name = display_name, + name = prev.name, + display_name = prev.display_name, problems = vim.deepcopy(problems), index_map = {}, } @@ -151,7 +151,7 @@ function M.get_test_cases(platform, contest_id, problem_id) end local index = cache_data[platform][contest_id].index_map[problem_id] - return cache_data[platform][contest_id].problems[index].test_cases + return cache_data[platform][contest_id].problems[index].test_cases or {} end ---@param platform string diff --git a/lua/cp/log.lua b/lua/cp/log.lua index 9c702b4..02bc5f4 100644 --- a/lua/cp/log.lua +++ b/lua/cp/log.lua @@ -1,8 +1,9 @@ local M = {} function M.log(msg, level, override) + local debug = require('cp.config').get_config().debug or false level = level or vim.log.levels.INFO - if level >= vim.log.levels.WARN or override then + if level >= vim.log.levels.WARN or override or debug then vim.schedule(function() vim.notify(('[cp.nvim]: %s'):format(msg), level) end) diff --git a/lua/cp/pickers/init.lua b/lua/cp/pickers/init.lua index 143bb73..2e7598b 100644 --- a/lua/cp/pickers/init.lua +++ b/lua/cp/pickers/init.lua @@ -40,27 +40,29 @@ end ---@param refresh? boolean ---@return cp.ContestItem[] function M.get_platform_contests(platform, refresh) - logger.log( - ('Loading %s contests...'):format(constants.PLATFORM_DISPLAY_NAMES[platform]), - vim.log.levels.INFO, - true - ) - cache.load() local picker_contests = cache.get_contest_summaries(platform) if refresh or vim.tbl_isempty(picker_contests) then - logger.log(('Cache miss on %s contests'):format(platform)) - local contests = scraper.scrape_contest_list(platform) -- sync - cache.set_contest_summaries(platform, contests) - picker_contests = cache.get_contest_summaries(platform) -- <-- reload after write - end + logger.log( + ('Loading %s contests...'):format(constants.PLATFORM_DISPLAY_NAMES[platform]), + vim.log.levels.INFO, + true + ) - logger.log( - ('Loaded %d %s contests.'):format(#picker_contests, constants.PLATFORM_DISPLAY_NAMES[platform]), - vim.log.levels.INFO, - true - ) + local contests = scraper.scrape_contest_list(platform) + cache.set_contest_summaries(platform, contests) + picker_contests = cache.get_contest_summaries(platform) + + logger.log( + ('Loaded %d %s contests.'):format( + #picker_contests, + constants.PLATFORM_DISPLAY_NAMES[platform] + ), + vim.log.levels.INFO, + true + ) + end return picker_contests end diff --git a/lua/cp/runner/execute.lua b/lua/cp/runner/execute.lua index d400d1e..9a70343 100644 --- a/lua/cp/runner/execute.lua +++ b/lua/cp/runner/execute.lua @@ -98,7 +98,6 @@ local function parse_and_strip_time_v(output) end local peak_mb = peak_kb / 1024.0 - head = head:gsub('\n+$', '') return head, peak_mb end diff --git a/lua/cp/runner/run.lua b/lua/cp/runner/run.lua index b4454b3..dbfb52b 100644 --- a/lua/cp/runner/run.lua +++ b/lua/cp/runner/run.lua @@ -58,6 +58,22 @@ local function load_constraints_from_cache(platform, contest_id, problem_id) return nil end +--- Normalize raw problem output to a "canonical" version +--- Usually, most contests ignore leading/trailing whitespace and empty lines +---@param lines string +local function normalize_lines(lines) + local normalized = {} + for _, line in + ipairs(vim.tbl_values(vim.split(((lines or ''):gsub('\r', '')), '\n', { plain = true }))) + do + local trimmed_line = vim.trim(line) + if trimmed_line ~= '' then + table.insert(normalized, trimmed_line) + end + end + return table.concat(normalized, '\n') +end + ---@param test_cases TestCase[] ---@return RanTestCase[] local function create_sentinal_panel_data(test_cases) @@ -106,8 +122,7 @@ local function run_single_test_case(contest_config, cp_config, test_case) local r = exec.run(cmd, stdin_content, timeout_ms, memory_mb) local ansi = require('cp.ui.ansi') - local out = (r.stdout or ''):gsub('\n$', '') - + local out = r.stdout or '' local highlights = {} if out ~= '' then if cp_config.run_panel.ansi then @@ -130,8 +145,8 @@ local function run_single_test_case(contest_config, cp_config, test_case) out = table.concat(trimmed, '\n') end - local expected = (test_case.expected or ''):gsub('\n$', '') - local ok = out == expected + local expected = test_case.expected or '' + local ok = normalize_lines(out) == normalize_lines(expected) local signal = r.signal if not signal and r.code and r.code >= 128 then diff --git a/lua/cp/setup.lua b/lua/cp/setup.lua index c821df2..0ee7725 100644 --- a/lua/cp/setup.lua +++ b/lua/cp/setup.lua @@ -63,10 +63,11 @@ function M.setup_contest(platform, contest_id, language, problem_id) M.setup_problem(pid, language) local cached_len = #vim.tbl_filter(function(p) - return cache.get_test_cases(platform, contest_id, p.id) ~= nil + return not vim.tbl_isempty(cache.get_test_cases(platform, contest_id, p.id)) end, problems) if cached_len ~= #problems then + logger.log(('Found %s problems, expected %s; re-fetching'):format(cached_len, #problems)) scraper.scrape_all_tests(platform, contest_id, function(ev) local cached_tests = {} for i, t in ipairs(ev.tests) do @@ -89,7 +90,7 @@ function M.setup_contest(platform, contest_id, language, problem_id) logger.log('Fetching contests problems...', vim.log.levels.INFO, true) scraper.scrape_contest_metadata(platform, contest_id, function(result) local problems = result.problems or {} - cache.set_contest_data(platform, contest_id, problems, result.name, result.display_name) + cache.set_contest_data(platform, contest_id, problems) logger.log(('Found %d problems for %s contest %s.'):format(#problems, platform, contest_id)) proceed(cache.get_contest_data(platform, contest_id)) end) diff --git a/lua/cp/ui/ansi.lua b/lua/cp/ui/ansi.lua index 0d31bdd..facfd1f 100644 --- a/lua/cp/ui/ansi.lua +++ b/lua/cp/ui/ansi.lua @@ -25,7 +25,7 @@ end ---@return AnsiParseResult function M.parse_ansi_text(text) local clean_text = text:gsub('\027%[[%d;]*[a-zA-Z]', '') - local lines = vim.split(clean_text, '\n', { plain = true, trimempty = false }) + local lines = vim.split(clean_text, '\n', { plain = true }) local highlights = {} local line_num = 0 diff --git a/lua/cp/ui/diff.lua b/lua/cp/ui/diff.lua index c3fe7cd..9f4604f 100644 --- a/lua/cp/ui/diff.lua +++ b/lua/cp/ui/diff.lua @@ -13,7 +13,7 @@ local M = {} local vim_backend = { name = 'vim', render = function(_, actual) - local actual_lines = vim.split(actual, '\n', { plain = true, trimempty = true }) + local actual_lines = vim.split(actual, '\n', { plain = true }) return { content = actual_lines, @@ -27,7 +27,7 @@ local none_backend = { name = 'none', render = function(expected, actual) local expected_lines = vim.split(expected, '\n', { plain = true, trimempty = true }) - local actual_lines = vim.split(actual, '\n', { plain = true, trimempty = true }) + local actual_lines = vim.split(actual, '\n', { plain = true }) return { content = { expected = expected_lines, actual = actual_lines }, @@ -64,7 +64,7 @@ local git_backend = { if result.code == 0 then return { - content = vim.split(actual, '\n', { plain = true, trimempty = true }), + content = vim.split(actual, '\n', { plain = true }), highlights = {}, } else diff --git a/lua/cp/ui/layouts.lua b/lua/cp/ui/layouts.lua index 84d667a..3d12a21 100644 --- a/lua/cp/ui/layouts.lua +++ b/lua/cp/ui/layouts.lua @@ -22,7 +22,7 @@ local function create_none_diff_layout(parent_win, expected_content, actual_cont vim.api.nvim_set_option_value('winbar', 'Actual', { win = actual_win }) local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) - local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + local actual_lines = vim.split(actual_content, '\n', { plain = true }) utils.update_buffer_content(expected_buf, expected_lines, {}) utils.update_buffer_content(actual_buf, actual_lines, {}) @@ -59,7 +59,7 @@ local function create_vim_diff_layout(parent_win, expected_content, actual_conte vim.api.nvim_set_option_value('winbar', 'Actual', { win = actual_win }) local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) - local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + local actual_lines = vim.split(actual_content, '\n', { plain = true }) utils.update_buffer_content(expected_buf, expected_lines, {}) utils.update_buffer_content(actual_buf, actual_lines, {}) @@ -108,7 +108,7 @@ local function create_git_diff_layout(parent_win, expected_content, actual_conte if diff_result.raw_diff and diff_result.raw_diff ~= '' then highlight.parse_and_apply_diff(diff_buf, diff_result.raw_diff, diff_namespace) else - local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + local lines = vim.split(actual_content, '\n', { plain = true }) utils.update_buffer_content(diff_buf, lines, {}) end @@ -124,7 +124,7 @@ end local function create_single_layout(parent_win, content) local buf = utils.create_buffer_with_options() - local lines = vim.split(content, '\n', { plain = true, trimempty = true }) + local lines = vim.split(content, '\n', { plain = true }) utils.update_buffer_content(buf, lines, {}) vim.api.nvim_set_current_win(parent_win) @@ -218,7 +218,7 @@ function M.update_diff_panes( end else if desired_mode == 'single' then - local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + local lines = vim.split(actual_content, '\n', { plain = true }) utils.update_buffer_content( current_diff_layout.buffers[1], lines, @@ -237,7 +237,7 @@ function M.update_diff_panes( diff_namespace ) else - local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + local lines = vim.split(actual_content, '\n', { plain = true }) utils.update_buffer_content( current_diff_layout.buffers[1], lines, @@ -247,7 +247,7 @@ function M.update_diff_panes( end elseif desired_mode == 'none' then local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) - local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + local actual_lines = vim.split(actual_content, '\n', { plain = true }) utils.update_buffer_content(current_diff_layout.buffers[1], expected_lines, {}) utils.update_buffer_content( current_diff_layout.buffers[2], @@ -257,7 +257,7 @@ function M.update_diff_panes( ) else local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) - local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true }) + local actual_lines = vim.split(actual_content, '\n', { plain = true }) utils.update_buffer_content(current_diff_layout.buffers[1], expected_lines, {}) utils.update_buffer_content( current_diff_layout.buffers[2], diff --git a/scrapers/atcoder.py b/scrapers/atcoder.py index 4ad8b99..2aab23c 100644 --- a/scrapers/atcoder.py +++ b/scrapers/atcoder.py @@ -197,9 +197,9 @@ def _extract_samples(html: str) -> list[TestCase]: mi = re.search(r"Sample\s*Input\s*(\d+)", title, flags=re.I) mo = re.search(r"Sample\s*Output\s*(\d+)", title, flags=re.I) if mi: - inputs[mi.group(1)] = t + inputs[mi.group(1)] = t.strip() elif mo: - outputs[mo.group(1)] = t + outputs[mo.group(1)] = t.strip() cases: list[TestCase] = [] for k in sorted(set(inputs) & set(outputs), key=lambda s: int(s)): cases.append(TestCase(input=inputs[k], expected=outputs[k])) diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index d76168d..47c08c9 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -39,7 +39,7 @@ def _text_from_pre(pre: Tag) -> str: pre.get_text(separator="\n", strip=False) .replace("\r", "") .replace("\xa0", " ") - .rstrip("\n") + .strip() ) @@ -61,6 +61,20 @@ def _extract_limits(block: Tag) -> tuple[int, float]: return timeout_ms, memory_mb +def _group_lines_by_id(pre: Tag) -> dict[int, list[str]]: + groups: dict[int, list[str]] = {} + if not isinstance(pre, Tag): + return groups + for div in pre.find_all("div", class_="test-example-line"): + cls = " ".join(div.get("class", [])) + m = re.search(r"\btest-example-line-(\d+)\b", cls) + if not m: + continue + gid = int(m.group(1)) + groups.setdefault(gid, []).append(div.get_text("", strip=False)) + return groups + + def _extract_title(block: Tag) -> tuple[str, str]: t = block.find("div", class_="title") if not t: @@ -77,19 +91,47 @@ def _extract_samples(block: Tag) -> list[TestCase]: if not st: return [] - inputs = [ - _text_from_pre(pre) + input_pres: list[Tag] = [ # type: ignore[misc] + inp.find("pre") # type: ignore[misc] for inp in st.find_all("div", class_="input") # type: ignore[union-attr] - for pre in [inp.find("pre")] - if isinstance(pre, Tag) + if isinstance(inp, Tag) and inp.find("pre") ] - outputs = [ - _text_from_pre(pre) + output_pres: list[Tag] = [ + out.find("pre") # type: ignore[misc] for out in st.find_all("div", class_="output") # type: ignore[union-attr] - for pre in [out.find("pre")] - if isinstance(pre, Tag) + if isinstance(out, Tag) and out.find("pre") ] + input_pres = [p for p in input_pres if isinstance(p, Tag)] + output_pres = [p for p in output_pres if isinstance(p, Tag)] + has_grouped = any( + p.find("div", class_="test-example-line") for p in input_pres + output_pres + ) + if has_grouped: + inputs_by_gid: dict[int, list[str]] = {} + outputs_by_gid: dict[int, list[str]] = {} + for p in input_pres: + g = _group_lines_by_id(p) + for k, v in g.items(): + inputs_by_gid.setdefault(k, []).extend(v) + for p in output_pres: + g = _group_lines_by_id(p) + for k, v in g.items(): + outputs_by_gid.setdefault(k, []).extend(v) + inputs_by_gid.pop(0, None) + outputs_by_gid.pop(0, None) + keys = sorted(set(inputs_by_gid.keys()) & set(outputs_by_gid.keys())) + if keys: + return [ + TestCase( + input="\n".join(inputs_by_gid[k]).strip(), + expected="\n".join(outputs_by_gid[k]).strip(), + ) + for k in keys + ] + + inputs = [_text_from_pre(p) for p in input_pres] + outputs = [_text_from_pre(p) for p in output_pres] n = min(len(inputs), len(outputs)) return [TestCase(input=inputs[i], expected=outputs[i]) for i in range(n)]