Compare commits

...

4 commits

Author SHA1 Message Date
ff03b932b1
ci: format 2026-03-04 12:46:47 -05:00
a37e7f2e4a
misc 2026-03-04 12:42:51 -05:00
1033b5e478
ci: format 2026-03-04 00:50:10 -05:00
b5b86ffc6e
fix(edit): clean up buffers on close and support :w to save
Problem: closing the test editor left cp://test-N-* buffers alive,
causing E95 on reopen. The nofile buftype also rejected :w, which
was counterintuitive in an editable grid.

Solution: delete all test buffers in toggle_edit teardown. Switch
buftype to acwrite with a BufWriteCmd autocmd that persists test
cases and clears the modified flag. Hoist save_all_tests above
setup_keybindings so the autocmd closure can reference it.
2026-03-04 00:43:33 -05:00
10 changed files with 190 additions and 172 deletions

View file

@ -5,19 +5,19 @@ local logger = require('cp.log')
local utils = require('cp.utils') local utils = require('cp.utils')
local function syshandle(result) local function syshandle(result)
local ok, data = pcall(vim.json.decode, result.stdout or '')
if ok then
return { success = true, data = data }
end
if result.code ~= 0 then if result.code ~= 0 then
local msg = 'Scraper failed: ' .. (result.stderr or 'Unknown error') local msg = 'Scraper failed: ' .. (result.stderr or 'Unknown error')
return { success = false, error = msg } return { success = false, error = msg }
end end
local ok, data = pcall(vim.json.decode, result.stdout) local msg = 'Failed to parse scraper output: ' .. tostring(data)
if not ok then logger.log(msg, vim.log.levels.ERROR)
local msg = 'Failed to parse scraper output: ' .. tostring(data) return { success = false, error = msg }
logger.log(msg, vim.log.levels.ERROR)
return { success = false, error = msg }
end
return { success = true, data = data }
end end
---@param env_map table<string, string> ---@param env_map table<string, string>

View file

