Compare commits

..

No commits in common. "ff03b932b1233285490480b6b50fa8a5c4a915bc" and "f17eb32e8c79156638cd17021acab174ac947a79" have entirely different histories.

10 changed files with 172 additions and 190 deletions

View file

@ -5,19 +5,19 @@ local logger = require('cp.log')
local utils = require('cp.utils') local utils = require('cp.utils')
local function syshandle(result) local function syshandle(result)
local ok, data = pcall(vim.json.decode, result.stdout or '')
if ok then
return { success = true, data = data }
end
if result.code ~= 0 then if result.code ~= 0 then
local msg = 'Scraper failed: ' .. (result.stderr or 'Unknown error') local msg = 'Scraper failed: ' .. (result.stderr or 'Unknown error')
return { success = false, error = msg } return { success = false, error = msg }
end end
local ok, data = pcall(vim.json.decode, result.stdout)
if not ok then
local msg = 'Failed to parse scraper output: ' .. tostring(data) local msg = 'Failed to parse scraper output: ' .. tostring(data)
logger.log(msg, vim.log.levels.ERROR) logger.log(msg, vim.log.levels.ERROR)
return { success = false, error = msg } return { success = false, error = msg }
end
return { success = true, data = data }
end end
---@param env_map table<string, string> ---@param env_map table<string, string>

View file

