diff --git a/lua/cp/scraper.lua b/lua/cp/scraper.lua index 7f774c3..194e671 100644 --- a/lua/cp/scraper.lua +++ b/lua/cp/scraper.lua @@ -5,19 +5,19 @@ local logger = require('cp.log') local utils = require('cp.utils') local function syshandle(result) - local ok, data = pcall(vim.json.decode, result.stdout or '') - if ok then - return { success = true, data = data } - end - if result.code ~= 0 then local msg = 'Scraper failed: ' .. (result.stderr or 'Unknown error') return { success = false, error = msg } end - local msg = 'Failed to parse scraper output: ' .. tostring(data) - logger.log(msg, vim.log.levels.ERROR) - return { success = false, error = msg } + local ok, data = pcall(vim.json.decode, result.stdout) + if not ok then + local msg = 'Failed to parse scraper output: ' .. tostring(data) + logger.log(msg, vim.log.levels.ERROR) + return { success = false, error = msg } + end + + return { success = true, data = data } end ---@param env_map table diff --git a/lua/cp/ui/edit.lua b/lua/cp/ui/edit.lua index 20d4e83..886c50a 100644 --- a/lua/cp/ui/edit.lua +++ b/lua/cp/ui/edit.lua @@ -144,7 +144,7 @@ local function add_new_test() vim.api.nvim_win_set_buf(input_win, input_buf) vim.bo[input_buf].modifiable = true vim.bo[input_buf].readonly = false - vim.bo[input_buf].buftype = 'acwrite' + vim.bo[input_buf].buftype = 'nofile' vim.bo[input_buf].buflisted = false helpers.clearcol(input_buf) @@ -155,7 +155,7 @@ local function add_new_test() vim.api.nvim_win_set_buf(expected_win, expected_buf) vim.bo[expected_buf].modifiable = true vim.bo[expected_buf].readonly = false - vim.bo[expected_buf].buftype = 'acwrite' + vim.bo[expected_buf].buftype = 'nofile' vim.bo[expected_buf].buflisted = false helpers.clearcol(expected_buf) @@ -177,80 +177,6 @@ local function add_new_test() logger.log(('Added test %d'):format(new_index)) end -local function save_all_tests() - if not edit_state then - return - end - - local platform = state.get_platform() - local contest_id = state.get_contest_id() - local problem_id = state.get_problem_id() - - if not platform or not contest_id or not problem_id then - return - end - - for i, pair in ipairs(edit_state.test_buffers) do - if - vim.api.nvim_buf_is_valid(pair.input_buf) and vim.api.nvim_buf_is_valid(pair.expected_buf) - then - local input_lines = vim.api.nvim_buf_get_lines(pair.input_buf, 0, -1, false) - local expected_lines = vim.api.nvim_buf_get_lines(pair.expected_buf, 0, -1, false) - - edit_state.test_cases[i].input = table.concat(input_lines, '\n') - edit_state.test_cases[i].expected = table.concat(expected_lines, '\n') - end - end - - local contest_data = cache.get_contest_data(platform, contest_id) - local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test - or false - - local combined_input = table.concat( - vim.tbl_map(function(tc) - return tc.input - end, edit_state.test_cases), - '\n' - ) - local combined_expected = table.concat( - vim.tbl_map(function(tc) - return tc.expected - end, edit_state.test_cases), - '\n' - ) - - cache.set_test_cases( - platform, - contest_id, - problem_id, - { input = combined_input, expected = combined_expected }, - edit_state.test_cases, - edit_state.constraints and edit_state.constraints.timeout_ms or 0, - edit_state.constraints and edit_state.constraints.memory_mb or 0, - false, - is_multi_test - ) - - local config = config_module.get_config() - local base_name = config.filename and config.filename(platform, contest_id, problem_id, config) - or config_module.default_filename(contest_id, problem_id) - - vim.fn.mkdir('io', 'p') - - for i, tc in ipairs(edit_state.test_cases) do - local input_file = string.format('io/%s.%d.cpin', base_name, i) - local expected_file = string.format('io/%s.%d.cpout', base_name, i) - - local input_content = (tc.input or ''):gsub('\r', '') - local expected_content = (tc.expected or ''):gsub('\r', '') - - vim.fn.writefile(vim.split(input_content, '\n', { trimempty = true }), input_file) - vim.fn.writefile(vim.split(expected_content, '\n', { trimempty = true }), expected_file) - end - - logger.log('Saved all test cases') -end - ---@param buf integer setup_keybindings = function(buf) local config = config_module.get_config() @@ -317,30 +243,86 @@ setup_keybindings = function(buf) end) end, }) +end - vim.api.nvim_create_autocmd('BufWriteCmd', { - group = augroup, - buffer = buf, - callback = function() - save_all_tests() - vim.bo[buf].modified = false - end, - }) +local function save_all_tests() + if not edit_state then + return + end + + local platform = state.get_platform() + local contest_id = state.get_contest_id() + local problem_id = state.get_problem_id() + + if not platform or not contest_id or not problem_id then + return + end + + for i, pair in ipairs(edit_state.test_buffers) do + if + vim.api.nvim_buf_is_valid(pair.input_buf) and vim.api.nvim_buf_is_valid(pair.expected_buf) + then + local input_lines = vim.api.nvim_buf_get_lines(pair.input_buf, 0, -1, false) + local expected_lines = vim.api.nvim_buf_get_lines(pair.expected_buf, 0, -1, false) + + edit_state.test_cases[i].input = table.concat(input_lines, '\n') + edit_state.test_cases[i].expected = table.concat(expected_lines, '\n') + end + end + + local contest_data = cache.get_contest_data(platform, contest_id) + local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test + or false + + -- Generate combined test from individual test cases + local combined_input = table.concat( + vim.tbl_map(function(tc) + return tc.input + end, edit_state.test_cases), + '\n' + ) + local combined_expected = table.concat( + vim.tbl_map(function(tc) + return tc.expected + end, edit_state.test_cases), + '\n' + ) + + cache.set_test_cases( + platform, + contest_id, + problem_id, + { input = combined_input, expected = combined_expected }, + edit_state.test_cases, + edit_state.constraints and edit_state.constraints.timeout_ms or 0, + edit_state.constraints and edit_state.constraints.memory_mb or 0, + false, + is_multi_test + ) + + local config = config_module.get_config() + local base_name = config.filename and config.filename(platform, contest_id, problem_id, config) + or config_module.default_filename(contest_id, problem_id) + + vim.fn.mkdir('io', 'p') + + for i, tc in ipairs(edit_state.test_cases) do + local input_file = string.format('io/%s.%d.cpin', base_name, i) + local expected_file = string.format('io/%s.%d.cpout', base_name, i) + + local input_content = (tc.input or ''):gsub('\r', '') + local expected_content = (tc.expected or ''):gsub('\r', '') + + vim.fn.writefile(vim.split(input_content, '\n', { trimempty = true }), input_file) + vim.fn.writefile(vim.split(expected_content, '\n', { trimempty = true }), expected_file) + end + + logger.log('Saved all test cases') end function M.toggle_edit(test_index) if edit_state then save_all_tests() - - for _, pair in ipairs(edit_state.test_buffers) do - if vim.api.nvim_buf_is_valid(pair.input_buf) then - vim.api.nvim_buf_delete(pair.input_buf, { force = true }) - end - if vim.api.nvim_buf_is_valid(pair.expected_buf) then - vim.api.nvim_buf_delete(pair.expected_buf, { force = true }) - end - end - edit_state = nil pcall(vim.api.nvim_clear_autocmds, { group = 'cp_edit_guard' }) @@ -429,7 +411,7 @@ function M.toggle_edit(test_index) vim.api.nvim_win_set_buf(input_win, input_buf) vim.bo[input_buf].modifiable = true vim.bo[input_buf].readonly = false - vim.bo[input_buf].buftype = 'acwrite' + vim.bo[input_buf].buftype = 'nofile' vim.bo[input_buf].buflisted = false helpers.clearcol(input_buf) @@ -439,7 +421,7 @@ function M.toggle_edit(test_index) vim.api.nvim_win_set_buf(expected_win, expected_buf) vim.bo[expected_buf].modifiable = true vim.bo[expected_buf].readonly = false - vim.bo[expected_buf].buftype = 'acwrite' + vim.bo[expected_buf].buftype = 'nofile' vim.bo[expected_buf].buflisted = false helpers.clearcol(expected_buf) diff --git a/scrapers/atcoder.py b/scrapers/atcoder.py index 81ec211..14debdc 100644 --- a/scrapers/atcoder.py +++ b/scrapers/atcoder.py @@ -2,7 +2,6 @@ import asyncio import json -import os import re import sys import time @@ -16,10 +15,16 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry from .base import BaseScraper, extract_precision -from .language_ids import get_language_id -from .models import (CombinedTest, ContestListResult, ContestSummary, - MetadataResult, ProblemSummary, SubmitResult, TestCase, - TestsResult) +from .models import ( + CombinedTest, + ContestListResult, + ContestSummary, + MetadataResult, + ProblemSummary, + SubmitResult, + TestCase, + TestsResult, +) MIB_TO_MB = 1.048576 BASE_URL = "https://atcoder.jp" @@ -373,12 +378,10 @@ class AtcoderScraper(BaseScraper): credentials: dict[str, str], ) -> SubmitResult: def _submit_sync() -> SubmitResult: - from curl_cffi import requests as curl_requests - try: - session = curl_requests.Session(impersonate="chrome") - - login_page = session.get(f"{BASE_URL}/login", timeout=TIMEOUT_SECONDS) + login_page = _session.get( + f"{BASE_URL}/login", headers=HEADERS, timeout=TIMEOUT_SECONDS + ) login_page.raise_for_status() soup = BeautifulSoup(login_page.text, "html.parser") csrf_input = soup.find("input", {"name": "csrf_token"}) @@ -388,29 +391,21 @@ class AtcoderScraper(BaseScraper): ) csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr] - login_resp = session.post( + login_resp = _session.post( f"{BASE_URL}/login", data={ "username": credentials.get("username", ""), "password": credentials.get("password", ""), "csrf_token": csrf_token, }, + headers=HEADERS, timeout=TIMEOUT_SECONDS, - allow_redirects=False, ) - if login_resp.status_code in (301, 302): - location = login_resp.headers.get("Location", "") - if "/login" in location: - return SubmitResult( - success=False, - error="Login failed: incorrect username or password", - ) - session.get(BASE_URL + location, timeout=TIMEOUT_SECONDS) - else: - login_resp.raise_for_status() + login_resp.raise_for_status() - submit_page = session.get( + submit_page = _session.get( f"{BASE_URL}/contests/{contest_id}/submit", + headers=HEADERS, timeout=TIMEOUT_SECONDS, ) submit_page.raise_for_status() @@ -423,7 +418,7 @@ class AtcoderScraper(BaseScraper): csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr] task_screen_name = f"{contest_id}_{problem_id}" - submit_resp = session.post( + submit_resp = _session.post( f"{BASE_URL}/contests/{contest_id}/submit", data={ "data.TaskScreenName": task_screen_name, @@ -431,26 +426,13 @@ class AtcoderScraper(BaseScraper): "sourceCode": source_code, "csrf_token": csrf_token, }, + headers=HEADERS, timeout=TIMEOUT_SECONDS, - allow_redirects=False, ) - if submit_resp.status_code in (301, 302): - location = submit_resp.headers.get("Location", "") - if "/submissions/me" in location: - return SubmitResult( - success=True, - error="", - submission_id="", - verdict="submitted", - ) - return SubmitResult( - success=False, - error=f"Submit may have failed: redirected to {location}", - ) submit_resp.raise_for_status() + return SubmitResult( - success=False, - error="Unexpected response from submit (expected redirect)", + success=True, error="", submission_id="", verdict="submitted" ) except Exception as e: return SubmitResult(success=False, error=str(e)) @@ -513,31 +495,9 @@ async def main_async() -> int: print(contest_result.model_dump_json()) return 0 if contest_result.success else 1 - if mode == "submit": - if len(sys.argv) != 5: - print( - SubmitResult( - success=False, - error="Usage: atcoder.py submit ", - ).model_dump_json() - ) - return 1 - source_code = sys.stdin.read() - creds_raw = os.environ.get("CP_CREDENTIALS", "{}") - try: - credentials = json.loads(creds_raw) - except json.JSONDecodeError: - credentials = {} - language_id = get_language_id("atcoder", sys.argv[4]) or sys.argv[4] - submit_result = await scraper.submit( - sys.argv[2], sys.argv[3], source_code, language_id, credentials - ) - print(submit_result.model_dump_json()) - return 0 if submit_result.success else 1 - result = MetadataResult( success=False, - error="Unknown mode. Use 'metadata ', 'tests ', 'contests', or 'submit '", + error="Unknown mode. Use 'metadata ', 'tests ', or 'contests'", url="", ) print(result.model_dump_json()) diff --git a/scrapers/base.py b/scrapers/base.py index a1b6978..ed0636b 100644 --- a/scrapers/base.py +++ b/scrapers/base.py @@ -6,8 +6,13 @@ import sys from abc import ABC, abstractmethod from .language_ids import get_language_id -from .models import (CombinedTest, ContestListResult, MetadataResult, - SubmitResult, TestsResult) +from .models import ( + CombinedTest, + ContestListResult, + MetadataResult, + SubmitResult, + TestsResult, +) _PRECISION_ABS_REL_RE = re.compile( r"(?:absolute|relative)\s+error[^.]*?10\s*[\^{]\s*\{?\s*[-\u2212]\s*(\d+)\s*\}?", diff --git a/scrapers/codechef.py b/scrapers/codechef.py index a99c03c..57ce33e 100644 --- a/scrapers/codechef.py +++ b/scrapers/codechef.py @@ -9,8 +9,14 @@ import httpx from curl_cffi import requests as curl_requests from .base import BaseScraper, extract_precision -from .models import (ContestListResult, ContestSummary, MetadataResult, - ProblemSummary, SubmitResult, TestCase) +from .models import ( + ContestListResult, + ContestSummary, + MetadataResult, + ProblemSummary, + SubmitResult, + TestCase, +) BASE_URL = "https://www.codechef.com" API_CONTESTS_ALL = "/api/list/contests/all" diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index 53aef24..c0495d8 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -10,8 +10,14 @@ from bs4 import BeautifulSoup, Tag from curl_cffi import requests as curl_requests from .base import BaseScraper, extract_precision -from .models import (ContestListResult, ContestSummary, MetadataResult, - ProblemSummary, SubmitResult, TestCase) +from .models import ( + ContestListResult, + ContestSummary, + MetadataResult, + ProblemSummary, + SubmitResult, + TestCase, +) BASE_URL = "https://codeforces.com" API_CONTEST_LIST_URL = f"{BASE_URL}/api/contest.list" diff --git a/scrapers/cses.py b/scrapers/cses.py index 644b5a9..473558f 100644 --- a/scrapers/cses.py +++ b/scrapers/cses.py @@ -8,8 +8,14 @@ from typing import Any import httpx from .base import BaseScraper, extract_precision -from .models import (ContestListResult, ContestSummary, MetadataResult, - ProblemSummary, SubmitResult, TestCase) +from .models import ( + ContestListResult, + ContestSummary, + MetadataResult, + ProblemSummary, + SubmitResult, + TestCase, +) BASE_URL = "https://cses.fi" INDEX_PATH = "/problemset" diff --git a/scrapers/kattis.py b/scrapers/kattis.py index 1079081..d1675bf 100644 --- a/scrapers/kattis.py +++ b/scrapers/kattis.py @@ -10,8 +10,14 @@ from datetime import datetime import httpx from .base import BaseScraper -from .models import (ContestListResult, ContestSummary, MetadataResult, - ProblemSummary, SubmitResult, TestCase) +from .models import ( + ContestListResult, + ContestSummary, + MetadataResult, + ProblemSummary, + SubmitResult, + TestCase, +) BASE_URL = "https://open.kattis.com" HEADERS = { diff --git a/scrapers/usaco.py b/scrapers/usaco.py index e933c47..565f1b5 100644 --- a/scrapers/usaco.py +++ b/scrapers/usaco.py @@ -8,8 +8,14 @@ from typing import Any, cast import httpx from .base import BaseScraper -from .models import (ContestListResult, ContestSummary, MetadataResult, - ProblemSummary, SubmitResult, TestCase) +from .models import ( + ContestListResult, + ContestSummary, + MetadataResult, + ProblemSummary, + SubmitResult, + TestCase, +) BASE_URL = "http://www.usaco.org" HEADERS = { @@ -31,7 +37,8 @@ DIVISION_HEADING_RE = re.compile( re.IGNORECASE, ) PROBLEM_BLOCK_RE = re.compile( - r"([^<]+)\s*.*?" r"viewproblem2&cpid=(\d+)", + r"([^<]+)\s*.*?" + r"viewproblem2&cpid=(\d+)", re.DOTALL, ) SAMPLE_IN_RE = re.compile(r"(.*?)", re.DOTALL) diff --git a/tests/test_scrapers.py b/tests/test_scrapers.py index e4f9377..8ce468f 100644 --- a/tests/test_scrapers.py +++ b/tests/test_scrapers.py @@ -1,6 +1,10 @@ import pytest -from scrapers.models import ContestListResult, MetadataResult, TestsResult +from scrapers.models import ( + ContestListResult, + MetadataResult, + TestsResult, +) MATRIX = { "cses": { @@ -57,9 +61,9 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode): assert hasattr(tr.combined, "input"), "combined missing input" assert hasattr(tr.combined, "expected"), "combined missing expected" assert isinstance(tr.combined.input, str), "combined.input not string" - assert isinstance( - tr.combined.expected, str - ), "combined.expected not string" + assert isinstance(tr.combined.expected, str), ( + "combined.expected not string" + ) assert hasattr(tr, "multi_test"), "Missing multi_test field" assert isinstance(tr.multi_test, bool), "multi_test not boolean" validated_any = True @@ -73,12 +77,12 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode): assert isinstance(obj["combined"], dict), "combined not a dict" assert "input" in obj["combined"], "combined missing input key" assert "expected" in obj["combined"], "combined missing expected key" - assert isinstance( - obj["combined"]["input"], str - ), "combined.input not string" - assert isinstance( - obj["combined"]["expected"], str - ), "combined.expected not string" + assert isinstance(obj["combined"]["input"], str), ( + "combined.input not string" + ) + assert isinstance(obj["combined"]["expected"], str), ( + "combined.expected not string" + ) assert "multi_test" in obj, "Missing multi_test field in raw JSON" assert isinstance(obj["multi_test"], bool), "multi_test not boolean" validated_any = True