@ -144,7 +144,7 @@ local function add_new_test()
vim.api.nvim_win_set_buf(input_win, input_buf) vim.api.nvim_win_set_buf(input_win, input_buf)
vim.bo[input_buf].modifiable = true vim.bo[input_buf].modifiable = true
vim.bo[input_buf].readonly = false vim.bo[input_buf].readonly = false
vim.bo[input_buf].buftype = 'nofile' vim.bo[input_buf].buftype = 'acwrite'
vim.bo[input_buf].buflisted = false vim.bo[input_buf].buflisted = false
helpers.clearcol(input_buf) helpers.clearcol(input_buf)
@ -155,7 +155,7 @@ local function add_new_test()
vim.api.nvim_win_set_buf(expected_win, expected_buf) vim.api.nvim_win_set_buf(expected_win, expected_buf)
vim.bo[expected_buf].modifiable = true vim.bo[expected_buf].modifiable = true
vim.bo[expected_buf].readonly = false vim.bo[expected_buf].readonly = false
vim.bo[expected_buf].buftype = 'nofile' vim.bo[expected_buf].buftype = 'acwrite'
vim.bo[expected_buf].buflisted = false vim.bo[expected_buf].buflisted = false
helpers.clearcol(expected_buf) helpers.clearcol(expected_buf)
@ -177,6 +177,80 @@ local function add_new_test()
logger.log(('Added test %d'):format(new_index)) logger.log(('Added test %d'):format(new_index))
end end
local function save_all_tests()
if not edit_state then
return
end
local platform = state.get_platform()
local contest_id = state.get_contest_id()
local problem_id = state.get_problem_id()
if not platform or not contest_id or not problem_id then
return
end
for i, pair in ipairs(edit_state.test_buffers) do
if
vim.api.nvim_buf_is_valid(pair.input_buf) and vim.api.nvim_buf_is_valid(pair.expected_buf)
then
local input_lines = vim.api.nvim_buf_get_lines(pair.input_buf, 0, -1, false)
local expected_lines = vim.api.nvim_buf_get_lines(pair.expected_buf, 0, -1, false)
edit_state.test_cases[i].input = table.concat(input_lines, '\n')
edit_state.test_cases[i].expected = table.concat(expected_lines, '\n')
end
end
local contest_data = cache.get_contest_data(platform, contest_id)
local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test
or false
local combined_input = table.concat(
vim.tbl_map(function(tc)
return tc.input
end, edit_state.test_cases),
'\n'
)
local combined_expected = table.concat(
vim.tbl_map(function(tc)
return tc.expected
end, edit_state.test_cases),
'\n'
)
cache.set_test_cases(
platform,
contest_id,
problem_id,
{ input = combined_input, expected = combined_expected },
edit_state.test_cases,
edit_state.constraints and edit_state.constraints.timeout_ms or 0,
edit_state.constraints and edit_state.constraints.memory_mb or 0,
false,
is_multi_test
)
local config = config_module.get_config()
local base_name = config.filename and config.filename(platform, contest_id, problem_id, config)
or config_module.default_filename(contest_id, problem_id)
vim.fn.mkdir('io', 'p')
for i, tc in ipairs(edit_state.test_cases) do
local input_file = string.format('io/%s.%d.cpin', base_name, i)
local expected_file = string.format('io/%s.%d.cpout', base_name, i)
local input_content = (tc.input or ''):gsub('\r', '')
local expected_content = (tc.expected or ''):gsub('\r', '')
vim.fn.writefile(vim.split(input_content, '\n', { trimempty = true }), input_file)
vim.fn.writefile(vim.split(expected_content, '\n', { trimempty = true }), expected_file)
end
logger.log('Saved all test cases')
end
---@param buf integer ---@param buf integer
setup_keybindings = function(buf) setup_keybindings = function(buf)
local config = config_module.get_config() local config = config_module.get_config()
@ -243,86 +317,30 @@ setup_keybindings = function(buf)
end) end)
end, end,
}) })
end
local function save_all_tests() vim.api.nvim_create_autocmd('BufWriteCmd', {
if not edit_state then group = augroup,
return buffer = buf,
end callback = function()
save_all_tests()
local platform = state.get_platform() vim.bo[buf].modified = false
local contest_id = state.get_contest_id() end,
local problem_id = state.get_problem_id() })
if not platform or not contest_id or not problem_id then
return
end
for i, pair in ipairs(edit_state.test_buffers) do
if
vim.api.nvim_buf_is_valid(pair.input_buf) and vim.api.nvim_buf_is_valid(pair.expected_buf)
then
local input_lines = vim.api.nvim_buf_get_lines(pair.input_buf, 0, -1, false)
local expected_lines = vim.api.nvim_buf_get_lines(pair.expected_buf, 0, -1, false)
edit_state.test_cases[i].input = table.concat(input_lines, '\n')
edit_state.test_cases[i].expected = table.concat(expected_lines, '\n')
end
end
local contest_data = cache.get_contest_data(platform, contest_id)
local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test
or false
-- Generate combined test from individual test cases
local combined_input = table.concat(
vim.tbl_map(function(tc)
return tc.input
end, edit_state.test_cases),
'\n'
)
local combined_expected = table.concat(
vim.tbl_map(function(tc)
return tc.expected
end, edit_state.test_cases),
'\n'
)
cache.set_test_cases(
platform,
contest_id,
problem_id,
{ input = combined_input, expected = combined_expected },
edit_state.test_cases,
edit_state.constraints and edit_state.constraints.timeout_ms or 0,
edit_state.constraints and edit_state.constraints.memory_mb or 0,
false,
is_multi_test
)
local config = config_module.get_config()
local base_name = config.filename and config.filename(platform, contest_id, problem_id, config)
or config_module.default_filename(contest_id, problem_id)
vim.fn.mkdir('io', 'p')
for i, tc in ipairs(edit_state.test_cases) do
local input_file = string.format('io/%s.%d.cpin', base_name, i)
local expected_file = string.format('io/%s.%d.cpout', base_name, i)
local input_content = (tc.input or ''):gsub('\r', '')
local expected_content = (tc.expected or ''):gsub('\r', '')
vim.fn.writefile(vim.split(input_content, '\n', { trimempty = true }), input_file)
vim.fn.writefile(vim.split(expected_content, '\n', { trimempty = true }), expected_file)
end
logger.log('Saved all test cases')
end end
function M.toggle_edit(test_index) function M.toggle_edit(test_index)
if edit_state then if edit_state then
save_all_tests() save_all_tests()
for _, pair in ipairs(edit_state.test_buffers) do
if vim.api.nvim_buf_is_valid(pair.input_buf) then
vim.api.nvim_buf_delete(pair.input_buf, { force = true })
end
if vim.api.nvim_buf_is_valid(pair.expected_buf) then
vim.api.nvim_buf_delete(pair.expected_buf, { force = true })
end
end
edit_state = nil edit_state = nil
pcall(vim.api.nvim_clear_autocmds, { group = 'cp_edit_guard' }) pcall(vim.api.nvim_clear_autocmds, { group = 'cp_edit_guard' })
@ -411,7 +429,7 @@ function M.toggle_edit(test_index)
vim.api.nvim_win_set_buf(input_win, input_buf) vim.api.nvim_win_set_buf(input_win, input_buf)
vim.bo[input_buf].modifiable = true vim.bo[input_buf].modifiable = true
vim.bo[input_buf].readonly = false vim.bo[input_buf].readonly = false
vim.bo[input_buf].buftype = 'nofile' vim.bo[input_buf].buftype = 'acwrite'
vim.bo[input_buf].buflisted = false vim.bo[input_buf].buflisted = false
helpers.clearcol(input_buf) helpers.clearcol(input_buf)
@ -421,7 +439,7 @@ function M.toggle_edit(test_index)
vim.api.nvim_win_set_buf(expected_win, expected_buf) vim.api.nvim_win_set_buf(expected_win, expected_buf)
vim.bo[expected_buf].modifiable = true vim.bo[expected_buf].modifiable = true
vim.bo[expected_buf].readonly = false vim.bo[expected_buf].readonly = false
vim.bo[expected_buf].buftype = 'nofile' vim.bo[expected_buf].buftype = 'acwrite'
vim.bo[expected_buf].buflisted = false vim.bo[expected_buf].buflisted = false
helpers.clearcol(expected_buf) helpers.clearcol(expected_buf)