@ -144,7 +144,7 @@ local function add_new_test()
vim.api.nvim_win_set_buf(input_win, input_buf) vim.api.nvim_win_set_buf(input_win, input_buf)
vim.bo[input_buf].modifiable = true vim.bo[input_buf].modifiable = true
vim.bo[input_buf].readonly = false vim.bo[input_buf].readonly = false
vim.bo[input_buf].buftype = 'acwrite' vim.bo[input_buf].buftype = 'nofile'
vim.bo[input_buf].buflisted = false vim.bo[input_buf].buflisted = false
helpers.clearcol(input_buf) helpers.clearcol(input_buf)
@ -155,7 +155,7 @@ local function add_new_test()
vim.api.nvim_win_set_buf(expected_win, expected_buf) vim.api.nvim_win_set_buf(expected_win, expected_buf)
vim.bo[expected_buf].modifiable = true vim.bo[expected_buf].modifiable = true
vim.bo[expected_buf].readonly = false vim.bo[expected_buf].readonly = false
vim.bo[expected_buf].buftype = 'acwrite' vim.bo[expected_buf].buftype = 'nofile'
vim.bo[expected_buf].buflisted = false vim.bo[expected_buf].buflisted = false
helpers.clearcol(expected_buf) helpers.clearcol(expected_buf)
@ -177,80 +177,6 @@ local function add_new_test()
logger.log(('Added test %d'):format(new_index)) logger.log(('Added test %d'):format(new_index))
end end
local function save_all_tests()
if not edit_state then
return
end
local platform = state.get_platform()
local contest_id = state.get_contest_id()
local problem_id = state.get_problem_id()
if not platform or not contest_id or not problem_id then
return
end
for i, pair in ipairs(edit_state.test_buffers) do
if
vim.api.nvim_buf_is_valid(pair.input_buf) and vim.api.nvim_buf_is_valid(pair.expected_buf)
then
local input_lines = vim.api.nvim_buf_get_lines(pair.input_buf, 0, -1, false)
local expected_lines = vim.api.nvim_buf_get_lines(pair.expected_buf, 0, -1, false)
edit_state.test_cases[i].input = table.concat(input_lines, '\n')
edit_state.test_cases[i].expected = table.concat(expected_lines, '\n')
end
end
local contest_data = cache.get_contest_data(platform, contest_id)
local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test
or false
local combined_input = table.concat(
vim.tbl_map(function(tc)
return tc.input
end, edit_state.test_cases),
'\n'
)
local combined_expected = table.concat(
vim.tbl_map(function(tc)
return tc.expected
end, edit_state.test_cases),
'\n'
)
cache.set_test_cases(
platform,
contest_id,
problem_id,
{ input = combined_input, expected = combined_expected },
edit_state.test_cases,
edit_state.constraints and edit_state.constraints.timeout_ms or 0,
edit_state.constraints and edit_state.constraints.memory_mb or 0,
false,
is_multi_test
)
local config = config_module.get_config()
local base_name = config.filename and config.filename(platform, contest_id, problem_id, config)
or config_module.default_filename(contest_id, problem_id)
vim.fn.mkdir('io', 'p')
for i, tc in ipairs(edit_state.test_cases) do
local input_file = string.format('io/%s.%d.cpin', base_name, i)
local expected_file = string.format('io/%s.%d.cpout', base_name, i)
local input_content = (tc.input or ''):gsub('\r', '')
local expected_content = (tc.expected or ''):gsub('\r', '')
vim.fn.writefile(vim.split(input_content, '\n', { trimempty = true }), input_file)
vim.fn.writefile(vim.split(expected_content, '\n', { trimempty = true }), expected_file)
end
logger.log('Saved all test cases')
end
---@param buf integer ---@param buf integer
setup_keybindings = function(buf) setup_keybindings = function(buf)
local config = config_module.get_config() local config = config_module.get_config()
@ -317,30 +243,86 @@ setup_keybindings = function(buf)
end) end)
end, end,
}) })
end
vim.api.nvim_create_autocmd('BufWriteCmd', { local function save_all_tests()
group = augroup, if not edit_state then
buffer = buf, return
callback = function() end
save_all_tests()
vim.bo[buf].modified = false local platform = state.get_platform()
end, local contest_id = state.get_contest_id()
}) local problem_id = state.get_problem_id()
if not platform or not contest_id or not problem_id then
return
end
for i, pair in ipairs(edit_state.test_buffers) do
if
vim.api.nvim_buf_is_valid(pair.input_buf) and vim.api.nvim_buf_is_valid(pair.expected_buf)
then
local input_lines = vim.api.nvim_buf_get_lines(pair.input_buf, 0, -1, false)
local expected_lines = vim.api.nvim_buf_get_lines(pair.expected_buf, 0, -1, false)
edit_state.test_cases[i].input = table.concat(input_lines, '\n')
edit_state.test_cases[i].expected = table.concat(expected_lines, '\n')
end
end
local contest_data = cache.get_contest_data(platform, contest_id)
local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test
or false
-- Generate combined test from individual test cases
local combined_input = table.concat(
vim.tbl_map(function(tc)
return tc.input
end, edit_state.test_cases),
'\n'
)
local combined_expected = table.concat(
vim.tbl_map(function(tc)
return tc.expected
end, edit_state.test_cases),
'\n'
)
cache.set_test_cases(
platform,
contest_id,
problem_id,
{ input = combined_input, expected = combined_expected },
edit_state.test_cases,
edit_state.constraints and edit_state.constraints.timeout_ms or 0,
edit_state.constraints and edit_state.constraints.memory_mb or 0,
false,
is_multi_test
)
local config = config_module.get_config()
local base_name = config.filename and config.filename(platform, contest_id, problem_id, config)
or config_module.default_filename(contest_id, problem_id)
vim.fn.mkdir('io', 'p')
for i, tc in ipairs(edit_state.test_cases) do
local input_file = string.format('io/%s.%d.cpin', base_name, i)
local expected_file = string.format('io/%s.%d.cpout', base_name, i)
local input_content = (tc.input or ''):gsub('\r', '')
local expected_content = (tc.expected or ''):gsub('\r', '')
vim.fn.writefile(vim.split(input_content, '\n', { trimempty = true }), input_file)
vim.fn.writefile(vim.split(expected_content, '\n', { trimempty = true }), expected_file)
end
logger.log('Saved all test cases')
end end
function M.toggle_edit(test_index) function M.toggle_edit(test_index)
if edit_state then if edit_state then
save_all_tests() save_all_tests()
for _, pair in ipairs(edit_state.test_buffers) do
if vim.api.nvim_buf_is_valid(pair.input_buf) then
vim.api.nvim_buf_delete(pair.input_buf, { force = true })
end
if vim.api.nvim_buf_is_valid(pair.expected_buf) then
vim.api.nvim_buf_delete(pair.expected_buf, { force = true })
end
end
edit_state = nil edit_state = nil
pcall(vim.api.nvim_clear_autocmds, { group = 'cp_edit_guard' }) pcall(vim.api.nvim_clear_autocmds, { group = 'cp_edit_guard' })
@ -429,7 +411,7 @@ function M.toggle_edit(test_index)
vim.api.nvim_win_set_buf(input_win, input_buf) vim.api.nvim_win_set_buf(input_win, input_buf)
vim.bo[input_buf].modifiable = true vim.bo[input_buf].modifiable = true
vim.bo[input_buf].readonly = false vim.bo[input_buf].readonly = false
vim.bo[input_buf].buftype = 'acwrite' vim.bo[input_buf].buftype = 'nofile'
vim.bo[input_buf].buflisted = false vim.bo[input_buf].buflisted = false
helpers.clearcol(input_buf) helpers.clearcol(input_buf)
@ -439,7 +421,7 @@ function M.toggle_edit(test_index)
vim.api.nvim_win_set_buf(expected_win, expected_buf) vim.api.nvim_win_set_buf(expected_win, expected_buf)
vim.bo[expected_buf].modifiable = true vim.bo[expected_buf].modifiable = true
vim.bo[expected_buf].readonly = false vim.bo[expected_buf].readonly = false
vim.bo[expected_buf].buftype = 'acwrite' vim.bo[expected_buf].buftype = 'nofile'
vim.bo[expected_buf].buflisted = false vim.bo[expected_buf].buflisted = false
helpers.clearcol(expected_buf) helpers.clearcol(expected_buf)

