diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 696d8ae..ba977a8 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -79,29 +79,22 @@ end ---@param platform string ---@param contest_id string ----@return ContestData? +---@return ContestData function M.get_contest_data(platform, contest_id) vim.validate({ platform = { platform, 'string' }, contest_id = { contest_id, 'string' }, }) - if not cache_data[platform] then - return nil - end - - local contest_data = cache_data[platform][contest_id] - if not contest_data or vim.tbl_isempty(contest_data) then - return nil - end - - return contest_data + return cache_data[platform][contest_id] or {} end ---@param platform string ---@param contest_id string ---@param problems Problem[] -function M.set_contest_data(platform, contest_id, problems) +---@param contest_name? string +---@param display_name? string +function M.set_contest_data(platform, contest_id, problems, contest_name, display_name) vim.validate({ platform = { platform, 'string' }, contest_id = { contest_id, 'string' }, @@ -109,36 +102,17 @@ function M.set_contest_data(platform, contest_id, problems) }) cache_data[platform] = cache_data[platform] or {} - local existing = cache_data[platform][contest_id] or {} - - local existing_by_id = {} - if existing.problems then - for _, p in ipairs(existing.problems) do - existing_by_id[p.id] = p - end + local out = { + name = contest_name, + display_name = display_name, + problems = vim.deepcopy(problems), + index_map = {}, + } + for i, p in ipairs(out.problems) do + out.index_map[p.id] = i end - local merged = {} - for _, p in ipairs(problems) do - local prev = existing_by_id[p.id] or {} - local merged_p = { - id = p.id, - name = p.name or prev.name, - test_cases = prev.test_cases, - timeout_ms = prev.timeout_ms, - memory_mb = prev.memory_mb, - interactive = prev.interactive, - } - table.insert(merged, merged_p) - end - - existing.problems = merged - existing.index_map = {} - for i, p in ipairs(merged) do - existing.index_map[p.id] = i - end - - cache_data[platform][contest_id] = existing + cache_data[platform][contest_id] = out M.save() end diff --git a/lua/cp/scraper.lua b/lua/cp/scraper.lua index 2a2f168..ce414d9 100644 --- a/lua/cp/scraper.lua +++ b/lua/cp/scraper.lua @@ -1,67 +1,109 @@ local M = {} -local utils = require('cp.utils') - local logger = require('cp.log') +local utils = require('cp.utils') local function syshandle(result) if result.code ~= 0 then local msg = 'Scraper failed: ' .. (result.stderr or 'Unknown error') logger.log(msg, vim.log.levels.ERROR) - return { - success = false, - error = msg, - } + return { success = false, error = msg } end local ok, data = pcall(vim.json.decode, result.stdout) if not ok then local msg = 'Failed to parse scraper output: ' .. tostring(data) logger.log(msg, vim.log.levels.ERROR) - return { - success = false, - error = msg, - } + return { success = false, error = msg } end - return { - success = true, - data = data, - } + return { success = true, data = data } end +---@param platform string +---@param subcommand string +---@param args string[] +---@param opts { sync?: boolean, ndjson?: boolean, on_event?: fun(ev: table), on_exit?: fun(result: table) } local function run_scraper(platform, subcommand, args, opts) - if not utils.setup_python_env() then - local msg = 'Python environment setup failed' - logger.log(msg, vim.log.levels.ERROR) - return { - success = false, - message = msg, - } - end - local plugin_path = utils.get_plugin_path() - local cmd = { - 'uv', - 'run', - '--directory', - plugin_path, - '-m', - 'scrapers.' .. platform, - subcommand, - } + local cmd = { 'uv', 'run', '--directory', plugin_path, '-m', 'scrapers.' .. platform, subcommand } vim.list_extend(cmd, args) - local sysopts = { - text = true, - timeout = 30000, - } + if opts and opts.ndjson then + local uv = vim.loop + local stdout = uv.new_pipe(false) + local stderr = uv.new_pipe(false) + local buf = '' - if opts.sync then + local handle = uv.spawn( + cmd[1], + { args = vim.list_slice(cmd, 2), stdio = { nil, stdout, stderr } }, + function(code, signal) + if buf ~= '' and opts.on_event then + local ok_tail, ev_tail = pcall(vim.json.decode, buf) + if ok_tail then + opts.on_event(ev_tail) + end + buf = '' + end + if opts.on_exit then + opts.on_exit({ success = (code == 0), code = code, signal = signal }) + end + if not stdout:is_closing() then + stdout:close() + end + if not stderr:is_closing() then + stderr:close() + end + if handle and not handle:is_closing() then + handle:close() + end + end + ) + + if not handle then + logger.log('Failed to start scraper process', vim.log.levels.ERROR) + return { success = false, error = 'spawn failed' } + end + + uv.read_start(stdout, function(_, data) + if data == nil then + if buf ~= '' and opts.on_event then + local ok_tail, ev_tail = pcall(vim.json.decode, buf) + if ok_tail then + opts.on_event(ev_tail) + end + buf = '' + end + return + end + buf = buf .. data + while true do + local s, e = buf:find('\n', 1, true) + if not s then + break + end + local line = buf:sub(1, s - 1) + buf = buf:sub(e + 1) + local ok, ev = pcall(vim.json.decode, line) + if ok and opts.on_event then + opts.on_event(ev) + end + end + end) + + uv.read_start(stderr, function(_, _) end) + return + end + + local sysopts = { text = true, timeout = 30000 } + if opts and opts.sync then local result = vim.system(cmd, sysopts):wait() return syshandle(result) else vim.system(cmd, sysopts, function(result) - return opts.on_exit(syshandle(result)) + if opts and opts.on_exit then + return opts.on_exit(syshandle(result)) + end end) end end @@ -93,41 +135,48 @@ end function M.scrape_contest_list(platform) local result = run_scraper(platform, 'contests', {}, { sync = true }) - if not result.success or not result.data.contests then + if not result or not result.success or not (result.data and result.data.contests) then logger.log( - ('Could not scrape contests list for platform %s: %s'):format(platform, result.msg), + ('Could not scrape contests list for platform %s: %s'):format( + platform, + (result and result.error) or 'unknown' + ), vim.log.levels.ERROR ) return {} end - return result.data.contests end -function M.scrape_problem_tests(platform, contest_id, problem_id, callback) - run_scraper(platform, 'tests', { contest_id, problem_id }, { - on_exit = function(result) - if not result.success or not result.data.tests then - logger.log( - 'Failed to load tests: ' .. (result.msg or 'unknown error'), - vim.log.levels.ERROR - ) - - return {} +---@param platform string +---@param contest_id string +---@param callback fun(data: table)|nil +function M.scrape_all_tests(platform, contest_id, callback) + run_scraper(platform, 'tests', { contest_id }, { + ndjson = true, + on_event = function(ev) + if ev.done then + return + end + if ev.error and ev.problem_id then + logger.log( + ('Failed to load tests for %s/%s: %s'):format(contest_id, ev.problem_id, ev.error), + vim.log.levels.WARN + ) + return + end + if not ev.problem_id or not ev.tests then + return end - vim.schedule(function() vim.system({ 'mkdir', '-p', 'build', 'io' }):wait() local config = require('cp.config') - local base_name = config.default_filename(contest_id, problem_id) - - for i, test_case in ipairs(result.data.tests) do + local base_name = config.default_filename(contest_id, ev.problem_id) + for i, t in ipairs(ev.tests) do local input_file = 'io/' .. base_name .. '.' .. i .. '.cpin' local expected_file = 'io/' .. base_name .. '.' .. i .. '.cpout' - - local input_content = test_case.input:gsub('\r', '') - local expected_content = test_case.expected:gsub('\r', '') - + local input_content = t.input:gsub('\r', '') + local expected_content = t.expected:gsub('\r', '') pcall(vim.fn.writefile, vim.split(input_content, '\n', { trimempty = true }), input_file) pcall( vim.fn.writefile, @@ -136,7 +185,13 @@ function M.scrape_problem_tests(platform, contest_id, problem_id, callback) ) end if type(callback) == 'function' then - callback(result.data) + callback({ + tests = ev.tests, + timeout_ms = ev.timeout_ms or 0, + memory_mb = ev.memory_mb or 0, + interactive = ev.interactive or false, + problem_id = ev.problem_id, + }) end end) end, diff --git a/lua/cp/setup.lua b/lua/cp/setup.lua index 4d0c402..bfcd329 100644 --- a/lua/cp/setup.lua +++ b/lua/cp/setup.lua @@ -28,45 +28,26 @@ function M.set_platform(platform) return true end -local function backfill_missing_tests(platform, contest_id, problems) - cache.load() - local missing = {} - for _, prob in ipairs(problems) do - if not cache.get_test_cases(platform, contest_id, prob.id) then - table.insert(missing, prob.id) - end - end - if #missing == 0 then - logger.log(('All problems already cached for %s contest %s.'):format(platform, contest_id)) - return - end - for _, pid in ipairs(missing) do - local captured = pid - scraper.scrape_problem_tests(platform, contest_id, captured, function(result) - local cached_tests = {} - if result.tests then - for i, t in ipairs(result.tests) do - cached_tests[i] = { index = i, input = t.input, expected = t.expected } - end - end - cache.set_test_cases( - platform, - contest_id, - captured, - cached_tests, - result.timeout_ms, - result.memory_mb - ) - end) - end -end +---@class TestCaseLite +---@field input string +---@field expected string +---@class ScrapeEvent +---@field problem_id string +---@field tests TestCaseLite[]|nil +---@field timeout_ms integer|nil +---@field memory_mb integer|nil +---@field interactive boolean|nil +---@field error string|nil +---@field done boolean|nil +---@field succeeded integer|nil +---@field failed integer|nil + +---@param platform string +---@param contest_id string +---@param language string|nil +---@param problem_id string|nil function M.setup_contest(platform, contest_id, language, problem_id) - if not platform then - logger.log('No platform configured. Use :CP [--{lang=,debug} first.') - return - end - local config = config_module.get_config() if not vim.tbl_contains(config.scrapers, platform) then logger.log(('Scraping disabled for %s.'):format(platform), vim.log.levels.WARN) @@ -75,27 +56,70 @@ function M.setup_contest(platform, contest_id, language, problem_id) state.set_contest_id(contest_id) cache.load() - local contest_data = cache.get_contest_data(platform, contest_id) + local contest_data = cache.get_contest_data(platform, contest_id) if not contest_data or not contest_data.problems then logger.log('Fetching contests problems...', vim.log.levels.INFO, true) scraper.scrape_contest_metadata(platform, contest_id, function(result) local problems = result.problems or {} - cache.set_contest_data(platform, contest_id, problems) + cache.set_contest_data(platform, contest_id, problems, result.name, result.display_name) logger.log(('Found %d problems for %s contest %s.'):format(#problems, platform, contest_id)) - local pid = problem_id or (problems[1] and problems[1].id) - if pid then - M.setup_problem(pid, language) - end - backfill_missing_tests(platform, contest_id, problems) - end) - else - local problems = contest_data.problems - local pid = problem_id or (problems[1] and problems[1].id) - if pid then + + contest_data = cache.get_contest_data(platform, contest_id) + local pid = contest_data.problems[problem_id and contest_data.index_map[problem_id] or 1].id M.setup_problem(pid, language) - end - backfill_missing_tests(platform, contest_id, problems) + + local cached_len = #vim.tbl_filter(function(p) + return cache.get_test_cases(platform, contest_id, p.id) ~= nil + end, problems) + if cached_len < #problems then + scraper.scrape_all_tests(platform, contest_id, function(ev) + if not ev or not ev.tests or not ev.problem_id then + return + end + local cached_tests = {} + for i, t in ipairs(ev.tests) do + cached_tests[i] = { index = i, input = t.input, expected = t.expected } + end + cache.set_test_cases( + platform, + contest_id, + ev.problem_id, + cached_tests, + ev.timeout_ms or 0, + ev.memory_mb or 0 + ) + end) + end + end) + + return + end + + local problems = contest_data.problems + local pid = problems[(problem_id and contest_data.index_map[problem_id] or 1)].id + M.setup_problem(pid, language) + local cached_len = #vim.tbl_filter(function(p) + return cache.get_test_cases(platform, contest_id, p.id) ~= nil + end, problems) + if cached_len < #problems then + scraper.scrape_all_tests(platform, contest_id, function(ev) + if not ev or not ev.tests or not ev.problem_id then + return + end + local cached_tests = {} + for i, t in ipairs(ev.tests) do + cached_tests[i] = { index = i, input = t.input, expected = t.expected } + end + cache.set_test_cases( + platform, + contest_id, + ev.problem_id, + cached_tests, + ev.timeout_ms or 0, + ev.memory_mb or 0 + ) + end) end end @@ -195,19 +219,9 @@ function M.navigate_problem(direction, language) end local problems = contest_data.problems - local current_index - for i, prob in ipairs(problems) do - if prob.id == current_problem_id then - current_index = i - break - end - end - if not current_index then - M.setup_contest(platform, contest_id, language, problems[1].id) - return - end + local index = contest_data.index_map[current_problem_id] - local new_index = current_index + direction + local new_index = index + direction if new_index < 1 or new_index > #problems then return end diff --git a/pyproject.toml b/pyproject.toml index 54d8580..d999be0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "backoff>=2.2.1", "beautifulsoup4>=4.13.5", "curl-cffi>=0.13.0", + "ndjson>=0.3.1", "playwright>=1.55.0", "requests>=2.32.5", "scrapling[fetchers]>=0.3.5", @@ -22,6 +23,7 @@ dev = [ "pytest>=8.0.0", "pytest-mock>=3.12.0", "pre-commit>=4.3.0", + "basedpyright>=1.31.6", ] [tool.pytest.ini_options] diff --git a/uv.lock b/uv.lock index 1113a88..0cfa5f2 100644 --- a/uv.lock +++ b/uv.lock @@ -119,6 +119,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148, upload-time = "2022-10-05T19:19:30.546Z" }, ] +[[package]] +name = "basedpyright" +version = "1.31.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodejs-wheel-binaries" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/f6/c5657b1e464d04757cde2db76922a88091fe16854bd3d12e470c23b0dcf1/basedpyright-1.31.6.tar.gz", hash = "sha256:07f3602ba1582218dfd1db25b8b69cd3493e1f4367f46a44fd57bb9034b52ea9", size = 22683901, upload-time = "2025-10-01T13:11:21.317Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/2b/34f338b4c04fe965fd209ed872d9fdd893dacc1a06feb6c9fec13ff535c1/basedpyright-1.31.6-py3-none-any.whl", hash = "sha256:620968ee69c14eee6682f29ffd6f813a30966afb1083ecfa4caf155c5d24f2d5", size = 11805295, upload-time = "2025-10-01T13:11:18.308Z" }, +] + [[package]] name = "beautifulsoup4" version = "4.13.5" @@ -1030,6 +1042,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, ] +[[package]] +name = "ndjson" +version = "0.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b4/d5/209b6ca94566f9c94c0ec41cee1681c0a3b92a306a84a9b0fcd662088dc3/ndjson-0.3.1.tar.gz", hash = "sha256:bf9746cb6bb1cb53d172cda7f154c07c786d665ff28341e4e689b796b229e5d6", size = 6448, upload-time = "2020-02-25T05:01:07.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/c9/04ba0056011ba96a58163ebfd666d8385300bd12da1afe661a5a147758d7/ndjson-0.3.1-py2.py3-none-any.whl", hash = "sha256:839c22275e6baa3040077b83c005ac24199b94973309a8a1809be962c753a410", size = 5305, upload-time = "2020-02-25T05:01:06.39Z" }, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -1039,6 +1060,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, ] +[[package]] +name = "nodejs-wheel-binaries" +version = "22.20.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/54/02f58c8119e2f1984e2572cc77a7b469dbaf4f8d171ad376e305749ef48e/nodejs_wheel_binaries-22.20.0.tar.gz", hash = "sha256:a62d47c9fd9c32191dff65bbe60261504f26992a0a19fe8b4d523256a84bd351", size = 8058, upload-time = "2025-09-26T09:48:00.906Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/6d/333e5458422f12318e3c3e6e7f194353aa68b0d633217c7e89833427ca01/nodejs_wheel_binaries-22.20.0-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:455add5ac4f01c9c830ab6771dbfad0fdf373f9b040d3aabe8cca9b6c56654fb", size = 53246314, upload-time = "2025-09-26T09:47:32.536Z" }, + { url = "https://files.pythonhosted.org/packages/56/30/dcd6879d286a35b3c4c8f9e5e0e1bcf4f9e25fe35310fc77ecf97f915a23/nodejs_wheel_binaries-22.20.0-py2.py3-none-macosx_11_0_x86_64.whl", hash = "sha256:5d8c12f97eea7028b34a84446eb5ca81829d0c428dfb4e647e09ac617f4e21fa", size = 53644391, upload-time = "2025-09-26T09:47:36.093Z" }, + { url = "https://files.pythonhosted.org/packages/58/be/c7b2e7aa3bb281d380a1c531f84d0ccfe225832dfc3bed1ca171753b9630/nodejs_wheel_binaries-22.20.0-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a2b0989194148f66e9295d8f11bc463bde02cbe276517f4d20a310fb84780ae", size = 60282516, upload-time = "2025-09-26T09:47:39.88Z" }, + { url = "https://files.pythonhosted.org/packages/3e/c5/8befacf4190e03babbae54cb0809fb1a76e1600ec3967ab8ee9f8fc85b65/nodejs_wheel_binaries-22.20.0-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5c500aa4dc046333ecb0a80f183e069e5c30ce637f1c1a37166b2c0b642dc21", size = 60347290, upload-time = "2025-09-26T09:47:43.712Z" }, + { url = "https://files.pythonhosted.org/packages/c0/bd/cfffd1e334277afa0714962c6ec432b5fe339340a6bca2e5fa8e678e7590/nodejs_wheel_binaries-22.20.0-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3279eb1b99521f0d20a850bbfc0159a658e0e85b843b3cf31b090d7da9f10dfc", size = 62178798, upload-time = "2025-09-26T09:47:47.752Z" }, + { url = "https://files.pythonhosted.org/packages/08/14/10b83a9c02faac985b3e9f5e65d63a34fc0f46b48d8a2c3e4caa3e1e7318/nodejs_wheel_binaries-22.20.0-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d29705797b33bade62d79d8f106c2453c8a26442a9b2a5576610c0f7e7c351ed", size = 62772957, upload-time = "2025-09-26T09:47:51.266Z" }, + { url = "https://files.pythonhosted.org/packages/b4/a9/c6a480259aa0d6b270aac2c6ba73a97444b9267adde983a5b7e34f17e45a/nodejs_wheel_binaries-22.20.0-py2.py3-none-win_amd64.whl", hash = "sha256:4bd658962f24958503541963e5a6f2cc512a8cb301e48a69dc03c879f40a28ae", size = 40120431, upload-time = "2025-09-26T09:47:54.363Z" }, + { url = "https://files.pythonhosted.org/packages/42/b1/6a4eb2c6e9efa028074b0001b61008c9d202b6b46caee9e5d1b18c088216/nodejs_wheel_binaries-22.20.0-py2.py3-none-win_arm64.whl", hash = "sha256:1fccac931faa210d22b6962bcdbc99269d16221d831b9a118bbb80fe434a60b8", size = 38844133, upload-time = "2025-09-26T09:47:57.357Z" }, +] + [[package]] name = "numpy" version = "2.3.3" @@ -1598,6 +1635,7 @@ dependencies = [ { name = "backoff" }, { name = "beautifulsoup4" }, { name = "curl-cffi" }, + { name = "ndjson" }, { name = "playwright" }, { name = "requests" }, { name = "scrapling", extra = ["fetchers"] }, @@ -1606,6 +1644,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "basedpyright" }, { name = "mypy" }, { name = "pre-commit" }, { name = "pytest" }, @@ -1619,6 +1658,7 @@ requires-dist = [ { name = "backoff", specifier = ">=2.2.1" }, { name = "beautifulsoup4", specifier = ">=4.13.5" }, { name = "curl-cffi", specifier = ">=0.13.0" }, + { name = "ndjson", specifier = ">=0.3.1" }, { name = "playwright", specifier = ">=1.55.0" }, { name = "requests", specifier = ">=2.32.5" }, { name = "scrapling", extras = ["fetchers"], specifier = ">=0.3.5" }, @@ -1627,6 +1667,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "basedpyright", specifier = ">=1.31.6" }, { name = "mypy", specifier = ">=1.18.2" }, { name = "pre-commit", specifier = ">=4.3.0" }, { name = "pytest", specifier = ">=8.0.0" },