Merge pull request #141 from barrett-ruth/feat/caching

misc cachign fixes
This commit is contained in:
Barrett Ruth 2025-10-04 21:02:09 +02:00 committed by GitHub
commit 018f61af92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 113 additions and 53 deletions

View file

@ -92,9 +92,7 @@ end
---@param platform string
---@param contest_id string
---@param problems Problem[]
---@param contest_name? string
---@param display_name? string
function M.set_contest_data(platform, contest_id, problems, contest_name, display_name)
function M.set_contest_data(platform, contest_id, problems)
vim.validate({
platform = { platform, 'string' },
contest_id = { contest_id, 'string' },
@ -102,9 +100,11 @@ function M.set_contest_data(platform, contest_id, problems, contest_name, displa
})
cache_data[platform] = cache_data[platform] or {}
local prev = cache_data[platform][contest_id] or {}
local out = {
name = contest_name,
display_name = display_name,
name = prev.name,
display_name = prev.display_name,
problems = vim.deepcopy(problems),
index_map = {},
}
@ -151,7 +151,7 @@ function M.get_test_cases(platform, contest_id, problem_id)
end
local index = cache_data[platform][contest_id].index_map[problem_id]
return cache_data[platform][contest_id].problems[index].test_cases
return cache_data[platform][contest_id].problems[index].test_cases or {}
end
---@param platform string

View file

@ -1,8 +1,9 @@
local M = {}
function M.log(msg, level, override)
local debug = require('cp.config').get_config().debug or false
level = level or vim.log.levels.INFO
if level >= vim.log.levels.WARN or override then
if level >= vim.log.levels.WARN or override or debug then
vim.schedule(function()
vim.notify(('[cp.nvim]: %s'):format(msg), level)
end)

View file

@ -40,27 +40,29 @@ end
---@param refresh? boolean
---@return cp.ContestItem[]
function M.get_platform_contests(platform, refresh)
logger.log(
('Loading %s contests...'):format(constants.PLATFORM_DISPLAY_NAMES[platform]),
vim.log.levels.INFO,
true
)
cache.load()
local picker_contests = cache.get_contest_summaries(platform)
if refresh or vim.tbl_isempty(picker_contests) then
logger.log(('Cache miss on %s contests'):format(platform))
local contests = scraper.scrape_contest_list(platform) -- sync
cache.set_contest_summaries(platform, contests)
picker_contests = cache.get_contest_summaries(platform) -- <-- reload after write
end
logger.log(
('Loading %s contests...'):format(constants.PLATFORM_DISPLAY_NAMES[platform]),
vim.log.levels.INFO,
true
)
logger.log(
('Loaded %d %s contests.'):format(#picker_contests, constants.PLATFORM_DISPLAY_NAMES[platform]),
vim.log.levels.INFO,
true
)
local contests = scraper.scrape_contest_list(platform)
cache.set_contest_summaries(platform, contests)
picker_contests = cache.get_contest_summaries(platform)
logger.log(
('Loaded %d %s contests.'):format(
#picker_contests,
constants.PLATFORM_DISPLAY_NAMES[platform]
),
vim.log.levels.INFO,
true
)
end
return picker_contests
end

View file

@ -98,7 +98,6 @@ local function parse_and_strip_time_v(output)
end
local peak_mb = peak_kb / 1024.0
head = head:gsub('\n+$', '')
return head, peak_mb
end

View file

@ -58,6 +58,22 @@ local function load_constraints_from_cache(platform, contest_id, problem_id)
return nil
end
--- Normalize raw problem output to a "canonical" version
--- Usually, most contests ignore leading/trailing whitespace and empty lines
---@param lines string
local function normalize_lines(lines)
local normalized = {}
for _, line in
ipairs(vim.tbl_values(vim.split(((lines or ''):gsub('\r', '')), '\n', { plain = true })))
do
local trimmed_line = vim.trim(line)
if trimmed_line ~= '' then
table.insert(normalized, trimmed_line)
end
end
return table.concat(normalized, '\n')
end
---@param test_cases TestCase[]
---@return RanTestCase[]
local function create_sentinal_panel_data(test_cases)
@ -106,8 +122,7 @@ local function run_single_test_case(contest_config, cp_config, test_case)
local r = exec.run(cmd, stdin_content, timeout_ms, memory_mb)
local ansi = require('cp.ui.ansi')
local out = (r.stdout or ''):gsub('\n$', '')
local out = r.stdout or ''
local highlights = {}
if out ~= '' then
if cp_config.run_panel.ansi then
@ -130,8 +145,8 @@ local function run_single_test_case(contest_config, cp_config, test_case)
out = table.concat(trimmed, '\n')
end
local expected = (test_case.expected or ''):gsub('\n$', '')
local ok = out == expected
local expected = test_case.expected or ''
local ok = normalize_lines(out) == normalize_lines(expected)
local signal = r.signal
if not signal and r.code and r.code >= 128 then

View file

@ -63,10 +63,11 @@ function M.setup_contest(platform, contest_id, language, problem_id)
M.setup_problem(pid, language)
local cached_len = #vim.tbl_filter(function(p)
return cache.get_test_cases(platform, contest_id, p.id) ~= nil
return not vim.tbl_isempty(cache.get_test_cases(platform, contest_id, p.id))
end, problems)
if cached_len ~= #problems then
logger.log(('Found %s problems, expected %s; re-fetching'):format(cached_len, #problems))
scraper.scrape_all_tests(platform, contest_id, function(ev)
local cached_tests = {}
for i, t in ipairs(ev.tests) do
@ -89,7 +90,7 @@ function M.setup_contest(platform, contest_id, language, problem_id)
logger.log('Fetching contests problems...', vim.log.levels.INFO, true)
scraper.scrape_contest_metadata(platform, contest_id, function(result)
local problems = result.problems or {}
cache.set_contest_data(platform, contest_id, problems, result.name, result.display_name)
cache.set_contest_data(platform, contest_id, problems)
logger.log(('Found %d problems for %s contest %s.'):format(#problems, platform, contest_id))
proceed(cache.get_contest_data(platform, contest_id))
end)

View file

@ -25,7 +25,7 @@ end
---@return AnsiParseResult
function M.parse_ansi_text(text)
local clean_text = text:gsub('\027%[[%d;]*[a-zA-Z]', '')
local lines = vim.split(clean_text, '\n', { plain = true, trimempty = false })
local lines = vim.split(clean_text, '\n', { plain = true })
local highlights = {}
local line_num = 0

View file

@ -13,7 +13,7 @@ local M = {}
local vim_backend = {
name = 'vim',
render = function(_, actual)
local actual_lines = vim.split(actual, '\n', { plain = true, trimempty = true })
local actual_lines = vim.split(actual, '\n', { plain = true })
return {
content = actual_lines,
@ -27,7 +27,7 @@ local none_backend = {
name = 'none',
render = function(expected, actual)
local expected_lines = vim.split(expected, '\n', { plain = true, trimempty = true })
local actual_lines = vim.split(actual, '\n', { plain = true, trimempty = true })
local actual_lines = vim.split(actual, '\n', { plain = true })
return {
content = { expected = expected_lines, actual = actual_lines },
@ -64,7 +64,7 @@ local git_backend = {
if result.code == 0 then
return {
content = vim.split(actual, '\n', { plain = true, trimempty = true }),
content = vim.split(actual, '\n', { plain = true }),
highlights = {},
}
else

View file

@ -22,7 +22,7 @@ local function create_none_diff_layout(parent_win, expected_content, actual_cont
vim.api.nvim_set_option_value('winbar', 'Actual', { win = actual_win })
local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true })
local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true })
local actual_lines = vim.split(actual_content, '\n', { plain = true })
utils.update_buffer_content(expected_buf, expected_lines, {})
utils.update_buffer_content(actual_buf, actual_lines, {})
@ -59,7 +59,7 @@ local function create_vim_diff_layout(parent_win, expected_content, actual_conte
vim.api.nvim_set_option_value('winbar', 'Actual', { win = actual_win })
local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true })
local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true })
local actual_lines = vim.split(actual_content, '\n', { plain = true })
utils.update_buffer_content(expected_buf, expected_lines, {})
utils.update_buffer_content(actual_buf, actual_lines, {})
@ -108,7 +108,7 @@ local function create_git_diff_layout(parent_win, expected_content, actual_conte
if diff_result.raw_diff and diff_result.raw_diff ~= '' then
highlight.parse_and_apply_diff(diff_buf, diff_result.raw_diff, diff_namespace)
else
local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true })
local lines = vim.split(actual_content, '\n', { plain = true })
utils.update_buffer_content(diff_buf, lines, {})
end
@ -124,7 +124,7 @@ end
local function create_single_layout(parent_win, content)
local buf = utils.create_buffer_with_options()
local lines = vim.split(content, '\n', { plain = true, trimempty = true })
local lines = vim.split(content, '\n', { plain = true })
utils.update_buffer_content(buf, lines, {})
vim.api.nvim_set_current_win(parent_win)
@ -218,7 +218,7 @@ function M.update_diff_panes(
end
else
if desired_mode == 'single' then
local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true })
local lines = vim.split(actual_content, '\n', { plain = true })
utils.update_buffer_content(
current_diff_layout.buffers[1],
lines,
@ -237,7 +237,7 @@ function M.update_diff_panes(
diff_namespace
)
else
local lines = vim.split(actual_content, '\n', { plain = true, trimempty = true })
local lines = vim.split(actual_content, '\n', { plain = true })
utils.update_buffer_content(
current_diff_layout.buffers[1],
lines,
@ -247,7 +247,7 @@ function M.update_diff_panes(
end
elseif desired_mode == 'none' then
local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true })
local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true })
local actual_lines = vim.split(actual_content, '\n', { plain = true })
utils.update_buffer_content(current_diff_layout.buffers[1], expected_lines, {})
utils.update_buffer_content(
current_diff_layout.buffers[2],
@ -257,7 +257,7 @@ function M.update_diff_panes(
)
else
local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true })
local actual_lines = vim.split(actual_content, '\n', { plain = true, trimempty = true })
local actual_lines = vim.split(actual_content, '\n', { plain = true })
utils.update_buffer_content(current_diff_layout.buffers[1], expected_lines, {})
utils.update_buffer_content(
current_diff_layout.buffers[2],

View file

@ -197,9 +197,9 @@ def _extract_samples(html: str) -> list[TestCase]:
mi = re.search(r"Sample\s*Input\s*(\d+)", title, flags=re.I)
mo = re.search(r"Sample\s*Output\s*(\d+)", title, flags=re.I)
if mi:
inputs[mi.group(1)] = t
inputs[mi.group(1)] = t.strip()
elif mo:
outputs[mo.group(1)] = t
outputs[mo.group(1)] = t.strip()
cases: list[TestCase] = []
for k in sorted(set(inputs) & set(outputs), key=lambda s: int(s)):
cases.append(TestCase(input=inputs[k], expected=outputs[k]))

View file

@ -39,7 +39,7 @@ def _text_from_pre(pre: Tag) -> str:
pre.get_text(separator="\n", strip=False)
.replace("\r", "")
.replace("\xa0", " ")
.rstrip("\n")
.strip()
)
@ -61,6 +61,20 @@ def _extract_limits(block: Tag) -> tuple[int, float]:
return timeout_ms, memory_mb
def _group_lines_by_id(pre: Tag) -> dict[int, list[str]]:
groups: dict[int, list[str]] = {}
if not isinstance(pre, Tag):
return groups
for div in pre.find_all("div", class_="test-example-line"):
cls = " ".join(div.get("class", []))
m = re.search(r"\btest-example-line-(\d+)\b", cls)
if not m:
continue
gid = int(m.group(1))
groups.setdefault(gid, []).append(div.get_text("", strip=False))
return groups
def _extract_title(block: Tag) -> tuple[str, str]:
t = block.find("div", class_="title")
if not t:
@ -77,19 +91,47 @@ def _extract_samples(block: Tag) -> list[TestCase]:
if not st:
return []
inputs = [
_text_from_pre(pre)
input_pres: list[Tag] = [ # type: ignore[misc]
inp.find("pre") # type: ignore[misc]
for inp in st.find_all("div", class_="input") # type: ignore[union-attr]
for pre in [inp.find("pre")]
if isinstance(pre, Tag)
if isinstance(inp, Tag) and inp.find("pre")
]
outputs = [
_text_from_pre(pre)
output_pres: list[Tag] = [
out.find("pre") # type: ignore[misc]
for out in st.find_all("div", class_="output") # type: ignore[union-attr]
for pre in [out.find("pre")]
if isinstance(pre, Tag)
if isinstance(out, Tag) and out.find("pre")
]
input_pres = [p for p in input_pres if isinstance(p, Tag)]
output_pres = [p for p in output_pres if isinstance(p, Tag)]
has_grouped = any(
p.find("div", class_="test-example-line") for p in input_pres + output_pres
)
if has_grouped:
inputs_by_gid: dict[int, list[str]] = {}
outputs_by_gid: dict[int, list[str]] = {}
for p in input_pres:
g = _group_lines_by_id(p)
for k, v in g.items():
inputs_by_gid.setdefault(k, []).extend(v)
for p in output_pres:
g = _group_lines_by_id(p)
for k, v in g.items():
outputs_by_gid.setdefault(k, []).extend(v)
inputs_by_gid.pop(0, None)
outputs_by_gid.pop(0, None)
keys = sorted(set(inputs_by_gid.keys()) & set(outputs_by_gid.keys()))
if keys:
return [
TestCase(
input="\n".join(inputs_by_gid[k]).strip(),
expected="\n".join(outputs_by_gid[k]).strip(),
)
for k in keys
]
inputs = [_text_from_pre(p) for p in input_pres]
outputs = [_text_from_pre(p) for p in output_pres]
n = min(len(inputs), len(outputs))
return [TestCase(input=inputs[i], expected=outputs[i]) for i in range(n)]