View file

@ -2,6 +2,7 @@
import asyncio import asyncio
import json import json
import os
import re import re
import sys import sys
import time import time
@ -15,16 +16,10 @@ from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry from urllib3.util.retry import Retry
from .base import BaseScraper, extract_precision from .base import BaseScraper, extract_precision
from .models import ( from .language_ids import get_language_id
CombinedTest, from .models import (CombinedTest, ContestListResult, ContestSummary,
ContestListResult, MetadataResult, ProblemSummary, SubmitResult, TestCase,
ContestSummary, TestsResult)
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
TestsResult,
)
MIB_TO_MB = 1.048576 MIB_TO_MB = 1.048576
BASE_URL = "https://atcoder.jp" BASE_URL = "https://atcoder.jp"
@ -378,10 +373,12 @@ class AtcoderScraper(BaseScraper):
credentials: dict[str, str], credentials: dict[str, str],
) -> SubmitResult: ) -> SubmitResult:
def _submit_sync() -> SubmitResult: def _submit_sync() -> SubmitResult:
from curl_cffi import requests as curl_requests
try: try:
login_page = _session.get( session = curl_requests.Session(impersonate="chrome")
f"{BASE_URL}/login", headers=HEADERS, timeout=TIMEOUT_SECONDS
) login_page = session.get(f"{BASE_URL}/login", timeout=TIMEOUT_SECONDS)
login_page.raise_for_status() login_page.raise_for_status()
soup = BeautifulSoup(login_page.text, "html.parser") soup = BeautifulSoup(login_page.text, "html.parser")
csrf_input = soup.find("input", {"name": "csrf_token"}) csrf_input = soup.find("input", {"name": "csrf_token"})
@ -391,21 +388,29 @@ class AtcoderScraper(BaseScraper):
) )
csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr] csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr]
login_resp = _session.post( login_resp = session.post(
f"{BASE_URL}/login", f"{BASE_URL}/login",
data={ data={
"username": credentials.get("username", ""), "username": credentials.get("username", ""),
"password": credentials.get("password", ""), "password": credentials.get("password", ""),
"csrf_token": csrf_token, "csrf_token": csrf_token,
}, },
headers=HEADERS,
timeout=TIMEOUT_SECONDS, timeout=TIMEOUT_SECONDS,
allow_redirects=False,
) )
login_resp.raise_for_status() if login_resp.status_code in (301, 302):
location = login_resp.headers.get("Location", "")
if "/login" in location:
return SubmitResult(
success=False,
error="Login failed: incorrect username or password",
)
session.get(BASE_URL + location, timeout=TIMEOUT_SECONDS)
else:
login_resp.raise_for_status()
submit_page = _session.get( submit_page = session.get(
f"{BASE_URL}/contests/{contest_id}/submit", f"{BASE_URL}/contests/{contest_id}/submit",
headers=HEADERS,
timeout=TIMEOUT_SECONDS, timeout=TIMEOUT_SECONDS,
) )
submit_page.raise_for_status() submit_page.raise_for_status()
@ -418,7 +423,7 @@ class AtcoderScraper(BaseScraper):
csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr] csrf_token = csrf_input.get("value", "") or "" # type: ignore[union-attr]
task_screen_name = f"{contest_id}_{problem_id}" task_screen_name = f"{contest_id}_{problem_id}"
submit_resp = _session.post( submit_resp = session.post(
f"{BASE_URL}/contests/{contest_id}/submit", f"{BASE_URL}/contests/{contest_id}/submit",
data={ data={
"data.TaskScreenName": task_screen_name, "data.TaskScreenName": task_screen_name,
@ -426,13 +431,26 @@ class AtcoderScraper(BaseScraper):
"sourceCode": source_code, "sourceCode": source_code,
"csrf_token": csrf_token, "csrf_token": csrf_token,
}, },
headers=HEADERS,
timeout=TIMEOUT_SECONDS, timeout=TIMEOUT_SECONDS,
allow_redirects=False,
) )
if submit_resp.status_code in (301, 302):
location = submit_resp.headers.get("Location", "")
if "/submissions/me" in location:
return SubmitResult(
success=True,
error="",
submission_id="",
verdict="submitted",
)
return SubmitResult(
success=False,
error=f"Submit may have failed: redirected to {location}",
)
submit_resp.raise_for_status() submit_resp.raise_for_status()
return SubmitResult( return SubmitResult(
success=True, error="", submission_id="", verdict="submitted" success=False,
error="Unexpected response from submit (expected redirect)",
) )
except Exception as e: except Exception as e:
return SubmitResult(success=False, error=str(e)) return SubmitResult(success=False, error=str(e))
@ -495,9 +513,31 @@ async def main_async() -> int:
print(contest_result.model_dump_json()) print(contest_result.model_dump_json())
return 0 if contest_result.success else 1 return 0 if contest_result.success else 1
if mode == "submit":
if len(sys.argv) != 5:
print(
SubmitResult(
success=False,
error="Usage: atcoder.py submit <contest_id> <problem_id> <language>",
).model_dump_json()
)
return 1
source_code = sys.stdin.read()
creds_raw = os.environ.get("CP_CREDENTIALS", "{}")
try:
credentials = json.loads(creds_raw)
except json.JSONDecodeError:
credentials = {}
language_id = get_language_id("atcoder", sys.argv[4]) or sys.argv[4]
submit_result = await scraper.submit(
sys.argv[2], sys.argv[3], source_code, language_id, credentials
)
print(submit_result.model_dump_json())
return 0 if submit_result.success else 1
result = MetadataResult( result = MetadataResult(
success=False, success=False,
error="Unknown mode. Use 'metadata <contest_id>', 'tests <contest_id>', or 'contests'", error="Unknown mode. Use 'metadata <contest_id>', 'tests <contest_id>', 'contests', or 'submit <contest_id> <problem_id> <language>'",
url="", url="",
) )
print(result.model_dump_json()) print(result.model_dump_json())

View file

@ -6,13 +6,8 @@ import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from .language_ids import get_language_id from .language_ids import get_language_id
from .models import ( from .models import (CombinedTest, ContestListResult, MetadataResult,
CombinedTest, SubmitResult, TestsResult)
ContestListResult,
MetadataResult,
SubmitResult,
TestsResult,
)
_PRECISION_ABS_REL_RE = re.compile( _PRECISION_ABS_REL_RE = re.compile(
r"(?:absolute|relative)\s+error[^.]*?10\s*[\^{]\s*\{?\s*[-\u2212]\s*(\d+)\s*\}?", r"(?:absolute|relative)\s+error[^.]*?10\s*[\^{]\s*\{?\s*[-\u2212]\s*(\d+)\s*\}?",

View file

@ -9,14 +9,8 @@ import httpx
from curl_cffi import requests as curl_requests from curl_cffi import requests as curl_requests
from .base import BaseScraper, extract_precision from .base import BaseScraper, extract_precision
from .models import ( from .models import (ContestListResult, ContestSummary, MetadataResult,
ContestListResult, ProblemSummary, SubmitResult, TestCase)
ContestSummary,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "https://www.codechef.com" BASE_URL = "https://www.codechef.com"
API_CONTESTS_ALL = "/api/list/contests/all" API_CONTESTS_ALL = "/api/list/contests/all"

View file

@ -10,14 +10,8 @@ from bs4 import BeautifulSoup, Tag
from curl_cffi import requests as curl_requests from curl_cffi import requests as curl_requests
from .base import BaseScraper, extract_precision from .base import BaseScraper, extract_precision
from .models import ( from .models import (ContestListResult, ContestSummary, MetadataResult,
ContestListResult, ProblemSummary, SubmitResult, TestCase)
ContestSummary,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "https://codeforces.com" BASE_URL = "https://codeforces.com"
API_CONTEST_LIST_URL = f"{BASE_URL}/api/contest.list" API_CONTEST_LIST_URL = f"{BASE_URL}/api/contest.list"

View file

@ -8,14 +8,8 @@ from typing import Any
import httpx import httpx
from .base import BaseScraper, extract_precision from .base import BaseScraper, extract_precision
from .models import ( from .models import (ContestListResult, ContestSummary, MetadataResult,
ContestListResult, ProblemSummary, SubmitResult, TestCase)
ContestSummary,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "https://cses.fi" BASE_URL = "https://cses.fi"
INDEX_PATH = "/problemset" INDEX_PATH = "/problemset"

View file

@ -10,14 +10,8 @@ from datetime import datetime
import httpx import httpx
from .base import BaseScraper from .base import BaseScraper
from .models import ( from .models import (ContestListResult, ContestSummary, MetadataResult,
ContestListResult, ProblemSummary, SubmitResult, TestCase)
ContestSummary,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "https://open.kattis.com" BASE_URL = "https://open.kattis.com"
HEADERS = { HEADERS = {

View file

@ -8,14 +8,8 @@ from typing import Any, cast
import httpx import httpx
from .base import BaseScraper from .base import BaseScraper
from .models import ( from .models import (ContestListResult, ContestSummary, MetadataResult,
ContestListResult, ProblemSummary, SubmitResult, TestCase)
ContestSummary,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "http://www.usaco.org" BASE_URL = "http://www.usaco.org"
HEADERS = { HEADERS = {
@ -37,8 +31,7 @@ DIVISION_HEADING_RE = re.compile(
re.IGNORECASE, re.IGNORECASE,
) )
PROBLEM_BLOCK_RE = re.compile( PROBLEM_BLOCK_RE = re.compile(
r"<b>([^<]+)</b>\s*<br\s*/?>.*?" r"<b>([^<]+)</b>\s*<br\s*/?>.*?" r"viewproblem2&cpid=(\d+)",
r"viewproblem2&cpid=(\d+)",
re.DOTALL, re.DOTALL,
) )
SAMPLE_IN_RE = re.compile(r"<pre\s+class=['\"]in['\"]>(.*?)</pre>", re.DOTALL) SAMPLE_IN_RE = re.compile(r"<pre\s+class=['\"]in['\"]>(.*?)</pre>", re.DOTALL)

View file

@ -1,10 +1,6 @@
import pytest import pytest
from scrapers.models import ( from scrapers.models import ContestListResult, MetadataResult, TestsResult
ContestListResult,
MetadataResult,
TestsResult,
)
MATRIX = { MATRIX = {
"cses": { "cses": {
@ -61,9 +57,9 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode):
assert hasattr(tr.combined, "input"), "combined missing input" assert hasattr(tr.combined, "input"), "combined missing input"
assert hasattr(tr.combined, "expected"), "combined missing expected" assert hasattr(tr.combined, "expected"), "combined missing expected"
assert isinstance(tr.combined.input, str), "combined.input not string" assert isinstance(tr.combined.input, str), "combined.input not string"
assert isinstance(tr.combined.expected, str), ( assert isinstance(
"combined.expected not string" tr.combined.expected, str
) ), "combined.expected not string"
assert hasattr(tr, "multi_test"), "Missing multi_test field" assert hasattr(tr, "multi_test"), "Missing multi_test field"
assert isinstance(tr.multi_test, bool), "multi_test not boolean" assert isinstance(tr.multi_test, bool), "multi_test not boolean"
validated_any = True validated_any = True
@ -77,12 +73,12 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode):
assert isinstance(obj["combined"], dict), "combined not a dict" assert isinstance(obj["combined"], dict), "combined not a dict"
assert "input" in obj["combined"], "combined missing input key" assert "input" in obj["combined"], "combined missing input key"
assert "expected" in obj["combined"], "combined missing expected key" assert "expected" in obj["combined"], "combined missing expected key"
assert isinstance(obj["combined"]["input"], str), ( assert isinstance(
"combined.input not string" obj["combined"]["input"], str
) ), "combined.input not string"
assert isinstance(obj["combined"]["expected"], str), ( assert isinstance(
"combined.expected not string" obj["combined"]["expected"], str
) ), "combined.expected not string"
assert "multi_test" in obj, "Missing multi_test field in raw JSON" assert "multi_test" in obj, "Missing multi_test field in raw JSON"
assert isinstance(obj["multi_test"], bool), "multi_test not boolean" assert isinstance(obj["multi_test"], bool), "multi_test not boolean"
validated_any = True validated_any = True