View file

@ -2,7 +2,6 @@
import asyncio import asyncio
import json import json
import os
import re import re
import sys import sys
import time import time
@ -16,10 +15,16 @@ from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry from urllib3.util.retry import Retry
from .base import BaseScraper, extract_precision from .base import BaseScraper, extract_precision
from .language_ids import get_language_id from .models import (
from .models import (CombinedTest, ContestListResult, ContestSummary, CombinedTest,
MetadataResult, ProblemSummary, SubmitResult, TestCase, ContestListResult,
TestsResult) ContestSummary,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
TestsResult,
)
MIB_TO_MB = 1.048576 MIB_TO_MB = 1.048576
BASE_URL = "https://atcoder.jp" BASE_URL = "https://atcoder.jp"
@ -373,12 +378,10 @@ class AtcoderScraper(BaseScraper):
credentials: dict[str, str], credentials: dict[str, str],
) -> SubmitResult: ) -> SubmitResult:
def _submit_sync() -> SubmitResult: def _submit_sync() -> SubmitResult:
from curl_cffi import requests as curl_requests
try: try:
session = curl_requests.Session(impersonate="chrome") login_page = _session.get(
f"{BASE_URL}/login", headers=HEADERS, timeout=TIMEOUT_SECONDS
login_page = session.get(f"{BASE_URL}/login", timeout=TIMEOUT_SECONDS) )
login_page.raise_for_status() login_page.raise_for_status()
soup = BeautifulSoup(login_page.text, "html.parser") soup = BeautifulSoup(login_page.text, "html.parser")
csrf_input = soup.find("input", {"name": "csrf_token"}) csrf_input = soup.find("input", {"name": "csrf_token"})
@ -388,29 +391,21 @@ class AtcoderScraper(BaseScraper):
) )
csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr] csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr]
login_resp = session.post( login_resp = _session.post(
f"{BASE_URL}/login", f"{BASE_URL}/login",
data={ data={
"username": credentials.get("username", ""), "username": credentials.get("username", ""),
"password": credentials.get("password", ""), "password": credentials.get("password", ""),
"csrf_token": csrf_token, "csrf_token": csrf_token,
}, },
headers=HEADERS,
timeout=TIMEOUT_SECONDS, timeout=TIMEOUT_SECONDS,
allow_redirects=False,
) )
if login_resp.status_code in (301, 302):
location = login_resp.headers.get("Location", "")
if "/login" in location:
return SubmitResult(
success=False,
error="Login failed: incorrect username or password",
)
session.get(BASE_URL + location, timeout=TIMEOUT_SECONDS)
else:
login_resp.raise_for_status() login_resp.raise_for_status()
submit_page = session.get( submit_page = _session.get(
f"{BASE_URL}/contests/{contest_id}/submit", f"{BASE_URL}/contests/{contest_id}/submit",
headers=HEADERS,
timeout=TIMEOUT_SECONDS, timeout=TIMEOUT_SECONDS,
) )
submit_page.raise_for_status() submit_page.raise_for_status()
@ -423,7 +418,7 @@ class AtcoderScraper(BaseScraper):
csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr] csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr]
task_screen_name = f"{contest_id}_{problem_id}" task_screen_name = f"{contest_id}_{problem_id}"
submit_resp = session.post( submit_resp = _session.post(
f"{BASE_URL}/contests/{contest_id}/submit", f"{BASE_URL}/contests/{contest_id}/submit",
data={ data={
"data.TaskScreenName": task_screen_name, "data.TaskScreenName": task_screen_name,
@ -431,26 +426,13 @@ class AtcoderScraper(BaseScraper):
"sourceCode": source_code, "sourceCode": source_code,
"csrf_token": csrf_token, "csrf_token": csrf_token,
}, },
headers=HEADERS,
timeout=TIMEOUT_SECONDS, timeout=TIMEOUT_SECONDS,
allow_redirects=False,
)
if submit_resp.status_code in (301, 302):
location = submit_resp.headers.get("Location", "")
if "/submissions/me" in location:
return SubmitResult(
success=True,
error="",
submission_id="",
verdict="submitted",
)
return SubmitResult(
success=False,
error=f"Submit may have failed: redirected to {location}",
) )
submit_resp.raise_for_status() submit_resp.raise_for_status()
return SubmitResult( return SubmitResult(
success=False, success=True, error="", submission_id="", verdict="submitted"
error="Unexpected response from submit (expected redirect)",
) )
except Exception as e: except Exception as e:
return SubmitResult(success=False, error=str(e)) return SubmitResult(success=False, error=str(e))
@ -513,31 +495,9 @@ async def main_async() -> int:
print(contest_result.model_dump_json()) print(contest_result.model_dump_json())
return 0 if contest_result.success else 1 return 0 if contest_result.success else 1
if mode == "submit":
if len(sys.argv) != 5:
print(
SubmitResult(
success=False,
error="Usage: atcoder.py submit <contest_id> <problem_id> <language>",
).model_dump_json()
)
return 1
source_code = sys.stdin.read()
creds_raw = os.environ.get("CP_CREDENTIALS", "{}")
try:
credentials = json.loads(creds_raw)
except json.JSONDecodeError:
credentials = {}
language_id = get_language_id("atcoder", sys.argv[4]) or sys.argv[4]
submit_result = await scraper.submit(
sys.argv[2], sys.argv[3], source_code, language_id, credentials
)
print(submit_result.model_dump_json())
return 0 if submit_result.success else 1
result = MetadataResult( result = MetadataResult(
success=False, success=False,
error="Unknown mode. Use 'metadata <contest_id>', 'tests <contest_id>', 'contests', or 'submit <contest_id> <problem_id> <language>'", error="Unknown mode. Use 'metadata <contest_id>', 'tests <contest_id>', or 'contests'",
url="", url="",
) )
print(result.model_dump_json()) print(result.model_dump_json())

View file

@ -6,8 +6,13 @@ import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from .language_ids import get_language_id from .language_ids import get_language_id
from .models import (CombinedTest, ContestListResult, MetadataResult, from .models import (
SubmitResult, TestsResult) CombinedTest,
ContestListResult,
MetadataResult,
SubmitResult,
TestsResult,
)
_PRECISION_ABS_REL_RE = re.compile( _PRECISION_ABS_REL_RE = re.compile(
r"(?:absolute|relative)\s+error[^.]*?10\s*[\^{]\s*\{?\s*[-\u2212]\s*(\d+)\s*\}?", r"(?:absolute|relative)\s+error[^.]*?10\s*[\^{]\s*\{?\s*[-\u2212]\s*(\d+)\s*\}?",

View file

@ -9,8 +9,14 @@ import httpx
from curl_cffi import requests as curl_requests from curl_cffi import requests as curl_requests
from .base import BaseScraper, extract_precision from .base import BaseScraper, extract_precision
from .models import (ContestListResult, ContestSummary, MetadataResult, from .models import (
ProblemSummary, SubmitResult, TestCase) ContestListResult,
ContestSummary,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "https://www.codechef.com" BASE_URL = "https://www.codechef.com"
API_CONTESTS_ALL = "/api/list/contests/all" API_CONTESTS_ALL = "/api/list/contests/all"

View file

@ -10,8 +10,14 @@ from bs4 import BeautifulSoup, Tag
from curl_cffi import requests as curl_requests from curl_cffi import requests as curl_requests
from .base import BaseScraper, extract_precision from .base import BaseScraper, extract_precision
from .models import (ContestListResult, ContestSummary, MetadataResult, from .models import (
ProblemSummary, SubmitResult, TestCase) ContestListResult,
ContestSummary,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "https://codeforces.com" BASE_URL = "https://codeforces.com"
API_CONTEST_LIST_URL = f"{BASE_URL}/api/contest.list" API_CONTEST_LIST_URL = f"{BASE_URL}/api/contest.list"

View file

@ -8,8 +8,14 @@ from typing import Any
import httpx import httpx
from .base import BaseScraper, extract_precision from .base import BaseScraper, extract_precision
from .models import (ContestListResult, ContestSummary, MetadataResult, from .models import (
ProblemSummary, SubmitResult, TestCase) ContestListResult,
ContestSummary,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "https://cses.fi" BASE_URL = "https://cses.fi"
INDEX_PATH = "/problemset" INDEX_PATH = "/problemset"

View file

@ -10,8 +10,14 @@ from datetime import datetime
import httpx import httpx
from .base import BaseScraper from .base import BaseScraper
from .models import (ContestListResult, ContestSummary, MetadataResult, from .models import (
ProblemSummary, SubmitResult, TestCase) ContestListResult,
ContestSummary,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "https://open.kattis.com" BASE_URL = "https://open.kattis.com"
HEADERS = { HEADERS = {

View file

@ -8,8 +8,14 @@ from typing import Any, cast
import httpx import httpx
from .base import BaseScraper from .base import BaseScraper
from .models import (ContestListResult, ContestSummary, MetadataResult, from .models import (
ProblemSummary, SubmitResult, TestCase) ContestListResult,
ContestSummary,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "http://www.usaco.org" BASE_URL = "http://www.usaco.org"
HEADERS = { HEADERS = {
@ -31,7 +37,8 @@ DIVISION_HEADING_RE = re.compile(
re.IGNORECASE, re.IGNORECASE,
) )
PROBLEM_BLOCK_RE = re.compile( PROBLEM_BLOCK_RE = re.compile(
r"<b>([^<]+)</b>\s*<br\s*/?>.*?" r"viewproblem2&cpid=(\d+)", r"<b>([^<]+)</b>\s*<br\s*/?>.*?"
r"viewproblem2&cpid=(\d+)",
re.DOTALL, re.DOTALL,
) )
SAMPLE_IN_RE = re.compile(r"<pre\s+class=['\"]in['\"]>(.*?)</pre>", re.DOTALL) SAMPLE_IN_RE = re.compile(r"<pre\s+class=['\"]in['\"]>(.*?)</pre>", re.DOTALL)

View file

@ -1,6 +1,10 @@
import pytest import pytest
from scrapers.models import ContestListResult, MetadataResult, TestsResult from scrapers.models import (
ContestListResult,
MetadataResult,
TestsResult,
)
MATRIX = { MATRIX = {
"cses": { "cses": {
@ -57,9 +61,9 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode):
assert hasattr(tr.combined, "input"), "combined missing input" assert hasattr(tr.combined, "input"), "combined missing input"
assert hasattr(tr.combined, "expected"), "combined missing expected" assert hasattr(tr.combined, "expected"), "combined missing expected"
assert isinstance(tr.combined.input, str), "combined.input not string" assert isinstance(tr.combined.input, str), "combined.input not string"
assert isinstance( assert isinstance(tr.combined.expected, str), (
tr.combined.expected, str "combined.expected not string"
), "combined.expected not string" )
assert hasattr(tr, "multi_test"), "Missing multi_test field" assert hasattr(tr, "multi_test"), "Missing multi_test field"
assert isinstance(tr.multi_test, bool), "multi_test not boolean" assert isinstance(tr.multi_test, bool), "multi_test not boolean"
validated_any = True validated_any = True
@ -73,12 +77,12 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode):
assert isinstance(obj["combined"], dict), "combined not a dict" assert isinstance(obj["combined"], dict), "combined not a dict"
assert "input" in obj["combined"], "combined missing input key" assert "input" in obj["combined"], "combined missing input key"
assert "expected" in obj["combined"], "combined missing expected key" assert "expected" in obj["combined"], "combined missing expected key"
assert isinstance( assert isinstance(obj["combined"]["input"], str), (
obj["combined"]["input"], str "combined.input not string"
), "combined.input not string" )
assert isinstance( assert isinstance(obj["combined"]["expected"], str), (
obj["combined"]["expected"], str "combined.expected not string"
), "combined.expected not string" )
assert "multi_test" in obj, "Missing multi_test field in raw JSON" assert "multi_test" in obj, "Missing multi_test field in raw JSON"
assert isinstance(obj["multi_test"], bool), "multi_test not boolean" assert isinstance(obj["multi_test"], bool), "multi_test not boolean"
validated_any = True validated_any = True