diff --git a/lua/cp/scraper.lua b/lua/cp/scraper.lua index 194e671..7f774c3 100644 --- a/lua/cp/scraper.lua +++ b/lua/cp/scraper.lua @@ -5,19 +5,19 @@ local logger = require('cp.log') local utils = require('cp.utils') local function syshandle(result) + local ok, data = pcall(vim.json.decode, result.stdout or '') + if ok then + return { success = true, data = data } + end + if result.code ~= 0 then local msg = 'Scraper failed: ' .. (result.stderr or 'Unknown error') return { success = false, error = msg } end - local ok, data = pcall(vim.json.decode, result.stdout) - if not ok then - local msg = 'Failed to parse scraper output: ' .. tostring(data) - logger.log(msg, vim.log.levels.ERROR) - return { success = false, error = msg } - end - - return { success = true, data = data } + local msg = 'Failed to parse scraper output: ' .. tostring(data) + logger.log(msg, vim.log.levels.ERROR) + return { success = false, error = msg } end ---@param env_map table diff --git a/lua/cp/ui/edit.lua b/lua/cp/ui/edit.lua index 886c50a..20d4e83 100644 --- a/lua/cp/ui/edit.lua +++ b/lua/cp/ui/edit.lua @@ -144,7 +144,7 @@ local function add_new_test() vim.api.nvim_win_set_buf(input_win, input_buf) vim.bo[input_buf].modifiable = true vim.bo[input_buf].readonly = false - vim.bo[input_buf].buftype = 'nofile' + vim.bo[input_buf].buftype = 'acwrite' vim.bo[input_buf].buflisted = false helpers.clearcol(input_buf) @@ -155,7 +155,7 @@ local function add_new_test() vim.api.nvim_win_set_buf(expected_win, expected_buf) vim.bo[expected_buf].modifiable = true vim.bo[expected_buf].readonly = false - vim.bo[expected_buf].buftype = 'nofile' + vim.bo[expected_buf].buftype = 'acwrite' vim.bo[expected_buf].buflisted = false helpers.clearcol(expected_buf) @@ -177,6 +177,80 @@ local function add_new_test() logger.log(('Added test %d'):format(new_index)) end +local function save_all_tests() + if not edit_state then + return + end + + local platform = state.get_platform() + local contest_id = state.get_contest_id() + local problem_id = state.get_problem_id() + + if not platform or not contest_id or not problem_id then + return + end + + for i, pair in ipairs(edit_state.test_buffers) do + if + vim.api.nvim_buf_is_valid(pair.input_buf) and vim.api.nvim_buf_is_valid(pair.expected_buf) + then + local input_lines = vim.api.nvim_buf_get_lines(pair.input_buf, 0, -1, false) + local expected_lines = vim.api.nvim_buf_get_lines(pair.expected_buf, 0, -1, false) + + edit_state.test_cases[i].input = table.concat(input_lines, '\n') + edit_state.test_cases[i].expected = table.concat(expected_lines, '\n') + end + end + + local contest_data = cache.get_contest_data(platform, contest_id) + local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test + or false + + local combined_input = table.concat( + vim.tbl_map(function(tc) + return tc.input + end, edit_state.test_cases), + '\n' + ) + local combined_expected = table.concat( + vim.tbl_map(function(tc) + return tc.expected + end, edit_state.test_cases), + '\n' + ) + + cache.set_test_cases( + platform, + contest_id, + problem_id, + { input = combined_input, expected = combined_expected }, + edit_state.test_cases, + edit_state.constraints and edit_state.constraints.timeout_ms or 0, + edit_state.constraints and edit_state.constraints.memory_mb or 0, + false, + is_multi_test + ) + + local config = config_module.get_config() + local base_name = config.filename and config.filename(platform, contest_id, problem_id, config) + or config_module.default_filename(contest_id, problem_id) + + vim.fn.mkdir('io', 'p') + + for i, tc in ipairs(edit_state.test_cases) do + local input_file = string.format('io/%s.%d.cpin', base_name, i) + local expected_file = string.format('io/%s.%d.cpout', base_name, i) + + local input_content = (tc.input or ''):gsub('\r', '') + local expected_content = (tc.expected or ''):gsub('\r', '') + + vim.fn.writefile(vim.split(input_content, '\n', { trimempty = true }), input_file) + vim.fn.writefile(vim.split(expected_content, '\n', { trimempty = true }), expected_file) + end + + logger.log('Saved all test cases') +end + ---@param buf integer setup_keybindings = function(buf) local config = config_module.get_config() @@ -243,86 +317,30 @@ setup_keybindings = function(buf) end) end, }) -end -local function save_all_tests() - if not edit_state then - return - end - - local platform = state.get_platform() - local contest_id = state.get_contest_id() - local problem_id = state.get_problem_id() - - if not platform or not contest_id or not problem_id then - return - end - - for i, pair in ipairs(edit_state.test_buffers) do - if - vim.api.nvim_buf_is_valid(pair.input_buf) and vim.api.nvim_buf_is_valid(pair.expected_buf) - then - local input_lines = vim.api.nvim_buf_get_lines(pair.input_buf, 0, -1, false) - local expected_lines = vim.api.nvim_buf_get_lines(pair.expected_buf, 0, -1, false) - - edit_state.test_cases[i].input = table.concat(input_lines, '\n') - edit_state.test_cases[i].expected = table.concat(expected_lines, '\n') - end - end - - local contest_data = cache.get_contest_data(platform, contest_id) - local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test - or false - - -- Generate combined test from individual test cases - local combined_input = table.concat( - vim.tbl_map(function(tc) - return tc.input - end, edit_state.test_cases), - '\n' - ) - local combined_expected = table.concat( - vim.tbl_map(function(tc) - return tc.expected - end, edit_state.test_cases), - '\n' - ) - - cache.set_test_cases( - platform, - contest_id, - problem_id, - { input = combined_input, expected = combined_expected }, - edit_state.test_cases, - edit_state.constraints and edit_state.constraints.timeout_ms or 0, - edit_state.constraints and edit_state.constraints.memory_mb or 0, - false, - is_multi_test - ) - - local config = config_module.get_config() - local base_name = config.filename and config.filename(platform, contest_id, problem_id, config) - or config_module.default_filename(contest_id, problem_id) - - vim.fn.mkdir('io', 'p') - - for i, tc in ipairs(edit_state.test_cases) do - local input_file = string.format('io/%s.%d.cpin', base_name, i) - local expected_file = string.format('io/%s.%d.cpout', base_name, i) - - local input_content = (tc.input or ''):gsub('\r', '') - local expected_content = (tc.expected or ''):gsub('\r', '') - - vim.fn.writefile(vim.split(input_content, '\n', { trimempty = true }), input_file) - vim.fn.writefile(vim.split(expected_content, '\n', { trimempty = true }), expected_file) - end - - logger.log('Saved all test cases') + vim.api.nvim_create_autocmd('BufWriteCmd', { + group = augroup, + buffer = buf, + callback = function() + save_all_tests() + vim.bo[buf].modified = false + end, + }) end function M.toggle_edit(test_index) if edit_state then save_all_tests() + + for _, pair in ipairs(edit_state.test_buffers) do + if vim.api.nvim_buf_is_valid(pair.input_buf) then + vim.api.nvim_buf_delete(pair.input_buf, { force = true }) + end + if vim.api.nvim_buf_is_valid(pair.expected_buf) then + vim.api.nvim_buf_delete(pair.expected_buf, { force = true }) + end + end + edit_state = nil pcall(vim.api.nvim_clear_autocmds, { group = 'cp_edit_guard' }) @@ -411,7 +429,7 @@ function M.toggle_edit(test_index) vim.api.nvim_win_set_buf(input_win, input_buf) vim.bo[input_buf].modifiable = true vim.bo[input_buf].readonly = false - vim.bo[input_buf].buftype = 'nofile' + vim.bo[input_buf].buftype = 'acwrite' vim.bo[input_buf].buflisted = false helpers.clearcol(input_buf) @@ -421,7 +439,7 @@ function M.toggle_edit(test_index) vim.api.nvim_win_set_buf(expected_win, expected_buf) vim.bo[expected_buf].modifiable = true vim.bo[expected_buf].readonly = false - vim.bo[expected_buf].buftype = 'nofile' + vim.bo[expected_buf].buftype = 'acwrite' vim.bo[expected_buf].buflisted = false helpers.clearcol(expected_buf) diff --git a/scrapers/atcoder.py b/scrapers/atcoder.py index 14debdc..81ec211 100644 --- a/scrapers/atcoder.py +++ b/scrapers/atcoder.py @@ -2,6 +2,7 @@ import asyncio import json +import os import re import sys import time @@ -15,16 +16,10 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry from .base import BaseScraper, extract_precision -from .models import ( - CombinedTest, - ContestListResult, - ContestSummary, - MetadataResult, - ProblemSummary, - SubmitResult, - TestCase, - TestsResult, -) +from .language_ids import get_language_id +from .models import (CombinedTest, ContestListResult, ContestSummary, + MetadataResult, ProblemSummary, SubmitResult, TestCase, + TestsResult) MIB_TO_MB = 1.048576 BASE_URL = "https://atcoder.jp" @@ -378,10 +373,12 @@ class AtcoderScraper(BaseScraper): credentials: dict[str, str], ) -> SubmitResult: def _submit_sync() -> SubmitResult: + from curl_cffi import requests as curl_requests + try: - login_page = _session.get( - f"{BASE_URL}/login", headers=HEADERS, timeout=TIMEOUT_SECONDS - ) + session = curl_requests.Session(impersonate="chrome") + + login_page = session.get(f"{BASE_URL}/login", timeout=TIMEOUT_SECONDS) login_page.raise_for_status() soup = BeautifulSoup(login_page.text, "html.parser") csrf_input = soup.find("input", {"name": "csrf_token"}) @@ -391,21 +388,29 @@ class AtcoderScraper(BaseScraper): ) csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr] - login_resp = _session.post( + login_resp = session.post( f"{BASE_URL}/login", data={ "username": credentials.get("username", ""), "password": credentials.get("password", ""), "csrf_token": csrf_token, }, - headers=HEADERS, timeout=TIMEOUT_SECONDS, + allow_redirects=False, ) - login_resp.raise_for_status() + if login_resp.status_code in (301, 302): + location = login_resp.headers.get("Location", "") + if "/login" in location: + return SubmitResult( + success=False, + error="Login failed: incorrect username or password", + ) + session.get(BASE_URL + location, timeout=TIMEOUT_SECONDS) + else: + login_resp.raise_for_status() - submit_page = _session.get( + submit_page = session.get( f"{BASE_URL}/contests/{contest_id}/submit", - headers=HEADERS, timeout=TIMEOUT_SECONDS, ) submit_page.raise_for_status() @@ -418,7 +423,7 @@ class AtcoderScraper(BaseScraper): csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr] task_screen_name = f"{contest_id}_{problem_id}" - submit_resp = _session.post( + submit_resp = session.post( f"{BASE_URL}/contests/{contest_id}/submit", data={ "data.TaskScreenName": task_screen_name, @@ -426,13 +431,26 @@ class AtcoderScraper(BaseScraper): "sourceCode": source_code, "csrf_token": csrf_token, }, - headers=HEADERS, timeout=TIMEOUT_SECONDS, + allow_redirects=False, ) + if submit_resp.status_code in (301, 302): + location = submit_resp.headers.get("Location", "") + if "/submissions/me" in location: + return SubmitResult( + success=True, + error="", + submission_id="", + verdict="submitted", + ) + return SubmitResult( + success=False, + error=f"Submit may have failed: redirected to {location}", + ) submit_resp.raise_for_status() - return SubmitResult( - success=True, error="", submission_id="", verdict="submitted" + success=False, + error="Unexpected response from submit (expected redirect)", ) except Exception as e: return SubmitResult(success=False, error=str(e)) @@ -495,9 +513,31 @@ async def main_async() -> int: print(contest_result.model_dump_json()) return 0 if contest_result.success else 1 + if mode == "submit": + if len(sys.argv) != 5: + print( + SubmitResult( + success=False, + error="Usage: atcoder.py submit ", + ).model_dump_json() + ) + return 1 + source_code = sys.stdin.read() + creds_raw = os.environ.get("CP_CREDENTIALS", "{}") + try: + credentials = json.loads(creds_raw) + except json.JSONDecodeError: + credentials = {} + language_id = get_language_id("atcoder", sys.argv[4]) or sys.argv[4] + submit_result = await scraper.submit( + sys.argv[2], sys.argv[3], source_code, language_id, credentials + ) + print(submit_result.model_dump_json()) + return 0 if submit_result.success else 1 + result = MetadataResult( success=False, - error="Unknown mode. Use 'metadata ', 'tests ', or 'contests'", + error="Unknown mode. Use 'metadata ', 'tests ', 'contests', or 'submit '", url="", ) print(result.model_dump_json()) diff --git a/scrapers/base.py b/scrapers/base.py index ed0636b..a1b6978 100644 --- a/scrapers/base.py +++ b/scrapers/base.py @@ -6,13 +6,8 @@ import sys from abc import ABC, abstractmethod from .language_ids import get_language_id -from .models import ( - CombinedTest, - ContestListResult, - MetadataResult, - SubmitResult, - TestsResult, -) +from .models import (CombinedTest, ContestListResult, MetadataResult, + SubmitResult, TestsResult) _PRECISION_ABS_REL_RE = re.compile( r"(?:absolute|relative)\s+error[^.]*?10\s*[\^{]\s*\{?\s*[-\u2212]\s*(\d+)\s*\}?", diff --git a/scrapers/codechef.py b/scrapers/codechef.py index 57ce33e..a99c03c 100644 --- a/scrapers/codechef.py +++ b/scrapers/codechef.py @@ -9,14 +9,8 @@ import httpx from curl_cffi import requests as curl_requests from .base import BaseScraper, extract_precision -from .models import ( - ContestListResult, - ContestSummary, - MetadataResult, - ProblemSummary, - SubmitResult, - TestCase, -) +from .models import (ContestListResult, ContestSummary, MetadataResult, + ProblemSummary, SubmitResult, TestCase) BASE_URL = "https://www.codechef.com" API_CONTESTS_ALL = "/api/list/contests/all" diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index c0495d8..53aef24 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -10,14 +10,8 @@ from bs4 import BeautifulSoup, Tag from curl_cffi import requests as curl_requests from .base import BaseScraper, extract_precision -from .models import ( - ContestListResult, - ContestSummary, - MetadataResult, - ProblemSummary, - SubmitResult, - TestCase, -) +from .models import (ContestListResult, ContestSummary, MetadataResult, + ProblemSummary, SubmitResult, TestCase) BASE_URL = "https://codeforces.com" API_CONTEST_LIST_URL = f"{BASE_URL}/api/contest.list" diff --git a/scrapers/cses.py b/scrapers/cses.py index 473558f..644b5a9 100644 --- a/scrapers/cses.py +++ b/scrapers/cses.py @@ -8,14 +8,8 @@ from typing import Any import httpx from .base import BaseScraper, extract_precision -from .models import ( - ContestListResult, - ContestSummary, - MetadataResult, - ProblemSummary, - SubmitResult, - TestCase, -) +from .models import (ContestListResult, ContestSummary, MetadataResult, + ProblemSummary, SubmitResult, TestCase) BASE_URL = "https://cses.fi" INDEX_PATH = "/problemset" diff --git a/scrapers/kattis.py b/scrapers/kattis.py index d1675bf..1079081 100644 --- a/scrapers/kattis.py +++ b/scrapers/kattis.py @@ -10,14 +10,8 @@ from datetime import datetime import httpx from .base import BaseScraper -from .models import ( - ContestListResult, - ContestSummary, - MetadataResult, - ProblemSummary, - SubmitResult, - TestCase, -) +from .models import (ContestListResult, ContestSummary, MetadataResult, + ProblemSummary, SubmitResult, TestCase) BASE_URL = "https://open.kattis.com" HEADERS = { diff --git a/scrapers/usaco.py b/scrapers/usaco.py index 565f1b5..e933c47 100644 --- a/scrapers/usaco.py +++ b/scrapers/usaco.py @@ -8,14 +8,8 @@ from typing import Any, cast import httpx from .base import BaseScraper -from .models import ( - ContestListResult, - ContestSummary, - MetadataResult, - ProblemSummary, - SubmitResult, - TestCase, -) +from .models import (ContestListResult, ContestSummary, MetadataResult, + ProblemSummary, SubmitResult, TestCase) BASE_URL = "http://www.usaco.org" HEADERS = { @@ -37,8 +31,7 @@ DIVISION_HEADING_RE = re.compile( re.IGNORECASE, ) PROBLEM_BLOCK_RE = re.compile( - r"([^<]+)\s*.*?" - r"viewproblem2&cpid=(\d+)", + r"([^<]+)\s*.*?" r"viewproblem2&cpid=(\d+)", re.DOTALL, ) SAMPLE_IN_RE = re.compile(r"(.*?)", re.DOTALL) diff --git a/tests/test_scrapers.py b/tests/test_scrapers.py index 8ce468f..e4f9377 100644 --- a/tests/test_scrapers.py +++ b/tests/test_scrapers.py @@ -1,10 +1,6 @@ import pytest -from scrapers.models import ( - ContestListResult, - MetadataResult, - TestsResult, -) +from scrapers.models import ContestListResult, MetadataResult, TestsResult MATRIX = { "cses": { @@ -61,9 +57,9 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode): assert hasattr(tr.combined, "input"), "combined missing input" assert hasattr(tr.combined, "expected"), "combined missing expected" assert isinstance(tr.combined.input, str), "combined.input not string" - assert isinstance(tr.combined.expected, str), ( - "combined.expected not string" - ) + assert isinstance( + tr.combined.expected, str + ), "combined.expected not string" assert hasattr(tr, "multi_test"), "Missing multi_test field" assert isinstance(tr.multi_test, bool), "multi_test not boolean" validated_any = True @@ -77,12 +73,12 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode): assert isinstance(obj["combined"], dict), "combined not a dict" assert "input" in obj["combined"], "combined missing input key" assert "expected" in obj["combined"], "combined missing expected key" - assert isinstance(obj["combined"]["input"], str), ( - "combined.input not string" - ) - assert isinstance(obj["combined"]["expected"], str), ( - "combined.expected not string" - ) + assert isinstance( + obj["combined"]["input"], str + ), "combined.input not string" + assert isinstance( + obj["combined"]["expected"], str + ), "combined.expected not string" assert "multi_test" in obj, "Missing multi_test field in raw JSON" assert isinstance(obj["multi_test"], bool), "multi_test not boolean" validated_any = True