diff --git a/doc/cp.nvim.txt b/doc/cp.nvim.txt index 695a2cb..2557fb8 100644 --- a/doc/cp.nvim.txt +++ b/doc/cp.nvim.txt @@ -142,6 +142,7 @@ Here's an example configuration with lazy.nvim: default_language = 'cpp', }, }, + open_url = true, debug = false, ui = { run_panel = { @@ -210,6 +211,7 @@ run CSES problems with Rust using the single schema: Should return full filename with extension. (default: concatenates contest_id and problem_id, lowercased) {ui} (|CpUI|) UI settings: run panel, diff backend, picker. + {open_url} (boolean) Open the contest url in the browser. *CpPlatform* Fields: ~ diff --git a/lua/cp/cache.lua b/lua/cp/cache.lua index 9e96caa..5c56ef8 100644 --- a/lua/cp/cache.lua +++ b/lua/cp/cache.lua @@ -9,6 +9,7 @@ ---@field index_map table ---@field name string ---@field display_name string +---@field url string ---@class ContestSummary ---@field display_name string @@ -94,11 +95,13 @@ end ---@param platform string ---@param contest_id string ---@param problems Problem[] -function M.set_contest_data(platform, contest_id, problems) +---@param url string +function M.set_contest_data(platform, contest_id, problems, url) vim.validate({ platform = { platform, 'string' }, contest_id = { contest_id, 'string' }, problems = { problems, 'table' }, + url = { url, 'string' }, }) cache_data[platform] = cache_data[platform] or {} @@ -109,6 +112,7 @@ function M.set_contest_data(platform, contest_id, problems) display_name = prev.display_name, problems = problems, index_map = {}, + url = url, } for i, p in ipairs(out.problems) do out.index_map[p.id] = i diff --git a/lua/cp/config.lua b/lua/cp/config.lua index 31540cf..ae66586 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -43,6 +43,7 @@ ---@field platforms table ---@field hooks Hooks ---@field debug boolean +---@field open_url boolean ---@field scrapers string[] ---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string ---@field ui CpUI @@ -58,6 +59,7 @@ local utils = require('cp.utils') -- defaults per the new single schema ---@type cp.Config M.defaults = { + open_url = false, languages = { cpp = { extension = 'cc', @@ -223,9 +225,7 @@ function M.setup(user_config) vim.validate({ hooks = { cfg.hooks, { 'table' } }, ui = { cfg.ui, { 'table' } }, - }) - - vim.validate({ + open_url = { cfg.open_url, { 'boolean', 'nil' }, true }, before_run = { cfg.hooks.before_run, { 'function', 'nil' }, true }, before_debug = { cfg.hooks.before_debug, { 'function', 'nil' }, true }, setup_code = { cfg.hooks.setup_code, { 'function', 'nil' }, true }, diff --git a/lua/cp/restore.lua b/lua/cp/restore.lua index 875e733..44cdfd4 100644 --- a/lua/cp/restore.lua +++ b/lua/cp/restore.lua @@ -16,7 +16,7 @@ function M.restore_from_current_file() end local setup = require('cp.setup') - setup.set_platform(file_state.platform) + state.set_platform(file_state.platform) state.set_contest_id(file_state.contest_id) state.set_problem_id(file_state.problem_id) setup.setup_contest( diff --git a/lua/cp/setup.lua b/lua/cp/setup.lua index 1744655..4c5fcae 100644 --- a/lua/cp/setup.lua +++ b/lua/cp/setup.lua @@ -77,6 +77,11 @@ function M.setup_contest(platform, contest_id, problem_id, language) local pid = problem_id and problem_id or problems[1].id M.setup_problem(pid, language) start_tests(platform, contest_id, problems) + + if contest_data.url and config_module.get_config().open_url then + vim.print('opening') + vim.ui.open(contest_data.url) + end end local contest_data = cache.get_contest_data(platform, contest_id) @@ -134,7 +139,7 @@ function M.setup_contest(platform, contest_id, problem_id, language) contest_id, vim.schedule_wrap(function(result) local problems = result.problems or {} - cache.set_contest_data(platform, contest_id, problems) + cache.set_contest_data(platform, contest_id, problems, result.url) local prov = state.get_provisional() if not prov or prov.platform ~= platform or prov.contest_id ~= contest_id then return @@ -150,8 +155,7 @@ function M.setup_contest(platform, contest_id, problem_id, language) if not pid then return end - M.setup_problem(pid, prov.language) - start_tests(platform, contest_id, cd.problems) + proceed(cd) end) ) return diff --git a/lua/cp/utils.lua b/lua/cp/utils.lua index e78b056..6ce2311 100644 --- a/lua/cp/utils.lua +++ b/lua/cp/utils.lua @@ -12,7 +12,7 @@ local _timeout_path = nil local _timeout_reason = nil local function is_windows() - return uname and uname.sysname == 'Windows_NT' + return uname.sysname == 'Windows_NT' end local function check_time_is_gnu_time(bin) diff --git a/scrapers/atcoder.py b/scrapers/atcoder.py index 7571a26..f565104 100644 --- a/scrapers/atcoder.py +++ b/scrapers/atcoder.py @@ -286,6 +286,7 @@ class AtcoderScraper(BaseScraper): error="", contest_id=cid, problems=problems, + url=f"https://atcoder.jp/contests/{contest_id}/tasks", ) return await self._safe_execute("metadata", impl, contest_id) @@ -335,6 +336,7 @@ async def main_async() -> int: result = MetadataResult( success=False, error="Usage: atcoder.py metadata OR atcoder.py tests OR atcoder.py contests", + url="", ) print(result.model_dump_json()) return 1 @@ -345,7 +347,9 @@ async def main_async() -> int: if mode == "metadata": if len(sys.argv) != 3: result = MetadataResult( - success=False, error="Usage: atcoder.py metadata " + success=False, + error="Usage: atcoder.py metadata ", + url="", ) print(result.model_dump_json()) return 1 @@ -360,7 +364,6 @@ async def main_async() -> int: success=False, error="Usage: atcoder.py tests ", problem_id="", - url="", tests=[], timeout_ms=0, memory_mb=0, @@ -385,6 +388,7 @@ async def main_async() -> int: result = MetadataResult( success=False, error="Unknown mode. Use 'metadata ', 'tests ', or 'contests'", + url="", ) print(result.model_dump_json()) return 1 diff --git a/scrapers/base.py b/scrapers/base.py index 5c602a3..315519c 100644 --- a/scrapers/base.py +++ b/scrapers/base.py @@ -28,6 +28,7 @@ class BaseScraper(ABC): error=f"{self.platform_name}: {error_msg}", contest_id=contest_id, problems=[], + url="", ) def _create_tests_error( @@ -37,7 +38,6 @@ class BaseScraper(ABC): success=False, error=f"{self.platform_name}: {error_msg}", problem_id=problem_id, - url=url, tests=[], timeout_ms=0, memory_mb=0, diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index b0eecc3..10287ae 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -198,7 +198,11 @@ class CodeforcesScraper(BaseScraper): f"No problems found for contest {cid}", cid ) return MetadataResult( - success=True, error="", contest_id=cid, problems=problems + success=True, + error="", + contest_id=cid, + problems=problems, + url=f"https://codeforces.com/contest/{contest_id}", ) return await self._safe_execute("metadata", impl, contest_id) @@ -259,6 +263,7 @@ async def main_async() -> int: result = MetadataResult( success=False, error="Usage: codeforces.py metadata OR codeforces.py tests OR codeforces.py contests", + url="", ) print(result.model_dump_json()) return 1 @@ -269,7 +274,9 @@ async def main_async() -> int: if mode == "metadata": if len(sys.argv) != 3: result = MetadataResult( - success=False, error="Usage: codeforces.py metadata " + success=False, + error="Usage: codeforces.py metadata ", + url="", ) print(result.model_dump_json()) return 1 @@ -284,7 +291,6 @@ async def main_async() -> int: success=False, error="Usage: codeforces.py tests ", problem_id="", - url="", tests=[], timeout_ms=0, memory_mb=0, @@ -309,6 +315,7 @@ async def main_async() -> int: result = MetadataResult( success=False, error="Unknown mode. Use 'metadata ', 'tests ', or 'contests'", + url="", ) print(result.model_dump_json()) return 1 diff --git a/scrapers/cses.py b/scrapers/cses.py index 2f76cc5..434e8a4 100644 --- a/scrapers/cses.py +++ b/scrapers/cses.py @@ -193,9 +193,14 @@ class CSESScraper(BaseScraper): return MetadataResult( success=False, error=f"{self.platform_name}: No problems found for category: {contest_id}", + url="", ) return MetadataResult( - success=True, error="", contest_id=contest_id, problems=problems + success=True, + error="", + contest_id=contest_id, + problems=problems, + url="https://cses.fi/problemset", ) async def scrape_contest_list(self) -> ContestListResult: @@ -249,6 +254,7 @@ async def main_async() -> int: result = MetadataResult( success=False, error="Usage: cses.py metadata OR cses.py tests OR cses.py contests", + url="", ) print(result.model_dump_json()) return 1 @@ -259,7 +265,9 @@ async def main_async() -> int: if mode == "metadata": if len(sys.argv) != 3: result = MetadataResult( - success=False, error="Usage: cses.py metadata " + success=False, + error="Usage: cses.py metadata ", + url="", ) print(result.model_dump_json()) return 1 @@ -274,7 +282,6 @@ async def main_async() -> int: success=False, error="Usage: cses.py tests ", problem_id="", - url="", tests=[], timeout_ms=0, memory_mb=0, @@ -299,6 +306,7 @@ async def main_async() -> int: result = MetadataResult( success=False, error=f"Unknown mode: {mode}. Use 'metadata ', 'tests ', or 'contests'", + url="", ) print(result.model_dump_json()) return 1 diff --git a/scrapers/models.py b/scrapers/models.py index d2cf19a..2a954ef 100644 --- a/scrapers/models.py +++ b/scrapers/models.py @@ -33,6 +33,7 @@ class ScrapingResult(BaseModel): class MetadataResult(ScrapingResult): contest_id: str = "" problems: list[ProblemSummary] = Field(default_factory=list) + url: str model_config = ConfigDict(extra="forbid") @@ -45,7 +46,6 @@ class ContestListResult(ScrapingResult): class TestsResult(ScrapingResult): problem_id: str - url: str tests: list[TestCase] = Field(default_factory=list) timeout_ms: int memory_mb: float diff --git a/tests/test_scrapers.py b/tests/test_scrapers.py index 83847ca..415f0dd 100644 --- a/tests/test_scrapers.py +++ b/tests/test_scrapers.py @@ -42,15 +42,13 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode): Model = MODEL_FOR_MODE[mode] model = Model.model_validate(objs[-1]) assert model is not None + assert model.success is True if mode == "metadata": - assert model.success in (True, False) - if model.success: - assert len(model.problems) >= 1 - assert all(isinstance(p.id, str) and p.id for p in model.problems) + assert model.url + assert len(model.problems) >= 1 + assert all(isinstance(p.id, str) and p.id for p in model.problems) else: - assert model.success in (True, False) - if model.success: - assert len(model.contests) >= 1 + assert len(model.contests) >= 1 else: validated_any = False for obj in objs: