Compare commits

..

No commits in common. "feat/codechef" and "chore/add-issue-templates" have entirely different histories.

70 changed files with 2110 additions and 6568 deletions

3
.envrc Normal file
View file

@ -0,0 +1,3 @@
VIRTUAL_ENV="$PWD/.venv"
PATH_add "$VIRTUAL_ENV/bin"
export VIRTUAL_ENV

112
.github/workflows/ci.yaml vendored Normal file
View file

@ -0,0 +1,112 @@
name: ci
on:
workflow_call:
pull_request:
branches: [main]
push:
branches: [main]
jobs:
changes:
runs-on: ubuntu-latest
outputs:
lua: ${{ steps.changes.outputs.lua }}
python: ${{ steps.changes.outputs.python }}
steps:
- uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: changes
with:
filters: |
lua:
- 'lua/**'
- 'spec/**'
- 'plugin/**'
- 'after/**'
- 'ftdetect/**'
- '*.lua'
- '.luarc.json'
- 'stylua.toml'
- 'selene.toml'
python:
- 'scripts/**'
- 'scrapers/**'
- 'tests/**'
- 'pyproject.toml'
- 'uv.lock'
lua-format:
runs-on: ubuntu-latest
needs: changes
if: ${{ needs.changes.outputs.lua == 'true' }}
steps:
- uses: actions/checkout@v4
- uses: JohnnyMorganz/stylua-action@v4
with:
token: ${{ secrets.GITHUB_TOKEN }}
version: 2.1.0
args: --check .
lua-lint:
runs-on: ubuntu-latest
needs: changes
if: ${{ needs.changes.outputs.lua == 'true' }}
steps:
- uses: actions/checkout@v4
- uses: NTBBloodbath/selene-action@v1.0.0
with:
token: ${{ secrets.GITHUB_TOKEN }}
args: --display-style quiet .
lua-typecheck:
runs-on: ubuntu-latest
needs: changes
if: ${{ needs.changes.outputs.lua == 'true' }}
steps:
- uses: actions/checkout@v4
- uses: mrcjkb/lua-typecheck-action@v0
with:
checklevel: Warning
directories: lua
configpath: .luarc.json
python-format:
runs-on: ubuntu-latest
needs: changes
if: ${{ needs.changes.outputs.python == 'true' }}
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- run: uv tool install ruff
- run: ruff format --check .
python-lint:
runs-on: ubuntu-latest
needs: changes
if: ${{ needs.changes.outputs.python == 'true' }}
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- run: uv tool install ruff
- run: ruff check .
python-typecheck:
runs-on: ubuntu-latest
needs: changes
if: ${{ needs.changes.outputs.python == 'true' }}
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- run: uv sync --dev
- run: uvx ty check .
python-test:
runs-on: ubuntu-latest
needs: changes
if: ${{ needs.changes.outputs.python == 'true' }}
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- run: uv sync --dev
- run: uv run camoufox fetch
- run: uv run pytest tests/ -v

View file

@ -28,7 +28,6 @@ jobs:
- '*.lua' - '*.lua'
- '.luarc.json' - '.luarc.json'
- '*.toml' - '*.toml'
- 'vim.yaml'
python: python:
- 'scripts/**/.py' - 'scripts/**/.py'
- 'scrapers/**/*.py' - 'scrapers/**/*.py'
@ -46,8 +45,11 @@ jobs:
if: ${{ needs.changes.outputs.lua == 'true' }} if: ${{ needs.changes.outputs.lua == 'true' }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: cachix/install-nix-action@v31 - uses: JohnnyMorganz/stylua-action@v4
- run: nix develop --command stylua --check . with:
token: ${{ secrets.GITHUB_TOKEN }}
version: 2.1.0
args: --check .
lua-lint: lua-lint:
name: Lua Lint Check name: Lua Lint Check
@ -56,8 +58,11 @@ jobs:
if: ${{ needs.changes.outputs.lua == 'true' }} if: ${{ needs.changes.outputs.lua == 'true' }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: cachix/install-nix-action@v31 - name: Lint with Selene
- run: nix develop --command selene --display-style quiet . uses: NTBBloodbath/selene-action@v1.0.0
with:
token: ${{ secrets.GITHUB_TOKEN }}
args: --display-style quiet .
lua-typecheck: lua-typecheck:
name: Lua Type Check name: Lua Type Check
@ -122,5 +127,15 @@ jobs:
if: ${{ needs.changes.outputs.markdown == 'true' }} if: ${{ needs.changes.outputs.markdown == 'true' }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: cachix/install-nix-action@v31 - name: Setup pnpm
- run: nix develop --command prettier --check . uses: pnpm/action-setup@v4
with:
version: 8
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Install prettier
run: pnpm add -g prettier@3.1.0
- name: Check markdown formatting with prettier
run: prettier --check .

View file

@ -44,7 +44,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v4 uses: astral-sh/setup-uv@v4
- name: Install dependencies - name: Install dependencies with pytest
run: uv sync --dev run: uv sync --dev
- name: Fetch camoufox data
run: uv run camoufox fetch
- name: Run Python tests - name: Run Python tests
run: uv run pytest tests/ -v run: uv run pytest tests/ -v

3
.gitignore vendored
View file

@ -14,6 +14,3 @@ __pycache__
.claude/ .claude/
node_modules/ node_modules/
.envrc
.direnv/

View file

@ -1,21 +1,8 @@
{ {
"runtime": { "runtime.version": "Lua 5.1",
"version": "LuaJIT", "runtime.path": ["lua/?.lua", "lua/?/init.lua"],
"path": ["lua/?.lua", "lua/?/init.lua"] "diagnostics.globals": ["vim"],
}, "workspace.library": ["$VIMRUNTIME/lua", "${3rd}/luv/library"],
"diagnostics": { "workspace.checkThirdParty": false,
"globals": ["vim"] "completion.callSnippet": "Replace"
},
"workspace": {
"library": [
"$VIMRUNTIME/lua",
"${3rd}/luv/library",
"${3rd}/busted/library"
],
"checkThirdParty": false,
"ignoreDir": [".direnv"]
},
"completion": {
"callSnippet": "Replace"
}
} }

View file

@ -1 +0,0 @@
.direnv/

View file

@ -28,12 +28,11 @@ Install using your package manager of choice or via
luarocks install cp.nvim luarocks install cp.nvim
``` ```
## Dependencies ## Optional Dependencies
- [uv](https://docs.astral.sh/uv/) for problem scraping
- GNU [time](https://www.gnu.org/software/time/) and - GNU [time](https://www.gnu.org/software/time/) and
[timeout](https://www.gnu.org/software/coreutils/manual/html_node/timeout-invocation.html) [timeout](https://www.gnu.org/software/coreutils/manual/html_node/timeout-invocation.html)
- [uv](https://docs.astral.sh/uv/) or [nix](https://nixos.org/) for problem
scraping
## Quick Start ## Quick Start

File diff suppressed because it is too large Load diff

43
flake.lock generated
View file

@ -1,43 +0,0 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1771008912,
"narHash": "sha256-gf2AmWVTs8lEq7z/3ZAsgnZDhWIckkb+ZnAo5RzSxJg=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "a82ccc39b39b621151d6732718e3e250109076fa",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs",
"systems": "systems"
}
},
"systems": {
"locked": {
"lastModified": 1689347949,
"narHash": "sha256-12tWmuL2zgBgZkdoB6qXZsgJEH9LR3oUgpaQq2RbI80=",
"owner": "nix-systems",
"repo": "default-linux",
"rev": "31732fcf5e8fea42e59c2488ad31a0e651500f68",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default-linux",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

137
flake.nix
View file

@ -1,137 +0,0 @@
{
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
systems.url = "github:nix-systems/default-linux";
};
outputs =
{
self,
nixpkgs,
systems,
}:
let
eachSystem = nixpkgs.lib.genAttrs (import systems);
pkgsFor = system: nixpkgs.legacyPackages.${system};
mkPythonEnv =
pkgs:
pkgs.python312.withPackages (ps: [
ps.backoff
ps.beautifulsoup4
ps.httpx
ps.ndjson
ps.pydantic
ps.requests
]);
mkDevPythonEnv =
pkgs:
pkgs.python312.withPackages (ps: [
ps.backoff
ps.beautifulsoup4
ps.httpx
ps.ndjson
ps.pydantic
ps.requests
ps.pytest
ps.pytest-mock
]);
mkSubmitEnv =
pkgs:
pkgs.buildFHSEnv {
name = "cp-nvim-submit";
targetPkgs =
pkgs: with pkgs; [
uv
alsa-lib
at-spi2-atk
cairo
cups
dbus
fontconfig
freetype
gdk-pixbuf
glib
gtk3
libdrm
libxkbcommon
mesa
libGL
nspr
nss
pango
libx11
libxcomposite
libxdamage
libxext
libxfixes
libxrandr
libxcb
at-spi2-core
expat
libgbm
systemdLibs
zlib
];
runScript = "${pkgs.uv}/bin/uv";
};
mkPlugin =
pkgs:
let
pythonEnv = mkPythonEnv pkgs;
submitEnv = mkSubmitEnv pkgs;
in
pkgs.vimUtils.buildVimPlugin {
pname = "cp-nvim";
version = "0-unstable-${self.shortRev or self.dirtyShortRev or "dev"}";
src = self;
postPatch = ''
substituteInPlace lua/cp/utils.lua \
--replace-fail "local _nix_python = nil" \
"local _nix_python = '${pythonEnv.interpreter}'"
substituteInPlace lua/cp/utils.lua \
--replace-fail "local _nix_submit_cmd = nil" \
"local _nix_submit_cmd = '${submitEnv}/bin/cp-nvim-submit'"
'';
nvimSkipModule = [
"cp.pickers.telescope"
"cp.version"
];
passthru = { inherit pythonEnv submitEnv; };
meta.description = "Competitive programming plugin for Neovim";
};
in
{
overlays.default = final: prev: {
vimPlugins = prev.vimPlugins // {
cp-nvim = mkPlugin final;
};
};
packages = eachSystem (system: {
default = mkPlugin (pkgsFor system);
pythonEnv = mkPythonEnv (pkgsFor system);
submitEnv = mkSubmitEnv (pkgsFor system);
});
formatter = eachSystem (system: (pkgsFor system).nixfmt-tree);
devShells = eachSystem (system: {
default = (pkgsFor system).mkShell {
packages = with (pkgsFor system); [
uv
(mkDevPythonEnv (pkgsFor system))
prettier
ruff
stylua
selene
lua-language-server
ty
];
};
});
};
}

View file

@ -10,14 +10,11 @@
---@field name string ---@field name string
---@field display_name string ---@field display_name string
---@field url string ---@field url string
---@field contest_url string
---@field standings_url string
---@class ContestSummary ---@class ContestSummary
---@field display_name string ---@field display_name string
---@field name string ---@field name string
---@field id string ---@field id string
---@field start_time? integer
---@class CombinedTest ---@class CombinedTest
---@field input string ---@field input string
@ -30,7 +27,6 @@
---@field multi_test? boolean ---@field multi_test? boolean
---@field memory_mb? number ---@field memory_mb? number
---@field timeout_ms? number ---@field timeout_ms? number
---@field precision? number
---@field combined_test? CombinedTest ---@field combined_test? CombinedTest
---@field test_cases TestCase[] ---@field test_cases TestCase[]
@ -42,8 +38,7 @@
local M = {} local M = {}
local CACHE_VERSION = 2 local logger = require('cp.log')
local cache_file = vim.fn.stdpath('data') .. '/cp-nvim.json' local cache_file = vim.fn.stdpath('data') .. '/cp-nvim.json'
local cache_data = {} local cache_data = {}
local loaded = false local loaded = false
@ -69,30 +64,10 @@ function M.load()
end end
local ok, decoded = pcall(vim.json.decode, table.concat(content, '\n')) local ok, decoded = pcall(vim.json.decode, table.concat(content, '\n'))
if not ok then if ok then
cache_data = {}
M.save()
loaded = true
return
end
if decoded._version == 1 then
local old_creds = decoded._credentials
decoded._credentials = nil
if old_creds then
for platform, creds in pairs(old_creds) do
decoded[platform] = decoded[platform] or {}
decoded[platform]._credentials = creds
end
end
decoded._version = CACHE_VERSION
cache_data = decoded
M.save()
elseif decoded._version == CACHE_VERSION then
cache_data = decoded cache_data = decoded
else else
cache_data = {} logger.log('Could not decode json in cache file', vim.log.levels.ERROR)
M.save()
end end
loaded = true loaded = true
end end
@ -103,7 +78,6 @@ function M.save()
vim.schedule(function() vim.schedule(function()
vim.fn.mkdir(vim.fn.fnamemodify(cache_file, ':h'), 'p') vim.fn.mkdir(vim.fn.fnamemodify(cache_file, ':h'), 'p')
cache_data._version = CACHE_VERSION
local encoded = vim.json.encode(cache_data) local encoded = vim.json.encode(cache_data)
local lines = vim.split(encoded, '\n') local lines = vim.split(encoded, '\n')
vim.fn.writefile(lines, cache_file) vim.fn.writefile(lines, cache_file)
@ -138,9 +112,7 @@ function M.get_cached_contest_ids(platform)
local contest_ids = {} local contest_ids = {}
for contest_id, _ in pairs(cache_data[platform]) do for contest_id, _ in pairs(cache_data[platform]) do
if contest_id:sub(1, 1) ~= '_' then table.insert(contest_ids, contest_id)
table.insert(contest_ids, contest_id)
end
end end
table.sort(contest_ids) table.sort(contest_ids)
return contest_ids return contest_ids
@ -150,16 +122,12 @@ end
---@param contest_id string ---@param contest_id string
---@param problems Problem[] ---@param problems Problem[]
---@param url string ---@param url string
---@param contest_url string function M.set_contest_data(platform, contest_id, problems, url)
---@param standings_url string
function M.set_contest_data(platform, contest_id, problems, url, contest_url, standings_url)
vim.validate({ vim.validate({
platform = { platform, 'string' }, platform = { platform, 'string' },
contest_id = { contest_id, 'string' }, contest_id = { contest_id, 'string' },
problems = { problems, 'table' }, problems = { problems, 'table' },
url = { url, 'string' }, url = { url, 'string' },
contest_url = { contest_url, 'string' },
standings_url = { standings_url, 'string' },
}) })
cache_data[platform] = cache_data[platform] or {} cache_data[platform] = cache_data[platform] or {}
@ -171,8 +139,6 @@ function M.set_contest_data(platform, contest_id, problems, url, contest_url, st
problems = problems, problems = problems,
index_map = {}, index_map = {},
url = url, url = url,
contest_url = contest_url,
standings_url = standings_url,
} }
for i, p in ipairs(out.problems) do for i, p in ipairs(out.problems) do
out.index_map[p.id] = i out.index_map[p.id] = i
@ -182,25 +148,6 @@ function M.set_contest_data(platform, contest_id, problems, url, contest_url, st
M.save() M.save()
end end
---@param platform string?
---@param contest_id string?
---@param problem_id string?
---@return { problem: string|nil, contest: string|nil, standings: string|nil }|nil
function M.get_open_urls(platform, contest_id, problem_id)
if not platform or not contest_id then
return nil
end
if not cache_data[platform] or not cache_data[platform][contest_id] then
return nil
end
local cd = cache_data[platform][contest_id]
return {
problem = cd.url ~= '' and problem_id and string.format(cd.url, problem_id) or nil,
contest = cd.contest_url ~= '' and cd.contest_url or nil,
standings = cd.standings_url ~= '' and cd.standings_url or nil,
}
end
---@param platform string ---@param platform string
---@param contest_id string ---@param contest_id string
function M.clear_contest_data(platform, contest_id) function M.clear_contest_data(platform, contest_id)
@ -266,7 +213,6 @@ end
---@param memory_mb number ---@param memory_mb number
---@param interactive boolean ---@param interactive boolean
---@param multi_test boolean ---@param multi_test boolean
---@param precision number?
function M.set_test_cases( function M.set_test_cases(
platform, platform,
contest_id, contest_id,
@ -276,8 +222,7 @@ function M.set_test_cases(
timeout_ms, timeout_ms,
memory_mb, memory_mb,
interactive, interactive,
multi_test, multi_test
precision
) )
vim.validate({ vim.validate({
platform = { platform, 'string' }, platform = { platform, 'string' },
@ -289,7 +234,6 @@ function M.set_test_cases(
memory_mb = { memory_mb, { 'number', 'nil' }, true }, memory_mb = { memory_mb, { 'number', 'nil' }, true },
interactive = { interactive, { 'boolean', 'nil' }, true }, interactive = { interactive, { 'boolean', 'nil' }, true },
multi_test = { multi_test, { 'boolean', 'nil' }, true }, multi_test = { multi_test, { 'boolean', 'nil' }, true },
precision = { precision, { 'number', 'nil' }, true },
}) })
local index = cache_data[platform][contest_id].index_map[problem_id] local index = cache_data[platform][contest_id].index_map[problem_id]
@ -300,7 +244,6 @@ function M.set_test_cases(
cache_data[platform][contest_id].problems[index].memory_mb = memory_mb cache_data[platform][contest_id].problems[index].memory_mb = memory_mb
cache_data[platform][contest_id].problems[index].interactive = interactive cache_data[platform][contest_id].problems[index].interactive = interactive
cache_data[platform][contest_id].problems[index].multi_test = multi_test cache_data[platform][contest_id].problems[index].multi_test = multi_test
cache_data[platform][contest_id].problems[index].precision = precision
M.save() M.save()
end end
@ -322,34 +265,6 @@ function M.get_constraints(platform, contest_id, problem_id)
return problem_data.timeout_ms, problem_data.memory_mb return problem_data.timeout_ms, problem_data.memory_mb
end end
---@param platform string
---@param contest_id string
---@param problem_id? string
---@return number?
function M.get_precision(platform, contest_id, problem_id)
vim.validate({
platform = { platform, 'string' },
contest_id = { contest_id, 'string' },
problem_id = { problem_id, { 'string', 'nil' }, true },
})
if
not cache_data[platform]
or not cache_data[platform][contest_id]
or not cache_data[platform][contest_id].index_map
then
return nil
end
local index = cache_data[platform][contest_id].index_map[problem_id]
if not index then
return nil
end
local problem_data = cache_data[platform][contest_id].problems[index]
return problem_data and problem_data.precision or nil
end
---@param file_path string ---@param file_path string
---@return FileState|nil ---@return FileState|nil
function M.get_file_state(file_path) function M.get_file_state(file_path)
@ -380,95 +295,28 @@ end
function M.get_contest_summaries(platform) function M.get_contest_summaries(platform)
local contest_list = {} local contest_list = {}
for contest_id, contest_data in pairs(cache_data[platform] or {}) do for contest_id, contest_data in pairs(cache_data[platform] or {}) do
if type(contest_data) == 'table' and contest_id:sub(1, 1) ~= '_' then table.insert(contest_list, {
table.insert(contest_list, { id = contest_id,
id = contest_id, name = contest_data.name,
name = contest_data.name, display_name = contest_data.display_name,
display_name = contest_data.display_name, })
})
end
end end
return contest_list return contest_list
end end
---@param platform string ---@param platform string
---@param contests ContestSummary[] ---@param contests ContestSummary[]
---@param opts? { supports_countdown?: boolean } function M.set_contest_summaries(platform, contests)
function M.set_contest_summaries(platform, contests, opts)
cache_data[platform] = cache_data[platform] or {} cache_data[platform] = cache_data[platform] or {}
for _, contest in ipairs(contests) do for _, contest in ipairs(contests) do
cache_data[platform][contest.id] = cache_data[platform][contest.id] or {} cache_data[platform][contest.id] = cache_data[platform][contest.id] or {}
cache_data[platform][contest.id].display_name = ( cache_data[platform][contest.id].display_name = contest.display_name
contest.display_name ~= vim.NIL and contest.display_name
) or contest.name
cache_data[platform][contest.id].name = contest.name cache_data[platform][contest.id].name = contest.name
if contest.start_time and contest.start_time ~= vim.NIL then
cache_data[platform][contest.id].start_time = contest.start_time
end
end
if opts and opts.supports_countdown ~= nil then
cache_data[platform].supports_countdown = opts.supports_countdown
end end
M.save() M.save()
end end
---@param platform string
---@return boolean?
function M.get_supports_countdown(platform)
if not cache_data[platform] then
return nil
end
return cache_data[platform].supports_countdown
end
---@param platform string
---@param contest_id string
---@return integer?
function M.get_contest_start_time(platform, contest_id)
if not cache_data[platform] or not cache_data[platform][contest_id] then
return nil
end
return cache_data[platform][contest_id].start_time
end
---@param platform string
---@param contest_id string
---@return string?
function M.get_contest_display_name(platform, contest_id)
if not cache_data[platform] or not cache_data[platform][contest_id] then
return nil
end
return cache_data[platform][contest_id].display_name
end
---@param platform string
---@return table?
function M.get_credentials(platform)
if not cache_data[platform] then
return nil
end
return cache_data[platform]._credentials
end
---@param platform string
---@param creds table
function M.set_credentials(platform, creds)
cache_data[platform] = cache_data[platform] or {}
cache_data[platform]._credentials = creds
M.save()
end
---@param platform string
function M.clear_credentials(platform)
if cache_data[platform] then
cache_data[platform]._credentials = nil
end
M.save()
end
---@return nil
function M.clear_all() function M.clear_all()
cache_data = {} cache_data = {}
M.save() M.save()
@ -490,9 +338,6 @@ function M.get_data_pretty()
return vim.inspect(cache_data) return vim.inspect(cache_data)
end end
---@return table M._cache = cache_data
function M.get_raw_cache()
return cache_data
end
return M return M

View file

@ -47,30 +47,26 @@ function M.handle_cache_command(cmd)
constants.PLATFORM_DISPLAY_NAMES[cmd.platform], constants.PLATFORM_DISPLAY_NAMES[cmd.platform],
cmd.contest cmd.contest
), ),
{ level = vim.log.levels.INFO, override = true } vim.log.levels.INFO,
true
) )
else else
logger.log( logger.log(("Unknown platform '%s'."):format(cmd.platform), vim.log.levels.ERROR)
("Unknown platform '%s'."):format(cmd.platform),
{ level = vim.log.levels.ERROR }
)
end end
elseif cmd.platform then elseif cmd.platform then
if vim.tbl_contains(platforms, cmd.platform) then if vim.tbl_contains(platforms, cmd.platform) then
cache.clear_platform(cmd.platform) cache.clear_platform(cmd.platform)
logger.log( logger.log(
("Cache cleared for platform '%s'"):format(constants.PLATFORM_DISPLAY_NAMES[cmd.platform]), ("Cache cleared for platform '%s'"):format(constants.PLATFORM_DISPLAY_NAMES[cmd.platform]),
{ level = vim.log.levels.INFO, override = true } vim.log.levels.INFO,
true
) )
else else
logger.log( logger.log(("Unknown platform '%s'."):format(cmd.platform), vim.log.levels.ERROR)
("Unknown platform '%s'."):format(cmd.platform),
{ level = vim.log.levels.ERROR }
)
end end
else else
cache.clear_all() cache.clear_all()
logger.log('Cache cleared', { level = vim.log.levels.INFO, override = true }) logger.log('Cache cleared', vim.log.levels.INFO, true)
end end
end end
end end

View file

@ -11,36 +11,18 @@ local actions = constants.ACTIONS
---@field type string ---@field type string
---@field error string? ---@field error string?
---@field action? string ---@field action? string
---@field requires_context? boolean
---@field message? string ---@field message? string
---@field contest? string ---@field contest? string
---@field platform? string ---@field platform? string
---@field problem_id? string ---@field problem_id? string
---@field interactor_cmd? string ---@field interactor_cmd? string
---@field generator_cmd? string
---@field brute_cmd? string
---@field test_index? integer ---@field test_index? integer
---@field test_indices? integer[] ---@field test_indices? integer[]
---@field mode? string ---@field mode? string
---@field debug? boolean ---@field debug? boolean
---@field language? string ---@field language? string
---@field race? boolean
---@field subcommand? string ---@field subcommand? string
---@param str string
---@return string
local function canonicalize_cf_contest(str)
local id = str:match('/contest/(%d+)') or str:match('/problemset/problem/(%d+)')
if id then
return id
end
local num = str:match('^(%d+)[A-Za-z]')
if num then
return num
end
return str
end
--- Turn raw args into normalized structure to later dispatch --- Turn raw args into normalized structure to later dispatch
---@param args string[] The raw command-line mode args ---@param args string[] The raw command-line mode args
---@return ParsedCommand ---@return ParsedCommand
@ -74,23 +56,10 @@ local function parse_command(args)
elseif first == 'interact' then elseif first == 'interact' then
local inter = args[2] local inter = args[2]
if inter and inter ~= '' then if inter and inter ~= '' then
return { return { type = 'action', action = 'interact', interactor_cmd = inter }
type = 'action',
action = 'interact',
requires_context = true,
interactor_cmd = inter,
}
else else
return { type = 'action', action = 'interact', requires_context = true } return { type = 'action', action = 'interact' }
end end
elseif first == 'stress' then
return {
type = 'action',
action = 'stress',
requires_context = true,
generator_cmd = args[2],
brute_cmd = args[3],
}
elseif first == 'edit' then elseif first == 'edit' then
local test_index = nil local test_index = nil
if #args >= 2 then if #args >= 2 then
@ -106,7 +75,7 @@ local function parse_command(args)
end end
test_index = idx test_index = idx
end end
return { type = 'action', action = 'edit', requires_context = true, test_index = test_index } return { type = 'action', action = 'edit', test_index = test_index }
elseif first == 'run' or first == 'panel' then elseif first == 'run' or first == 'panel' then
local debug = false local debug = false
local test_indices = nil local test_indices = nil
@ -219,28 +188,10 @@ local function parse_command(args)
return { return {
type = 'action', type = 'action',
action = first, action = first,
requires_context = true,
test_indices = test_indices, test_indices = test_indices,
debug = debug, debug = debug,
mode = mode, mode = mode,
} }
elseif first == 'open' then
local target = args[2] or 'problem'
if not vim.tbl_contains({ 'problem', 'contest', 'standings' }, target) then
return { type = 'error', message = 'Usage: :CP open [problem|contest|standings]' }
end
return { type = 'action', action = 'open', requires_context = true, subcommand = target }
elseif first == 'pick' then
local language = nil
if #args >= 3 and args[2] == '--lang' then
language = args[3]
elseif #args >= 2 and args[2] ~= nil and args[2]:sub(1, 2) ~= '--' then
return {
type = 'error',
message = ("Unknown argument '%s' for action '%s'"):format(args[2], first),
}
end
return { type = 'action', action = 'pick', requires_context = false, language = language }
else else
local language = nil local language = nil
if #args >= 3 and args[2] == '--lang' then if #args >= 3 and args[2] == '--lang' then
@ -251,77 +202,33 @@ local function parse_command(args)
message = ("Unknown argument '%s' for action '%s'"):format(args[2], first), message = ("Unknown argument '%s' for action '%s'"):format(args[2], first),
} }
end end
return { type = 'action', action = first, requires_context = true, language = language } return { type = 'action', action = first, language = language }
end end
end end
if vim.tbl_contains(platforms, first) then if vim.tbl_contains(platforms, first) then
if #args == 1 then if #args == 1 then
return { type = 'action', action = 'pick', requires_context = false, platform = first }
elseif #args == 2 then
if args[2] == 'login' or args[2] == 'logout' or args[2] == 'signup' then
return { type = 'action', action = args[2], requires_context = false, platform = first }
end
local contest = args[2]
if first == 'codeforces' then
contest = canonicalize_cf_contest(contest)
end
return { return {
type = 'contest_setup', type = 'error',
platform = first, message = 'Too few arguments - specify a contest.',
contest = contest,
} }
elseif #args == 3 and args[3] == '--race' then elseif #args == 2 then
local contest = args[2]
if first == 'codeforces' then
contest = canonicalize_cf_contest(contest)
end
return { return {
type = 'contest_setup', type = 'contest_setup',
platform = first, platform = first,
contest = contest, contest = args[2],
race = true,
} }
elseif #args == 4 and args[3] == '--lang' then elseif #args == 4 and args[3] == '--lang' then
local contest = args[2]
if first == 'codeforces' then
contest = canonicalize_cf_contest(contest)
end
return { return {
type = 'contest_setup', type = 'contest_setup',
platform = first, platform = first,
contest = contest, contest = args[2],
language = args[4], language = args[4],
} }
elseif #args == 5 then
local contest = args[2]
if first == 'codeforces' then
contest = canonicalize_cf_contest(contest)
end
local language, race = nil, false
if args[3] == '--race' and args[4] == '--lang' then
language = args[5]
race = true
elseif args[3] == '--lang' and args[5] == '--race' then
language = args[4]
race = true
else
return {
type = 'error',
message = 'Invalid arguments. Usage: :CP <platform> <contest> [--race] [--lang <language>]',
}
end
return {
type = 'contest_setup',
platform = first,
contest = contest,
language = language,
race = race,
}
else else
return { return {
type = 'error', type = 'error',
message = 'Invalid arguments. Usage: :CP <platform> <contest> [--race] [--lang <language>]', message = 'Invalid arguments. Usage: :CP <platform> <contest> [--lang <language>]',
} }
end end
end end
@ -342,30 +249,13 @@ local function parse_command(args)
return { type = 'error', message = 'Unknown command or no contest context.' } return { type = 'error', message = 'Unknown command or no contest context.' }
end end
---@param platform string
---@return boolean
local function check_platform_enabled(platform)
local cfg = require('cp.config').get_config()
if not cfg.platforms[platform] then
logger.log(
("Platform '%s' is not enabled. Add it to vim.g.cp.platforms to enable it."):format(
constants.PLATFORM_DISPLAY_NAMES[platform] or platform
),
{ level = vim.log.levels.ERROR }
)
return false
end
return true
end
--- Core logic for handling `:CP ...` commands --- Core logic for handling `:CP ...` commands
---@param opts { fargs: string[] }
---@return nil ---@return nil
function M.handle_command(opts) function M.handle_command(opts)
local cmd = parse_command(opts.fargs) local cmd = parse_command(opts.fargs)
if cmd.type == 'error' then if cmd.type == 'error' then
logger.log(cmd.message, { level = vim.log.levels.ERROR }) logger.log(cmd.message, vim.log.levels.ERROR)
return return
end end
@ -373,13 +263,6 @@ function M.handle_command(opts)
local restore = require('cp.restore') local restore = require('cp.restore')
restore.restore_from_current_file() restore.restore_from_current_file()
elseif cmd.type == 'action' then elseif cmd.type == 'action' then
if cmd.requires_context and not state.get_platform() then
local restore = require('cp.restore')
if not restore.restore_from_current_file() then
return
end
end
local setup = require('cp.setup') local setup = require('cp.setup')
local ui = require('cp.ui.views') local ui = require('cp.ui.views')
@ -398,48 +281,10 @@ function M.handle_command(opts)
setup.navigate_problem(-1, cmd.language) setup.navigate_problem(-1, cmd.language)
elseif cmd.action == 'pick' then elseif cmd.action == 'pick' then
local picker = require('cp.commands.picker') local picker = require('cp.commands.picker')
picker.handle_pick_action(cmd.language, cmd.platform) picker.handle_pick_action(cmd.language)
elseif cmd.action == 'edit' then elseif cmd.action == 'edit' then
local edit = require('cp.ui.edit') local edit = require('cp.ui.edit')
edit.toggle_edit(cmd.test_index) edit.toggle_edit(cmd.test_index)
elseif cmd.action == 'stress' then
require('cp.stress').toggle(cmd.generator_cmd, cmd.brute_cmd)
elseif cmd.action == 'submit' then
require('cp.submit').submit({ language = cmd.language })
elseif cmd.action == 'open' then
local cache = require('cp.cache')
cache.load()
local urls =
cache.get_open_urls(state.get_platform(), state.get_contest_id(), state.get_problem_id())
local url = urls and urls[cmd.subcommand]
if not url or url == '' then
logger.log(
("No URL available for '%s'"):format(cmd.subcommand),
{ level = vim.log.levels.WARN }
)
return
end
vim.ui.open(url)
elseif cmd.action == 'login' then
if not check_platform_enabled(cmd.platform) then
return
end
require('cp.credentials').login(cmd.platform)
elseif cmd.action == 'logout' then
if not check_platform_enabled(cmd.platform) then
return
end
require('cp.credentials').logout(cmd.platform)
elseif cmd.action == 'signup' then
local url = constants.SIGNUP_URLS[cmd.platform]
if not url then
logger.log(
("No signup URL available for '%s'"):format(cmd.platform),
{ level = vim.log.levels.WARN }
)
return
end
vim.ui.open(url)
end end
elseif cmd.type == 'problem_jump' then elseif cmd.type == 'problem_jump' then
local platform = state.get_platform() local platform = state.get_platform()
@ -447,7 +292,7 @@ function M.handle_command(opts)
local problem_id = cmd.problem_id local problem_id = cmd.problem_id
if not (platform and contest_id) then if not (platform and contest_id) then
logger.log('No contest is currently active.', { level = vim.log.levels.ERROR }) logger.log('No contest is currently active.', vim.log.levels.ERROR)
return return
end end
@ -462,7 +307,7 @@ function M.handle_command(opts)
contest_id, contest_id,
problem_id problem_id
), ),
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
return return
end end
@ -473,15 +318,8 @@ function M.handle_command(opts)
local cache_commands = require('cp.commands.cache') local cache_commands = require('cp.commands.cache')
cache_commands.handle_cache_command(cmd) cache_commands.handle_cache_command(cmd)
elseif cmd.type == 'contest_setup' then elseif cmd.type == 'contest_setup' then
if not check_platform_enabled(cmd.platform) then local setup = require('cp.setup')
return setup.setup_contest(cmd.platform, cmd.contest, nil, cmd.language)
end
if cmd.race then
require('cp.race').start(cmd.platform, cmd.contest, cmd.language)
else
local setup = require('cp.setup')
setup.setup_contest(cmd.platform, cmd.contest, nil, cmd.language)
end
return return
end end
end end

View file

@ -5,15 +5,14 @@ local logger = require('cp.log')
--- Dispatch `:CP pick` to appropriate picker --- Dispatch `:CP pick` to appropriate picker
---@param language? string ---@param language? string
---@param platform? string
---@return nil ---@return nil
function M.handle_pick_action(language, platform) function M.handle_pick_action(language)
local config = config_module.get_config() local config = config_module.get_config()
if not (config.ui and config.ui.picker) then if not (config.ui and config.ui.picker) then
logger.log( logger.log(
'No picker configured. Set ui.picker = "{telescope,fzf-lua}" in your config.', 'No picker configured. Set ui.picker = "{telescope,fzf-lua}" in your config.',
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
return return
end end
@ -26,13 +25,13 @@ function M.handle_pick_action(language, platform)
if not ok then if not ok then
logger.log( logger.log(
'telescope.nvim is not available. Install telescope.nvim xor change your picker config.', 'telescope.nvim is not available. Install telescope.nvim xor change your picker config.',
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
return return
end end
local ok_cp, telescope_picker = pcall(require, 'cp.pickers.telescope') local ok_cp, telescope_picker = pcall(require, 'cp.pickers.telescope')
if not ok_cp then if not ok_cp then
logger.log('Failed to load telescope integration.', { level = vim.log.levels.ERROR }) logger.log('Failed to load telescope integration.', vim.log.levels.ERROR)
return return
end end
@ -42,20 +41,20 @@ function M.handle_pick_action(language, platform)
if not ok then if not ok then
logger.log( logger.log(
'fzf-lua is not available. Install fzf-lua or change your picker config', 'fzf-lua is not available. Install fzf-lua or change your picker config',
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
return return
end end
local ok_cp, fzf_picker = pcall(require, 'cp.pickers.fzf_lua') local ok_cp, fzf_picker = pcall(require, 'cp.pickers.fzf_lua')
if not ok_cp then if not ok_cp then
logger.log('Failed to load fzf-lua integration.', { level = vim.log.levels.ERROR }) logger.log('Failed to load fzf-lua integration.', vim.log.levels.ERROR)
return return
end end
picker = fzf_picker picker = fzf_picker
end end
picker.pick(language, platform) picker.pick(language)
end end
return M return M

View file

@ -7,19 +7,10 @@
---@class CpLanguage ---@class CpLanguage
---@field extension string ---@field extension string
---@field commands CpLangCommands ---@field commands CpLangCommands
---@field template? string
---@field version? string
---@field submit_id? string
---@class CpTemplatesConfig
---@field cursor_marker? string
---@class CpPlatformOverrides ---@class CpPlatformOverrides
---@field extension? string ---@field extension? string
---@field commands? CpLangCommands ---@field commands? CpLangCommands
---@field template? string
---@field version? string
---@field submit_id? string
---@class CpPlatform ---@class CpPlatform
---@field enabled_languages string[] ---@field enabled_languages string[]
@ -29,7 +20,6 @@
---@class PanelConfig ---@class PanelConfig
---@field diff_modes string[] ---@field diff_modes string[]
---@field max_output_lines integer ---@field max_output_lines integer
---@field precision number?
---@class DiffGitConfig ---@class DiffGitConfig
---@field args string[] ---@field args string[]
@ -37,23 +27,12 @@
---@class DiffConfig ---@class DiffConfig
---@field git DiffGitConfig ---@field git DiffGitConfig
---@class CpSetupIOHooks
---@field input? fun(bufnr: integer, state: cp.State)
---@field output? fun(bufnr: integer, state: cp.State)
---@class CpSetupHooks
---@field contest? fun(state: cp.State)
---@field code? fun(state: cp.State)
---@field io? CpSetupIOHooks
---@class CpOnHooks
---@field enter? fun(state: cp.State)
---@field run? fun(state: cp.State)
---@field debug? fun(state: cp.State)
---@class Hooks ---@class Hooks
---@field setup? CpSetupHooks ---@field before_run? fun(state: cp.State)
---@field on? CpOnHooks ---@field before_debug? fun(state: cp.State)
---@field setup_code? fun(state: cp.State)
---@field setup_io_input? fun(bufnr: integer, state: cp.State)
---@field setup_io_output? fun(bufnr: integer, state: cp.State)
---@class VerdictFormatData ---@class VerdictFormatData
---@field index integer ---@field index integer
@ -82,6 +61,8 @@
---@class RunConfig ---@class RunConfig
---@field width number ---@field width number
---@field next_test_key string|nil
---@field prev_test_key string|nil
---@field format_verdict VerdictFormatter ---@field format_verdict VerdictFormatter
---@class EditConfig ---@class EditConfig
@ -102,16 +83,15 @@
---@class cp.Config ---@class cp.Config
---@field languages table<string, CpLanguage> ---@field languages table<string, CpLanguage>
---@field platforms table<string, CpPlatform> ---@field platforms table<string, CpPlatform>
---@field templates? CpTemplatesConfig
---@field hooks Hooks ---@field hooks Hooks
---@field debug boolean ---@field debug boolean
---@field open_url boolean
---@field scrapers string[] ---@field scrapers string[]
---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string ---@field filename? fun(contest: string, contest_id: string, problem_id?: string, config: cp.Config, language?: string): string
---@field ui CpUI ---@field ui CpUI
---@field runtime { effective: table<string, table<string, CpLanguage>> } -- computed ---@field runtime { effective: table<string, table<string, CpLanguage>> } -- computed
---@class cp.PartialConfig: cp.Config ---@class cp.PartialConfig: cp.Config
---@field platforms? table<string, CpPlatform|false>
local M = {} local M = {}
@ -122,6 +102,7 @@ local utils = require('cp.utils')
-- defaults per the new single schema -- defaults per the new single schema
---@type cp.Config ---@type cp.Config
M.defaults = { M.defaults = {
open_url = false,
languages = { languages = {
cpp = { cpp = {
extension = 'cc', extension = 'cc',
@ -166,29 +147,13 @@ M.defaults = {
enabled_languages = { 'cpp', 'python' }, enabled_languages = { 'cpp', 'python' },
default_language = 'cpp', default_language = 'cpp',
}, },
kattis = {
enabled_languages = { 'cpp', 'python' },
default_language = 'cpp',
},
usaco = {
enabled_languages = { 'cpp', 'python' },
default_language = 'cpp',
},
}, },
hooks = { hooks = {
setup = { before_run = nil,
contest = nil, before_debug = nil,
code = nil, setup_code = nil,
io = { setup_io_input = helpers.clearcol,
input = helpers.clearcol, setup_io_output = helpers.clearcol,
output = helpers.clearcol,
},
},
on = {
enter = nil,
run = nil,
debug = nil,
},
}, },
debug = false, debug = false,
scrapers = constants.PLATFORMS, scrapers = constants.PLATFORMS,
@ -197,6 +162,8 @@ M.defaults = {
ansi = true, ansi = true,
run = { run = {
width = 0.3, width = 0.3,
next_test_key = '<c-n>',
prev_test_key = '<c-p>',
format_verdict = helpers.default_verdict_formatter, format_verdict = helpers.default_verdict_formatter,
}, },
edit = { edit = {
@ -206,11 +173,7 @@ M.defaults = {
add_test_key = 'ga', add_test_key = 'ga',
save_and_exit_key = 'q', save_and_exit_key = 'q',
}, },
panel = { panel = { diff_modes = { 'side-by-side', 'git', 'vim' }, max_output_lines = 50 },
diff_modes = { 'side-by-side', 'git', 'vim' },
max_output_lines = 50,
precision = nil,
},
diff = { diff = {
git = { git = {
args = { 'diff', '--no-index', '--word-diff=plain', '--word-diff-regex=.', '--no-prefix' }, args = { 'diff', '--no-index', '--word-diff=plain', '--word-diff-regex=.', '--no-prefix' },
@ -252,18 +215,21 @@ local function validate_language(id, lang)
commands = { lang.commands, { 'table' } }, commands = { lang.commands, { 'table' } },
}) })
if lang.template ~= nil then
vim.validate({ template = { lang.template, 'string' } })
end
if not lang.commands.run then if not lang.commands.run then
error(('[cp.nvim] languages.%s.commands.run is required'):format(id)) error(('[cp.nvim] languages.%s.commands.run is required'):format(id))
end end
if lang.commands.build ~= nil then if lang.commands.build ~= nil then
vim.validate({ build = { lang.commands.build, { 'table' } } }) vim.validate({ build = { lang.commands.build, { 'table' } } })
if not has_tokens(lang.commands.build, { '{source}' }) then if not has_tokens(lang.commands.build, { '{source}', '{binary}' }) then
error(('[cp.nvim] languages.%s.commands.build must include {source}'):format(id)) error(('[cp.nvim] languages.%s.commands.build must include {source} and {binary}'):format(id))
end
for _, k in ipairs({ 'run', 'debug' }) do
if lang.commands[k] then
if not has_tokens(lang.commands[k], { '{binary}' }) then
error(('[cp.nvim] languages.%s.commands.%s must include {binary}'):format(id, k))
end
end
end end
else else
for _, k in ipairs({ 'run', 'debug' }) do for _, k in ipairs({ 'run', 'debug' }) do
@ -287,15 +253,6 @@ local function merge_lang(base, ov)
if ov.commands then if ov.commands then
out.commands = vim.tbl_deep_extend('force', out.commands or {}, ov.commands or {}) out.commands = vim.tbl_deep_extend('force', out.commands or {}, ov.commands or {})
end end
if ov.template then
out.template = ov.template
end
if ov.version then
out.version = ov.version
end
if ov.submit_id then
out.submit_id = ov.submit_id
end
return out return out
end end
@ -326,23 +283,6 @@ local function build_runtime(cfg)
validate_language(lid, base) validate_language(lid, base)
local eff = merge_lang(base, p.overrides and p.overrides[lid] or nil) local eff = merge_lang(base, p.overrides and p.overrides[lid] or nil)
validate_language(lid, eff) validate_language(lid, eff)
if eff.version then
local normalized = eff.version:lower():gsub('%s+', '')
local versions = (constants.LANGUAGE_VERSIONS[plat] or {})[lid]
if not versions or not versions[normalized] then
local avail = versions and vim.tbl_keys(versions) or {}
table.sort(avail)
error(
("[cp.nvim] Unknown version '%s' for %s on %s. Available: [%s]. See :help cp-submit-language"):format(
eff.version,
lid,
plat,
table.concat(avail, ', ')
)
)
end
eff.version = normalized
end
cfg.runtime.effective[plat][lid] = eff cfg.runtime.effective[plat][lid] = eff
end end
end end
@ -352,20 +292,7 @@ end
---@return cp.Config ---@return cp.Config
function M.setup(user_config) function M.setup(user_config)
vim.validate({ user_config = { user_config, { 'table', 'nil' }, true } }) vim.validate({ user_config = { user_config, { 'table', 'nil' }, true } })
local defaults = vim.deepcopy(M.defaults) local cfg = vim.tbl_deep_extend('force', vim.deepcopy(M.defaults), user_config or {})
if user_config and user_config.platforms then
for plat, v in pairs(user_config.platforms) do
if v == false then
defaults.platforms[plat] = nil
end
end
end
local cfg = vim.tbl_deep_extend('force', defaults, user_config or {})
for plat, v in pairs(cfg.platforms) do
if v == false then
cfg.platforms[plat] = nil
end
end
if not next(cfg.languages) then if not next(cfg.languages) then
error('[cp.nvim] At least one language must be configured') error('[cp.nvim] At least one language must be configured')
@ -375,17 +302,11 @@ function M.setup(user_config)
error('[cp.nvim] At least one platform must be configured') error('[cp.nvim] At least one platform must be configured')
end end
if cfg.templates ~= nil then
vim.validate({ templates = { cfg.templates, 'table' } })
if cfg.templates.cursor_marker ~= nil then
vim.validate({ cursor_marker = { cfg.templates.cursor_marker, 'string' } })
end
end
vim.validate({ vim.validate({
hooks = { cfg.hooks, { 'table' } }, hooks = { cfg.hooks, { 'table' } },
ui = { cfg.ui, { 'table' } }, ui = { cfg.ui, { 'table' } },
debug = { cfg.debug, { 'boolean', 'nil' }, true }, debug = { cfg.debug, { 'boolean', 'nil' }, true },
open_url = { cfg.open_url, { 'boolean', 'nil' }, true },
filename = { cfg.filename, { 'function', 'nil' }, true }, filename = { cfg.filename, { 'function', 'nil' }, true },
scrapers = { scrapers = {
cfg.scrapers, cfg.scrapers,
@ -402,29 +323,12 @@ function M.setup(user_config)
end, end,
('one of {%s}'):format(table.concat(constants.PLATFORMS, ',')), ('one of {%s}'):format(table.concat(constants.PLATFORMS, ',')),
}, },
before_run = { cfg.hooks.before_run, { 'function', 'nil' }, true },
before_debug = { cfg.hooks.before_debug, { 'function', 'nil' }, true },
setup_code = { cfg.hooks.setup_code, { 'function', 'nil' }, true },
setup_io_input = { cfg.hooks.setup_io_input, { 'function', 'nil' }, true },
setup_io_output = { cfg.hooks.setup_io_output, { 'function', 'nil' }, true },
}) })
if cfg.hooks.setup ~= nil then
vim.validate({ setup = { cfg.hooks.setup, 'table' } })
vim.validate({
contest = { cfg.hooks.setup.contest, { 'function', 'nil' }, true },
code = { cfg.hooks.setup.code, { 'function', 'nil' }, true },
})
if cfg.hooks.setup.io ~= nil then
vim.validate({ io = { cfg.hooks.setup.io, 'table' } })
vim.validate({
input = { cfg.hooks.setup.io.input, { 'function', 'nil' }, true },
output = { cfg.hooks.setup.io.output, { 'function', 'nil' }, true },
})
end
end
if cfg.hooks.on ~= nil then
vim.validate({ on = { cfg.hooks.on, 'table' } })
vim.validate({
enter = { cfg.hooks.on.enter, { 'function', 'nil' }, true },
run = { cfg.hooks.on.run, { 'function', 'nil' }, true },
debug = { cfg.hooks.on.debug, { 'function', 'nil' }, true },
})
end
local layouts = require('cp.ui.layouts') local layouts = require('cp.ui.layouts')
vim.validate({ vim.validate({
@ -451,13 +355,6 @@ function M.setup(user_config)
end, end,
'positive integer', 'positive integer',
}, },
precision = {
cfg.ui.panel.precision,
function(v)
return v == nil or (type(v) == 'number' and v >= 0)
end,
'nil or non-negative number',
},
git = { cfg.ui.diff.git, { 'table' } }, git = { cfg.ui.diff.git, { 'table' } },
git_args = { cfg.ui.diff.git.args, is_string_list, 'string[]' }, git_args = { cfg.ui.diff.git.args, is_string_list, 'string[]' },
width = { width = {
@ -467,6 +364,20 @@ function M.setup(user_config)
end, end,
'decimal between 0 and 1', 'decimal between 0 and 1',
}, },
next_test_key = {
cfg.ui.run.next_test_key,
function(v)
return v == nil or (type(v) == 'string' and #v > 0)
end,
'nil or non-empty string',
},
prev_test_key = {
cfg.ui.run.prev_test_key,
function(v)
return v == nil or (type(v) == 'string' and #v > 0)
end,
'nil or non-empty string',
},
format_verdict = { format_verdict = {
cfg.ui.run.format_verdict, cfg.ui.run.format_verdict,
'function', 'function',
@ -531,12 +442,10 @@ end
local current_config = nil local current_config = nil
---@param config cp.Config
function M.set_current_config(config) function M.set_current_config(config)
current_config = config current_config = config
end end
---@return cp.Config
function M.get_config() function M.get_config()
return current_config or M.defaults return current_config or M.defaults
end end

View file

@ -1,36 +1,13 @@
local M = {} local M = {}
M.PLATFORMS = { 'atcoder', 'codechef', 'codeforces', 'cses', 'kattis', 'usaco' } M.PLATFORMS = { 'atcoder', 'codechef', 'codeforces', 'cses' }
M.ACTIONS = { M.ACTIONS = { 'run', 'panel', 'next', 'prev', 'pick', 'cache', 'interact', 'edit' }
'run',
'panel',
'next',
'prev',
'pick',
'cache',
'interact',
'edit',
'stress',
'submit',
'open',
}
M.PLATFORM_DISPLAY_NAMES = { M.PLATFORM_DISPLAY_NAMES = {
atcoder = 'AtCoder', atcoder = 'AtCoder',
codechef = 'CodeChef', codechef = 'CodeChef',
codeforces = 'CodeForces', codeforces = 'CodeForces',
cses = 'CSES', cses = 'CSES',
kattis = 'Kattis',
usaco = 'USACO',
}
M.SIGNUP_URLS = {
atcoder = 'https://atcoder.jp/register',
codechef = 'https://www.codechef.com/register',
codeforces = 'https://codeforces.com/register',
cses = 'https://cses.fi/register',
kattis = 'https://open.kattis.com/register',
usaco = 'https://usaco.org/index.php?page=createaccount',
} }
M.CPP = 'cpp' M.CPP = 'cpp'
@ -73,150 +50,4 @@ M.signal_codes = {
[143] = 'SIGTERM', [143] = 'SIGTERM',
} }
M.LANGUAGE_VERSIONS = {
atcoder = {
cpp = { ['c++20'] = '6054', ['c++23'] = '6017', ['c++23-clang'] = '6116' },
python = { python3 = '6082', pypy3 = '6083', codon = '6115' },
java = { java = '6056' },
rust = { rust = '6088' },
c = { c23clang = '6013', c23gcc = '6014' },
go = { go = '6051', gccgo = '6050' },
haskell = { haskell = '6052' },
csharp = { csharp = '6015', ['csharp-aot'] = '6016' },
kotlin = { kotlin = '6062' },
ruby = { ruby = '6087', truffleruby = '6086' },
javascript = { bun = '6057', deno = '6058', nodejs = '6059' },
typescript = { deno = '6100', bun = '6101', nodejs = '6102' },
scala = { scala = '6090', ['scala-native'] = '6091' },
ocaml = { ocaml = '6073' },
dart = { dart = '6033' },
elixir = { elixir = '6038' },
erlang = { erlang = '6041' },
fsharp = { fsharp = '6042' },
swift = { swift = '6095' },
zig = { zig = '6111' },
nim = { nim = '6072', ['nim-old'] = '6071' },
lua = { lua = '6067', luajit = '6068' },
perl = { perl = '6076' },
php = { php = '6077' },
pascal = { pascal = '6075' },
crystal = { crystal = '6028' },
d = { dmd = '6030', gdc = '6031', ldc = '6032' },
julia = { julia = '6114' },
r = { r = '6084' },
commonlisp = { commonlisp = '6027' },
scheme = { chezscheme = '6092', gauche = '6093' },
clojure = { clojure = '6022', ['clojure-aot'] = '6023', babashka = '6021' },
ada = { ada = '6002' },
bash = { bash = '6008' },
fortran = { fortran2023 = '6047', fortran2018 = '6046', fortran77 = '6048' },
gleam = { gleam = '6049' },
lean = { lean = '6065' },
pony = { pony = '6079' },
prolog = { prolog = '6081' },
vala = { vala = '6106' },
v = { v = '6105' },
sql = { duckdb = '6118' },
},
codeforces = {
cpp = { ['c++17'] = '54', ['c++20'] = '89', ['c++23'] = '91', c11 = '43' },
python = { python3 = '31', pypy3 = '70', python2 = '7', pypy2 = '40', ['pypy3-old'] = '41' },
java = { java8 = '36', java21 = '87' },
kotlin = { ['1.7'] = '83', ['1.9'] = '88', ['2.2'] = '99' },
rust = { ['2021'] = '75', ['2024'] = '98' },
go = { go = '32' },
csharp = { mono = '9', dotnet3 = '65', dotnet6 = '79', dotnet9 = '96' },
haskell = { haskell = '12' },
javascript = { v8 = '34', nodejs = '55' },
ruby = { ruby = '67' },
scala = { scala = '20' },
ocaml = { ocaml = '19' },
d = { d = '28' },
perl = { perl = '13' },
php = { php = '6' },
pascal = { freepascal = '4', pascalabc = '51' },
fsharp = { fsharp = '97' },
},
cses = {
cpp = { ['c++17'] = 'C++17' },
python = { python3 = 'Python3', pypy3 = 'PyPy3' },
java = { java = 'Java' },
rust = { rust2021 = 'Rust2021' },
},
kattis = {
cpp = { ['c++17'] = 'C++', ['c++20'] = 'C++', ['c++23'] = 'C++' },
python = { python3 = 'Python 3', python2 = 'Python 2' },
java = { java = 'Java' },
rust = { rust = 'Rust' },
ada = { ada = 'Ada' },
algol60 = { algol60 = 'Algol 60' },
algol68 = { algol68 = 'Algol 68' },
apl = { apl = 'APL' },
bash = { bash = 'Bash' },
bcpl = { bcpl = 'BCPL' },
bqn = { bqn = 'BQN' },
c = { c = 'C' },
cobol = { cobol = 'COBOL' },
commonlisp = { commonlisp = 'Common Lisp' },
crystal = { crystal = 'Crystal' },
csharp = { csharp = 'C#' },
d = { d = 'D' },
dart = { dart = 'Dart' },
elixir = { elixir = 'Elixir' },
erlang = { erlang = 'Erlang' },
forth = { forth = 'Forth' },
fortran = { fortran = 'Fortran' },
fortran77 = { fortran77 = 'Fortran 77' },
fsharp = { fsharp = 'F#' },
gerbil = { gerbil = 'Gerbil' },
go = { go = 'Go' },
haskell = { haskell = 'Haskell' },
icon = { icon = 'Icon' },
javascript = { javascript = 'JavaScript (Node.js)', spidermonkey = 'JavaScript (SpiderMonkey)' },
julia = { julia = 'Julia' },
kotlin = { kotlin = 'Kotlin' },
lua = { lua = 'Lua' },
modula2 = { modula2 = 'Modula-2' },
nim = { nim = 'Nim' },
objectivec = { objectivec = 'Objective-C' },
ocaml = { ocaml = 'OCaml' },
octave = { octave = 'Octave' },
odin = { odin = 'Odin' },
pascal = { pascal = 'Pascal' },
perl = { perl = 'Perl' },
php = { php = 'PHP' },
pli = { pli = 'PL/I' },
prolog = { prolog = 'Prolog' },
racket = { racket = 'Racket' },
ruby = { ruby = 'Ruby' },
scala = { scala = 'Scala' },
simula = { simula = 'Simula 67' },
smalltalk = { smalltalk = 'Smalltalk' },
snobol = { snobol = 'SNOBOL' },
swift = { swift = 'Swift' },
typescript = { typescript = 'TypeScript' },
visualbasic = { visualbasic = 'Visual Basic' },
zig = { zig = 'Zig' },
},
usaco = {
cpp = { ['c++11'] = 'cpp', ['c++17'] = 'cpp' },
python = { python3 = 'python' },
java = { java = 'java' },
},
codechef = {
cpp = { ['c++20'] = 'C++' },
python = { python3 = 'PYTH 3', pypy3 = 'PYPY3' },
java = { java = 'JAVA' },
rust = { rust = 'rust' },
c = { c = 'C' },
go = { go = 'GO' },
kotlin = { kotlin = 'KTLN' },
javascript = { nodejs = 'NODEJS' },
typescript = { typescript = 'TS' },
csharp = { csharp = 'C#' },
},
}
M.DEFAULT_VERSIONS = { cpp = 'c++20', python = 'python3' }
return M return M

View file

@ -1,88 +0,0 @@
local M = {}
local cache = require('cp.cache')
local constants = require('cp.constants')
local logger = require('cp.log')
local state = require('cp.state')
local STATUS_MESSAGES = {
checking_login = 'Checking existing session...',
logging_in = 'Logging in...',
installing_browser = 'Installing browser...',
}
---@param platform string?
function M.login(platform)
platform = platform or state.get_platform()
if not platform then
logger.log(
'No platform specified. Usage: :CP <platform> login',
{ level = vim.log.levels.ERROR }
)
return
end
local display = constants.PLATFORM_DISPLAY_NAMES[platform] or platform
vim.ui.input({ prompt = display .. ' username: ' }, function(username)
if not username or username == '' then
logger.log('Cancelled', { level = vim.log.levels.WARN })
return
end
vim.fn.inputsave()
local password = vim.fn.inputsecret(display .. ' password: ')
vim.fn.inputrestore()
if not password or password == '' then
logger.log('Cancelled', { level = vim.log.levels.WARN })
return
end
cache.load()
local existing = cache.get_credentials(platform) or {}
local credentials = {
username = username,
password = password,
}
if existing.token then
credentials.token = existing.token
end
local scraper = require('cp.scraper')
scraper.login(platform, credentials, function(ev)
vim.schedule(function()
local msg = STATUS_MESSAGES[ev.status] or ev.status
logger.log(display .. ': ' .. msg, { level = vim.log.levels.INFO, override = true })
end)
end, function(result)
vim.schedule(function()
if result.success then
logger.log(
display .. ' login successful',
{ level = vim.log.levels.INFO, override = true }
)
else
local err = result.error or 'unknown error'
logger.log(display .. ' login failed: ' .. err, { level = vim.log.levels.ERROR })
end
end)
end)
end)
end
---@param platform string?
function M.logout(platform)
platform = platform or state.get_platform()
if not platform then
logger.log(
'No platform specified. Usage: :CP <platform> logout',
{ level = vim.log.levels.ERROR }
)
return
end
local display = constants.PLATFORM_DISPLAY_NAMES[platform] or platform
cache.load()
cache.clear_credentials(platform)
logger.log(display .. ' credentials cleared', { level = vim.log.levels.INFO, override = true })
end
return M

View file

@ -5,50 +5,33 @@ local utils = require('cp.utils')
local function check() local function check()
vim.health.start('cp.nvim [required] ~') vim.health.start('cp.nvim [required] ~')
utils.setup_python_env()
if vim.fn.has('nvim-0.10.0') == 1 then if vim.fn.has('nvim-0.10.0') == 1 then
vim.health.ok('Neovim 0.10.0+ detected') vim.health.ok('Neovim 0.10.0+ detected')
else else
vim.health.error('cp.nvim requires Neovim 0.10.0+') vim.health.error('cp.nvim requires Neovim 0.10.0+')
end end
local uname = vim.uv.os_uname() local uname = vim.loop.os_uname()
if uname.sysname == 'Windows_NT' then if uname.sysname == 'Windows_NT' then
vim.health.error('Windows is not supported') vim.health.error('Windows is not supported')
end end
if utils.is_nix_build() then if vim.fn.executable('uv') == 1 then
local source = utils.is_nix_discovered() and 'runtime discovery' or 'flake install' vim.health.ok('uv executable found')
vim.health.ok('Nix Python environment detected (' .. source .. ')') local r = vim.system({ 'uv', '--version' }, { text = true }):wait()
local py = utils.get_nix_python()
vim.health.info('Python: ' .. py)
local r = vim.system({ py, '--version' }, { text = true }):wait()
if r.code == 0 then if r.code == 0 then
vim.health.info('Python version: ' .. r.stdout:gsub('\n', '')) vim.health.info('uv version: ' .. r.stdout:gsub('\n', ''))
end end
else else
if vim.fn.executable('uv') == 1 then vim.health.warn('uv not found (install https://docs.astral.sh/uv/ for scraping)')
vim.health.ok('uv executable found') end
local r = vim.system({ 'uv', '--version' }, { text = true }):wait()
if r.code == 0 then
vim.health.info('uv version: ' .. r.stdout:gsub('\n', ''))
end
else
vim.health.warn('uv not found (install https://docs.astral.sh/uv/ for scraping)')
end
if vim.fn.executable('nix') == 1 then local plugin_path = utils.get_plugin_path()
vim.health.info('nix available but Python environment not resolved via nix') local venv_dir = plugin_path .. '/.venv'
end if vim.fn.isdirectory(venv_dir) == 1 then
vim.health.ok('Python virtual environment found at ' .. venv_dir)
local plugin_path = utils.get_plugin_path() else
local venv_dir = plugin_path .. '/.venv' vim.health.info('Python virtual environment not set up (created on first scrape)')
if vim.fn.isdirectory(venv_dir) == 1 then
vim.health.ok('Python virtual environment found at ' .. venv_dir)
else
vim.health.info('Python virtual environment not set up (created on first scrape)')
end
end end
local time_cap = utils.time_capability() local time_cap = utils.time_capability()
@ -58,7 +41,7 @@ local function check()
vim.health.error('GNU time not found: ' .. (time_cap.reason or '')) vim.health.error('GNU time not found: ' .. (time_cap.reason or ''))
end end
local timeout_cap = utils.timeout_capability() local timeout_cap = utils.time_capability()
if timeout_cap.ok then if timeout_cap.ok then
vim.health.ok('GNU timeout found: ' .. timeout_cap.path) vim.health.ok('GNU timeout found: ' .. timeout_cap.path)
else else
@ -66,7 +49,6 @@ local function check()
end end
end end
---@return nil
function M.check() function M.check()
local version = require('cp.version') local version = require('cp.version')
vim.health.start('cp.nvim health check ~') vim.health.start('cp.nvim health check ~')

View file

@ -7,7 +7,7 @@ local logger = require('cp.log')
M.helpers = helpers M.helpers = helpers
if vim.fn.has('nvim-0.10.0') == 0 then if vim.fn.has('nvim-0.10.0') == 0 then
logger.log('Requires nvim-0.10.0+', { level = vim.log.levels.ERROR }) logger.log('Requires nvim-0.10.0+', vim.log.levels.ERROR)
return {} return {}
end end
@ -15,42 +15,23 @@ local initialized = false
local function ensure_initialized() local function ensure_initialized()
if initialized then if initialized then
return true return
end end
local user_config = vim.g.cp or {} local user_config = vim.g.cp_config or {}
local ok, result = pcall(config_module.setup, user_config) local config = config_module.setup(user_config)
if not ok then config_module.set_current_config(config)
local msg = tostring(result):gsub('^.+:%d+: ', '')
logger.log(msg, { level = vim.log.levels.ERROR, override = true, sync = true })
return false
end
config_module.set_current_config(result)
initialized = true initialized = true
return true
end end
---@return nil ---@return nil
function M.handle_command(opts) function M.handle_command(opts)
if not ensure_initialized() then ensure_initialized()
return
end
local commands = require('cp.commands') local commands = require('cp.commands')
commands.handle_command(opts) commands.handle_command(opts)
end end
---@return boolean
function M.is_initialized() function M.is_initialized()
return initialized return initialized
end end
---@deprecated Use `vim.g.cp` instead
---@param user_config table?
function M.setup(user_config)
vim.deprecate('require("cp").setup()', 'vim.g.cp', 'v0.7.7', 'cp.nvim', false)
if user_config then
vim.g.cp = vim.tbl_deep_extend('force', vim.g.cp or {}, user_config)
end
end
return M return M

View file

@ -1,27 +1,12 @@
local M = {} local M = {}
---@class LogOpts function M.log(msg, level, override)
---@field level? integer
---@field override? boolean
---@field sync? boolean
---@param msg string
---@param opts? LogOpts
function M.log(msg, opts)
local debug = require('cp.config').get_config().debug or false local debug = require('cp.config').get_config().debug or false
opts = opts or {} level = level or vim.log.levels.INFO
local level = opts.level or vim.log.levels.INFO
local override = opts.override or false
local sync = opts.sync or false
if level >= vim.log.levels.WARN or override or debug then if level >= vim.log.levels.WARN or override or debug then
local notify = function() vim.schedule(function()
vim.notify(('[cp.nvim]: %s'):format(msg), level) vim.notify(('[cp.nvim]: %s'):format(msg), level)
end end)
if sync then
notify()
else
vim.schedule(notify)
end
end end
end end

View file

@ -1,4 +1,3 @@
local logger = require('cp.log')
local picker_utils = require('cp.pickers') local picker_utils = require('cp.pickers')
local M = {} local M = {}
@ -10,9 +9,9 @@ local function contest_picker(platform, refresh, language)
local contests = picker_utils.get_platform_contests(platform, refresh) local contests = picker_utils.get_platform_contests(platform, refresh)
if vim.tbl_isempty(contests) then if vim.tbl_isempty(contests) then
logger.log( vim.notify(
("No contests found for platform '%s'"):format(platform_display_name), ("No contests found for platform '%s'"):format(platform_display_name),
{ level = vim.log.levels.WARN } vim.log.levels.WARN
) )
return return
end end
@ -58,18 +57,11 @@ local function contest_picker(platform, refresh, language)
}) })
end end
---@param language? string function M.pick(language)
---@param platform? string
function M.pick(language, platform)
if platform then
contest_picker(platform, false, language)
return
end
local fzf = require('fzf-lua') local fzf = require('fzf-lua')
local platforms = picker_utils.get_platforms() local platforms = picker_utils.get_platforms()
local entries = vim.tbl_map(function(p) local entries = vim.tbl_map(function(platform)
return p.display_name return platform.display_name
end, platforms) end, platforms)
return fzf.fzf_exec(entries, { return fzf.fzf_exec(entries, {
@ -81,16 +73,16 @@ function M.pick(language, platform)
end end
local selected_name = selected[1] local selected_name = selected[1]
local found = nil local platform = nil
for _, p in ipairs(platforms) do for _, p in ipairs(platforms) do
if p.display_name == selected_name then if p.display_name == selected_name then
found = p platform = p
break break
end end
end end
if found then if platform then
contest_picker(found.id, false, language) contest_picker(platform.id, false, language)
end end
end, end,
}, },

View file

@ -42,21 +42,23 @@ function M.get_platform_contests(platform, refresh)
local picker_contests = cache.get_contest_summaries(platform) local picker_contests = cache.get_contest_summaries(platform)
if refresh or vim.tbl_isempty(picker_contests) then if refresh or vim.tbl_isempty(picker_contests) then
local display_name = constants.PLATFORM_DISPLAY_NAMES[platform]
logger.log( logger.log(
('Fetching %s contests...'):format(display_name), ('Loading %s contests...'):format(constants.PLATFORM_DISPLAY_NAMES[platform]),
{ level = vim.log.levels.INFO, override = true, sync = true } vim.log.levels.INFO,
true
) )
local result = scraper.scrape_contest_list(platform) local contests = scraper.scrape_contest_list(platform)
local contests = result and result.contests or {} cache.set_contest_summaries(platform, contests)
local sc = result and result.supports_countdown
cache.set_contest_summaries(platform, contests, { supports_countdown = sc })
picker_contests = cache.get_contest_summaries(platform) picker_contests = cache.get_contest_summaries(platform)
logger.log( logger.log(
('Fetched %d %s contests.'):format(#picker_contests, display_name), ('Loaded %d %s contests.'):format(
{ level = vim.log.levels.INFO, override = true } #picker_contests,
constants.PLATFORM_DISPLAY_NAMES[platform]
),
vim.log.levels.INFO,
true
) )
end end

View file

@ -4,7 +4,6 @@ local conf = require('telescope.config').values
local action_state = require('telescope.actions.state') local action_state = require('telescope.actions.state')
local actions = require('telescope.actions') local actions = require('telescope.actions')
local logger = require('cp.log')
local picker_utils = require('cp.pickers') local picker_utils = require('cp.pickers')
local M = {} local M = {}
@ -15,9 +14,9 @@ local function contest_picker(opts, platform, refresh, language)
local contests = picker_utils.get_platform_contests(platform, refresh) local contests = picker_utils.get_platform_contests(platform, refresh)
if vim.tbl_isempty(contests) then if vim.tbl_isempty(contests) then
logger.log( vim.notify(
('No contests found for platform: %s'):format(platform_display_name), ('No contests found for platform: %s'):format(platform_display_name),
{ level = vim.log.levels.WARN } vim.log.levels.WARN
) )
return return
end end
@ -64,14 +63,7 @@ local function contest_picker(opts, platform, refresh, language)
:find() :find()
end end
---@param language? string function M.pick(language)
---@param platform? string
function M.pick(language, platform)
if platform then
contest_picker({}, platform, false, language)
return
end
local opts = {} local opts = {}
local platforms = picker_utils.get_platforms() local platforms = picker_utils.get_platforms()

View file

@ -1,296 +0,0 @@
local M = {}
local cache = require('cp.cache')
local constants = require('cp.constants')
local logger = require('cp.log')
local scraper = require('cp.scraper')
local REFETCH_INTERVAL_S = 600
local RETRY_DELAY_MS = 3000
local MAX_RETRY_ATTEMPTS = 15
local race_state = {
timer = nil,
token = nil,
platform = nil,
contest_id = nil,
contest_name = nil,
language = nil,
start_time = nil,
last_refetch = nil,
}
local function format_countdown(seconds)
local d = math.floor(seconds / 86400)
local h = math.floor((seconds % 86400) / 3600)
local m = math.floor((seconds % 3600) / 60)
local s = seconds % 60
if d > 0 then
return string.format('%dd%dh%dm%ds', d, h, m, s)
elseif h > 0 then
return string.format('%dh%dm%ds', h, m, s)
elseif m > 0 then
return string.format('%dm%ds', m, s)
end
return string.format('%ds', s)
end
local function should_notify(remaining)
if remaining > 3600 then
return remaining % 900 == 0
end
if remaining > 300 then
return remaining % 60 == 0
end
if remaining > 60 then
return remaining % 10 == 0
end
return true
end
local function refetch_start_time()
local result = scraper.scrape_contest_list(race_state.platform)
if not result or not result.contests or #result.contests == 0 then
return
end
cache.set_contest_summaries(
race_state.platform,
result.contests,
{ supports_countdown = result.supports_countdown }
)
local new_time = cache.get_contest_start_time(race_state.platform, race_state.contest_id)
if new_time and new_time ~= race_state.start_time then
race_state.start_time = new_time
race_state.contest_name = cache.get_contest_display_name(
race_state.platform,
race_state.contest_id
) or race_state.contest_id
end
end
local function race_try_setup(platform, contest_id, language, attempt, token)
if race_state.token ~= token then
return
end
cache.load()
local cd = cache.get_contest_data(platform, contest_id)
if
cd
and type(cd.problems) == 'table'
and #cd.problems > 0
and type(cd.index_map) == 'table'
and next(cd.index_map) ~= nil
then
require('cp.setup').setup_contest(platform, contest_id, nil, language)
return
end
local display = constants.PLATFORM_DISPLAY_NAMES[platform] or platform
if attempt > 1 then
logger.log(
('Retrying %s "%s" setup (attempt %d/%d)...'):format(
display,
contest_id,
attempt,
MAX_RETRY_ATTEMPTS
),
{ level = vim.log.levels.WARN }
)
end
scraper.scrape_contest_metadata(
platform,
contest_id,
vim.schedule_wrap(function(data)
if race_state.token ~= token then
return
end
cache.set_contest_data(
platform,
contest_id,
data.problems or {},
data.url or '',
data.contest_url or '',
data.standings_url or ''
)
require('cp.setup').setup_contest(platform, contest_id, nil, language)
end),
vim.schedule_wrap(function()
if race_state.token ~= token then
return
end
if attempt >= MAX_RETRY_ATTEMPTS then
logger.log(
('Failed to load %s contest "%s" after %d attempts'):format(display, contest_id, attempt),
{ level = vim.log.levels.ERROR }
)
return
end
vim.defer_fn(function()
race_try_setup(platform, contest_id, language, attempt + 1, token)
end, RETRY_DELAY_MS)
end)
)
end
---@param platform string
---@param contest_id string
---@param language? string
function M.start(platform, contest_id, language)
if not platform or not vim.tbl_contains(constants.PLATFORMS, platform) then
logger.log('Invalid platform', { level = vim.log.levels.ERROR })
return
end
if not contest_id or contest_id == '' then
logger.log('Contest ID required', { level = vim.log.levels.ERROR })
return
end
if race_state.timer then
M.stop()
end
cache.load()
local display = constants.PLATFORM_DISPLAY_NAMES[platform] or platform
local cached_countdown = cache.get_supports_countdown(platform)
if cached_countdown == false then
logger.log(('%s does not support --race'):format(display), { level = vim.log.levels.ERROR })
return
end
local start_time = cache.get_contest_start_time(platform, contest_id)
if not start_time then
logger.log(
'Fetching contest list...',
{ level = vim.log.levels.INFO, override = true, sync = true }
)
local result = scraper.scrape_contest_list(platform)
if result then
local sc = result.supports_countdown
if sc == false then
cache.set_contest_summaries(platform, result.contests or {}, { supports_countdown = false })
logger.log(('%s does not support --race'):format(display), { level = vim.log.levels.ERROR })
return
end
if result.contests and #result.contests > 0 then
cache.set_contest_summaries(platform, result.contests, { supports_countdown = sc })
start_time = cache.get_contest_start_time(platform, contest_id)
end
end
end
if not start_time then
logger.log(
('No start time found for %s contest "%s"'):format(display, contest_id),
{ level = vim.log.levels.ERROR }
)
return
end
local token = vim.uv.hrtime()
local remaining = start_time - os.time()
if remaining <= 0 then
logger.log(
'Contest has already started, setting up...',
{ level = vim.log.levels.INFO, override = true }
)
race_state.token = token
race_try_setup(platform, contest_id, language, 1, token)
return
end
race_state.platform = platform
race_state.contest_id = contest_id
race_state.contest_name = cache.get_contest_display_name(platform, contest_id) or contest_id
race_state.language = language
race_state.start_time = start_time
race_state.last_refetch = os.time()
race_state.token = token
local timer = vim.uv.new_timer()
race_state.timer = timer
timer:start(
0,
1000,
vim.schedule_wrap(function()
if race_state.token ~= token then
return
end
local now = os.time()
if now - race_state.last_refetch >= REFETCH_INTERVAL_S then
race_state.last_refetch = now
refetch_start_time()
end
local r = race_state.start_time - now
if r <= 0 then
timer:stop()
timer:close()
race_state.timer = nil
local p = race_state.platform
local c = race_state.contest_id
local l = race_state.language
race_state.platform = nil
race_state.contest_id = nil
race_state.contest_name = nil
race_state.language = nil
race_state.start_time = nil
race_state.last_refetch = nil
logger.log('Contest started!', { level = vim.log.levels.INFO, override = true })
race_try_setup(p, c, l, 1, token)
elseif should_notify(r) then
logger.log(
('%s race "%s" starts in %s'):format(
constants.PLATFORM_DISPLAY_NAMES[race_state.platform] or race_state.platform,
race_state.contest_name,
format_countdown(r)
),
{ level = vim.log.levels.INFO, override = true }
)
end
end)
)
end
---@return nil
function M.stop()
local timer = race_state.timer
if not timer then
logger.log('No active race', { level = vim.log.levels.WARN })
return
end
local display = constants.PLATFORM_DISPLAY_NAMES[race_state.platform] or race_state.platform
local name = race_state.contest_name or race_state.contest_id
timer:stop()
timer:close()
race_state.timer = nil
race_state.token = nil
race_state.platform = nil
race_state.contest_id = nil
race_state.contest_name = nil
race_state.language = nil
race_state.start_time = nil
race_state.last_refetch = nil
logger.log(
('Cancelled %s race "%s"'):format(display, name),
{ level = vim.log.levels.INFO, override = true }
)
end
---@return { active: boolean, platform?: string, contest_id?: string, remaining_seconds?: integer }
function M.status()
if not race_state.timer or not race_state.start_time then
return { active = false }
end
return {
active = true,
platform = race_state.platform,
contest_id = race_state.contest_id,
remaining_seconds = math.max(0, race_state.start_time - os.time()),
}
end
return M

View file

@ -11,7 +11,7 @@ function M.restore_from_current_file()
local current_file = (vim.uv.fs_realpath(vim.fn.expand('%:p')) or vim.fn.expand('%:p')) local current_file = (vim.uv.fs_realpath(vim.fn.expand('%:p')) or vim.fn.expand('%:p'))
local file_state = cache.get_file_state(current_file) local file_state = cache.get_file_state(current_file)
if not file_state then if not file_state then
logger.log('No cached state found for current file.', { level = vim.log.levels.ERROR }) logger.log('No cached state found for current file.', vim.log.levels.ERROR)
return false return false
end end

View file

@ -33,9 +33,6 @@ local function substitute_template(cmd_template, substitutions)
return out return out
end end
---@param cmd_template string[]
---@param substitutions SubstitutableCommand
---@return string[]
function M.build_command(cmd_template, substitutions) function M.build_command(cmd_template, substitutions)
return substitute_template(cmd_template, substitutions) return substitute_template(cmd_template, substitutions)
end end
@ -46,7 +43,6 @@ end
function M.compile(compile_cmd, substitutions, on_complete) function M.compile(compile_cmd, substitutions, on_complete)
local cmd = substitute_template(compile_cmd, substitutions) local cmd = substitute_template(compile_cmd, substitutions)
local sh = table.concat(cmd, ' ') .. ' 2>&1' local sh = table.concat(cmd, ' ') .. ' 2>&1'
logger.log('compile: ' .. sh)
local t0 = vim.uv.hrtime() local t0 = vim.uv.hrtime()
vim.system({ 'sh', '-c', sh }, { text = false }, function(r) vim.system({ 'sh', '-c', sh }, { text = false }, function(r)
@ -55,7 +51,7 @@ function M.compile(compile_cmd, substitutions, on_complete)
r.stdout = ansi.bytes_to_string(r.stdout or '') r.stdout = ansi.bytes_to_string(r.stdout or '')
if r.code == 0 then if r.code == 0 then
logger.log(('Compilation successful in %.1fms.'):format(dt), { level = vim.log.levels.INFO }) logger.log(('Compilation successful in %.1fms.'):format(dt), vim.log.levels.INFO)
else else
logger.log(('Compilation failed in %.1fms.'):format(dt)) logger.log(('Compilation failed in %.1fms.'):format(dt))
end end
@ -123,7 +119,6 @@ function M.run(cmd, stdin, timeout_ms, memory_mb, on_complete)
local sec = math.ceil(timeout_ms / 1000) local sec = math.ceil(timeout_ms / 1000)
local timeout_prefix = ('%s -k 1s %ds '):format(timeout_bin, sec) local timeout_prefix = ('%s -k 1s %ds '):format(timeout_bin, sec)
local sh = prefix .. timeout_prefix .. ('%s -v sh -c %q 2>&1'):format(time_bin, prog) local sh = prefix .. timeout_prefix .. ('%s -v sh -c %q 2>&1'):format(time_bin, prog)
logger.log('run: ' .. sh)
local t0 = vim.uv.hrtime() local t0 = vim.uv.hrtime()
vim.system({ 'sh', '-c', sh }, { stdin = stdin, text = true }, function(r) vim.system({ 'sh', '-c', sh }, { stdin = stdin, text = true }, function(r)

View file

@ -19,7 +19,6 @@
---@class ProblemConstraints ---@class ProblemConstraints
---@field timeout_ms number ---@field timeout_ms number
---@field memory_mb number ---@field memory_mb number
---@field precision number?
---@class PanelState ---@class PanelState
---@field test_cases RanTestCase[] ---@field test_cases RanTestCase[]
@ -57,8 +56,7 @@ local function load_constraints_from_cache(platform, contest_id, problem_id)
cache.load() cache.load()
local timeout_ms, memory_mb = cache.get_constraints(platform, contest_id, problem_id) local timeout_ms, memory_mb = cache.get_constraints(platform, contest_id, problem_id)
if timeout_ms and memory_mb then if timeout_ms and memory_mb then
local precision = cache.get_precision(platform, contest_id, problem_id) return { timeout_ms = timeout_ms, memory_mb = memory_mb }
return { timeout_ms = timeout_ms, memory_mb = memory_mb, precision = precision }
end end
return nil return nil
end end
@ -101,53 +99,6 @@ local function build_command(cmd, substitutions)
return execute.build_command(cmd, substitutions) return execute.build_command(cmd, substitutions)
end end
---@param actual string
---@param expected string
---@param precision number?
---@return boolean
local function compare_outputs(actual, expected, precision)
local norm_actual = normalize_lines(actual)
local norm_expected = normalize_lines(expected)
if precision == nil or precision == 0 then
return norm_actual == norm_expected
end
local actual_lines = vim.split(norm_actual, '\n', { plain = true })
local expected_lines = vim.split(norm_expected, '\n', { plain = true })
if #actual_lines ~= #expected_lines then
return false
end
for i = 1, #actual_lines do
local a_tokens = vim.split(actual_lines[i], '%s+', { plain = false, trimempty = true })
local e_tokens = vim.split(expected_lines[i], '%s+', { plain = false, trimempty = true })
if #a_tokens ~= #e_tokens then
return false
end
for j = 1, #a_tokens do
local a_tok, e_tok = a_tokens[j], e_tokens[j]
local a_num = tonumber(a_tok)
local e_num = tonumber(e_tok)
if a_num ~= nil and e_num ~= nil then
if math.abs(a_num - e_num) > precision then
return false
end
else
if a_tok ~= e_tok then
return false
end
end
end
end
return true
end
---@param test_case RanTestCase ---@param test_case RanTestCase
---@param debug boolean? ---@param debug boolean?
---@param on_complete fun(result: { status: "pass"|"fail"|"tle"|"mle", actual: string, actual_highlights: Highlight[], error: string, stderr: string, time_ms: number, code: integer, ok: boolean, signal: string?, tled: boolean, mled: boolean, rss_mb: number }) ---@param on_complete fun(result: { status: "pass"|"fail"|"tle"|"mle", actual: string, actual_highlights: Highlight[], error: string, stderr: string, time_ms: number, code: integer, ok: boolean, signal: string?, tled: boolean, mled: boolean, rss_mb: number })
@ -192,9 +143,7 @@ local function run_single_test_case(test_case, debug, on_complete)
end end
local expected = test_case.expected or '' local expected = test_case.expected or ''
local precision = (panel_state.constraints and panel_state.constraints.precision) local ok = normalize_lines(out) == normalize_lines(expected)
or config.ui.panel.precision
local ok = compare_outputs(out, expected, precision)
local signal = r.signal local signal = r.signal
if not signal and r.code and r.code >= 128 then if not signal and r.code and r.code >= 128 then
@ -245,7 +194,7 @@ function M.load_test_cases()
state.get_problem_id() state.get_problem_id()
) )
logger.log(('Loaded %d test case(s)'):format(#tcs), { level = vim.log.levels.INFO }) logger.log(('Loaded %d test case(s)'):format(#tcs), vim.log.levels.INFO)
return #tcs > 0 return #tcs > 0
end end
@ -259,7 +208,7 @@ function M.run_combined_test(debug, on_complete)
) )
if not combined then if not combined then
logger.log('No combined test found', { level = vim.log.levels.ERROR }) logger.log('No combined test found', vim.log.levels.ERROR)
on_complete(nil) on_complete(nil)
return return
end end
@ -327,33 +276,26 @@ function M.run_all_test_cases(indices, debug, on_each, on_done)
end end
end end
if #to_run == 0 then local function run_next(pos)
logger.log( if pos > #to_run then
('Finished %s %d test cases.'):format(debug and 'debugging' or 'running', 0), logger.log(
{ level = vim.log.levels.INFO, override = true } ('Finished %s %d test cases.'):format(debug and 'debugging' or 'running', #to_run),
) vim.log.levels.INFO,
on_done(panel_state.test_cases) true
return )
end on_done(panel_state.test_cases)
return
end
local total = #to_run M.run_test_case(to_run[pos], debug, function()
local remaining = total
for _, idx in ipairs(to_run) do
M.run_test_case(idx, debug, function()
if on_each then if on_each then
on_each(idx, total) on_each(pos, #to_run)
end
remaining = remaining - 1
if remaining == 0 then
logger.log(
('Finished %s %d test cases.'):format(debug and 'debugging' or 'running', total),
{ level = vim.log.levels.INFO, override = true }
)
on_done(panel_state.test_cases)
end end
run_next(pos + 1)
end) end)
end end
run_next(1)
end end
---@return PanelState ---@return PanelState

View file

@ -376,7 +376,6 @@ function M.get_highlight_groups()
} }
end end
---@return nil
function M.setup_highlights() function M.setup_highlights()
local groups = M.get_highlight_groups() local groups = M.get_highlight_groups()
for name, opts in pairs(groups) do for name, opts in pairs(groups) do

View file

@ -5,141 +5,55 @@ local logger = require('cp.log')
local utils = require('cp.utils') local utils = require('cp.utils')
local function syshandle(result) local function syshandle(result)
local ok, data = pcall(vim.json.decode, result.stdout or '')
if ok then
return { success = true, data = data }
end
if result.code ~= 0 then if result.code ~= 0 then
local msg = 'Scraper failed: ' .. (result.stderr or 'Unknown error') local msg = 'Scraper failed: ' .. (result.stderr or 'Unknown error')
return { success = false, error = msg } return { success = false, error = msg }
end end
local msg = 'Failed to parse scraper output: ' .. tostring(data) local ok, data = pcall(vim.json.decode, result.stdout)
logger.log(msg, { level = vim.log.levels.ERROR }) if not ok then
return { success = false, error = msg } local msg = 'Failed to parse scraper output: ' .. tostring(data)
end logger.log(msg, vim.log.levels.ERROR)
return { success = false, error = msg }
---@param env_map table<string, string>
---@return string[]
local function spawn_env_list(env_map)
local out = {}
for key, value in pairs(env_map) do
out[#out + 1] = tostring(key) .. '=' .. tostring(value)
end end
return out
return { success = true, data = data }
end end
---@param platform string ---@param platform string
---@param subcommand string ---@param subcommand string
---@param args string[] ---@param args string[]
---@param opts { sync?: boolean, ndjson?: boolean, on_event?: fun(ev: table), on_exit?: fun(result: table), env_extra?: table<string, string>, stdin?: string } ---@param opts { sync?: boolean, ndjson?: boolean, on_event?: fun(ev: table), on_exit?: fun(result: table) }
local function run_scraper(platform, subcommand, args, opts) local function run_scraper(platform, subcommand, args, opts)
if not utils.setup_python_env() then
local msg = 'no Python environment available (install uv or nix)'
logger.log(msg, { level = vim.log.levels.ERROR })
if opts and opts.on_exit then
opts.on_exit({ success = false, error = msg })
end
return { success = false, error = msg }
end
local needs_browser = subcommand == 'submit'
or subcommand == 'login'
or (platform == 'codeforces' and (subcommand == 'metadata' or subcommand == 'tests'))
if needs_browser then
utils.setup_nix_submit_env()
end
local plugin_path = utils.get_plugin_path() local plugin_path = utils.get_plugin_path()
local cmd local cmd = { 'uv', 'run', '--directory', plugin_path, '-m', 'scrapers.' .. platform, subcommand }
if needs_browser then
cmd = utils.get_python_submit_cmd(platform, plugin_path)
else
cmd = utils.get_python_cmd(platform, plugin_path)
end
vim.list_extend(cmd, { subcommand })
vim.list_extend(cmd, args) vim.list_extend(cmd, args)
logger.log('scraper cmd: ' .. table.concat(cmd, ' '))
local env = vim.fn.environ() local env = vim.fn.environ()
env.VIRTUAL_ENV = '' env.VIRTUAL_ENV = ''
env.PYTHONPATH = '' env.PYTHONPATH = ''
env.CONDA_PREFIX = '' env.CONDA_PREFIX = ''
if opts and opts.env_extra then
for k, v in pairs(opts.env_extra) do
env[k] = v
end
end
if needs_browser and utils.is_nix_build() then
env.UV_PROJECT_ENVIRONMENT = vim.fn.stdpath('cache') .. '/cp-nvim/submit-env'
end
if opts and opts.ndjson then if opts and opts.ndjson then
local uv = vim.uv local uv = vim.loop
local stdin_pipe = nil
if opts.stdin then
stdin_pipe = uv.new_pipe(false)
end
local stdout = uv.new_pipe(false) local stdout = uv.new_pipe(false)
local stderr = uv.new_pipe(false) local stderr = uv.new_pipe(false)
local buf = '' local buf = ''
local timer = nil
local handle local handle
handle = uv.spawn(cmd[1], { handle = uv.spawn(
args = vim.list_slice(cmd, 2), cmd[1],
stdio = { stdin_pipe, stdout, stderr }, { args = vim.list_slice(cmd, 2), stdio = { nil, stdout, stderr }, env = env },
env = spawn_env_list(env), function(code, signal)
cwd = plugin_path, if buf ~= '' and opts.on_event then
}, function(code, signal) local ok_tail, ev_tail = pcall(vim.json.decode, buf)
if timer and not timer:is_closing() then if ok_tail then
timer:stop() opts.on_event(ev_tail)
timer:close() end
end buf = ''
if buf ~= '' and opts.on_event then
local ok_tail, ev_tail = pcall(vim.json.decode, buf)
if ok_tail then
opts.on_event(ev_tail)
end end
buf = '' if opts.on_exit then
end opts.on_exit({ success = (code == 0), code = code, signal = signal })
if opts.on_exit then
opts.on_exit({ success = (code == 0), code = code, signal = signal })
end
if stdin_pipe and not stdin_pipe:is_closing() then
stdin_pipe:close()
end
if not stdout:is_closing() then
stdout:close()
end
if not stderr:is_closing() then
stderr:close()
end
if handle and not handle:is_closing() then
handle:close()
end
end)
if not handle then
if stdin_pipe and not stdin_pipe:is_closing() then
stdin_pipe:close()
end
logger.log('Failed to start scraper process', { level = vim.log.levels.ERROR })
return { success = false, error = 'spawn failed' }
end
if needs_browser then
timer = uv.new_timer()
timer:start(120000, 0, function()
timer:stop()
timer:close()
if stdin_pipe and not stdin_pipe:is_closing() then
stdin_pipe:close()
end end
if not stdout:is_closing() then if not stdout:is_closing() then
stdout:close() stdout:close()
@ -148,21 +62,14 @@ local function run_scraper(platform, subcommand, args, opts)
stderr:close() stderr:close()
end end
if handle and not handle:is_closing() then if handle and not handle:is_closing() then
handle:kill(15)
handle:close() handle:close()
end end
if opts.on_exit then end
opts.on_exit({ success = false, error = 'submit timed out' }) )
end
end)
end
if stdin_pipe then if not handle then
uv.write(stdin_pipe, opts.stdin, function() logger.log('Failed to start scraper process', vim.log.levels.ERROR)
uv.shutdown(stdin_pipe, function() return { success = false, error = 'spawn failed' }
stdin_pipe:close()
end)
end)
end end
uv.read_start(stdout, function(_, data) uv.read_start(stdout, function(_, data)
@ -195,15 +102,7 @@ local function run_scraper(platform, subcommand, args, opts)
return return
end end
local sysopts = { local sysopts = { text = true, timeout = 30000, env = env }
text = true,
timeout = needs_browser and 120000 or 30000,
env = env,
cwd = plugin_path,
}
if opts and opts.stdin then
sysopts.stdin = opts.stdin
end
if opts and opts.sync then if opts and opts.sync then
local result = vim.system(cmd, sysopts):wait() local result = vim.system(cmd, sysopts):wait()
return syshandle(result) return syshandle(result)
@ -216,11 +115,7 @@ local function run_scraper(platform, subcommand, args, opts)
end end
end end
---@param platform string function M.scrape_contest_metadata(platform, contest_id, callback)
---@param contest_id string
---@param callback fun(data: table)?
---@param on_error fun()?
function M.scrape_contest_metadata(platform, contest_id, callback, on_error)
run_scraper(platform, 'metadata', { contest_id }, { run_scraper(platform, 'metadata', { contest_id }, {
on_exit = function(result) on_exit = function(result)
if not result or not result.success then if not result or not result.success then
@ -229,11 +124,8 @@ function M.scrape_contest_metadata(platform, contest_id, callback, on_error)
constants.PLATFORM_DISPLAY_NAMES[platform], constants.PLATFORM_DISPLAY_NAMES[platform],
contest_id contest_id
), ),
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
if type(on_error) == 'function' then
on_error()
end
return return
end end
local data = result.data or {} local data = result.data or {}
@ -243,11 +135,8 @@ function M.scrape_contest_metadata(platform, contest_id, callback, on_error)
constants.PLATFORM_DISPLAY_NAMES[platform], constants.PLATFORM_DISPLAY_NAMES[platform],
contest_id contest_id
), ),
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
if type(on_error) == 'function' then
on_error()
end
return return
end end
if type(callback) == 'function' then if type(callback) == 'function' then
@ -257,8 +146,6 @@ function M.scrape_contest_metadata(platform, contest_id, callback, on_error)
}) })
end end
---@param platform string
---@return { contests: ContestSummary[], supports_countdown: boolean }?
function M.scrape_contest_list(platform) function M.scrape_contest_list(platform)
local result = run_scraper(platform, 'contests', {}, { sync = true }) local result = run_scraper(platform, 'contests', {}, { sync = true })
if not result or not result.success or not (result.data and result.data.contests) then if not result or not result.success or not (result.data and result.data.contests) then
@ -267,28 +154,19 @@ function M.scrape_contest_list(platform)
platform, platform,
(result and result.error) or 'unknown' (result and result.error) or 'unknown'
), ),
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
return nil return {}
end end
return { return result.data.contests
contests = result.data.contests,
supports_countdown = result.data.supports_countdown ~= false,
}
end end
---@param platform string ---@param platform string
---@param contest_id string ---@param contest_id string
---@param callback fun(data: table)|nil ---@param callback fun(data: table)|nil
---@param on_done fun()|nil function M.scrape_all_tests(platform, contest_id, callback)
function M.scrape_all_tests(platform, contest_id, callback, on_done)
run_scraper(platform, 'tests', { contest_id }, { run_scraper(platform, 'tests', { contest_id }, {
ndjson = true, ndjson = true,
on_exit = function()
if type(on_done) == 'function' then
vim.schedule(on_done)
end
end,
on_event = function(ev) on_event = function(ev)
if ev.done then if ev.done then
return return
@ -300,7 +178,7 @@ function M.scrape_all_tests(platform, contest_id, callback, on_done)
contest_id, contest_id,
ev.error ev.error
), ),
{ level = vim.log.levels.WARN } vim.log.levels.WARN
) )
return return
end end
@ -327,7 +205,6 @@ function M.scrape_all_tests(platform, contest_id, callback, on_done)
memory_mb = ev.memory_mb or 0, memory_mb = ev.memory_mb or 0,
interactive = ev.interactive or false, interactive = ev.interactive or false,
multi_test = ev.multi_test or false, multi_test = ev.multi_test or false,
precision = ev.precision ~= vim.NIL and ev.precision or nil,
problem_id = ev.problem_id, problem_id = ev.problem_id,
}) })
end end
@ -336,87 +213,4 @@ function M.scrape_all_tests(platform, contest_id, callback, on_done)
}) })
end end
---@param platform string
---@param credentials table
---@param on_status fun(ev: table)?
---@param callback fun(result: table)?
function M.login(platform, credentials, on_status, callback)
local done = false
run_scraper(platform, 'login', {}, {
ndjson = true,
env_extra = { CP_CREDENTIALS = vim.json.encode(credentials) },
on_event = function(ev)
if ev.credentials ~= nil and next(ev.credentials) ~= nil then
require('cp.cache').set_credentials(platform, ev.credentials)
end
if ev.status ~= nil then
if type(on_status) == 'function' then
on_status(ev)
end
elseif ev.success ~= nil then
done = true
if type(callback) == 'function' then
callback(ev)
end
end
end,
on_exit = function(proc)
if not done and type(callback) == 'function' then
callback({
success = false,
error = 'login process exited (code=' .. tostring(proc.code) .. ')',
})
end
end,
})
end
---@param platform string
---@param contest_id string
---@param problem_id string
---@param language string
---@param source_file string
---@param credentials table
---@param on_status fun(ev: table)?
---@param callback fun(result: table)?
function M.submit(
platform,
contest_id,
problem_id,
language,
source_file,
credentials,
on_status,
callback
)
local done = false
run_scraper(platform, 'submit', { contest_id, problem_id, language, source_file }, {
ndjson = true,
env_extra = { CP_CREDENTIALS = vim.json.encode(credentials) },
on_event = function(ev)
if ev.credentials ~= nil then
require('cp.cache').set_credentials(platform, ev.credentials)
end
if ev.status ~= nil then
if type(on_status) == 'function' then
on_status(ev)
end
elseif ev.success ~= nil then
done = true
if type(callback) == 'function' then
callback(ev)
end
end
end,
on_exit = function(proc)
if not done and type(callback) == 'function' then
callback({
success = false,
error = 'submit process exited (code=' .. tostring(proc.code) .. ')',
})
end
end,
})
end
return M return M

View file

@ -8,39 +8,6 @@ local logger = require('cp.log')
local scraper = require('cp.scraper') local scraper = require('cp.scraper')
local state = require('cp.state') local state = require('cp.state')
local function apply_template(bufnr, lang_id, platform)
local config = config_module.get_config()
local eff = config.runtime.effective[platform] and config.runtime.effective[platform][lang_id]
if not eff or not eff.template then
return
end
local path = vim.fn.expand(eff.template)
if vim.fn.filereadable(path) ~= 1 then
logger.log(
('[cp.nvim] template not readable: %s'):format(path),
{ level = vim.log.levels.WARN }
)
return
end
local lines = vim.fn.readfile(path)
vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines)
local marker = config.templates and config.templates.cursor_marker
if marker then
for lnum, line in ipairs(lines) do
local col = line:find(marker, 1, true)
if col then
local new_line = line:sub(1, col - 1) .. line:sub(col + #marker)
vim.api.nvim_buf_set_lines(bufnr, lnum - 1, lnum, false, { new_line })
local winid = vim.fn.bufwinid(bufnr)
if winid ~= -1 then
vim.api.nvim_win_set_cursor(winid, { lnum, col - 1 })
end
break
end
end
end
end
---Get the language of the current file from cache ---Get the language of the current file from cache
---@return string? ---@return string?
local function get_current_file_language() local function get_current_file_language()
@ -115,15 +82,11 @@ local function start_tests(platform, contest_id, problems)
return not vim.tbl_isempty(cache.get_test_cases(platform, contest_id, p.id)) return not vim.tbl_isempty(cache.get_test_cases(platform, contest_id, p.id))
end, problems) end, problems)
if cached_len ~= #problems then if cached_len ~= #problems then
local to_fetch = #problems - cached_len
logger.log(('Fetching %s/%s problem tests...'):format(cached_len, #problems)) logger.log(('Fetching %s/%s problem tests...'):format(cached_len, #problems))
scraper.scrape_all_tests(platform, contest_id, function(ev) scraper.scrape_all_tests(platform, contest_id, function(ev)
local cached_tests = {} local cached_tests = {}
if not ev.interactive and vim.tbl_isempty(ev.tests) then if not ev.interactive and vim.tbl_isempty(ev.tests) then
logger.log( logger.log(("No tests found for problem '%s'."):format(ev.problem_id), vim.log.levels.WARN)
("No tests found for problem '%s'."):format(ev.problem_id),
{ level = vim.log.levels.WARN }
)
end end
for i, t in ipairs(ev.tests) do for i, t in ipairs(ev.tests) do
cached_tests[i] = { index = i, input = t.input, expected = t.expected } cached_tests[i] = { index = i, input = t.input, expected = t.expected }
@ -137,8 +100,7 @@ local function start_tests(platform, contest_id, problems)
ev.timeout_ms or 0, ev.timeout_ms or 0,
ev.memory_mb or 0, ev.memory_mb or 0,
ev.interactive, ev.interactive,
ev.multi_test, ev.multi_test
ev.precision
) )
local io_state = state.get_io_view_state() local io_state = state.get_io_view_state()
@ -149,11 +111,6 @@ local function start_tests(platform, contest_id, problems)
require('cp.utils').update_buffer_content(io_state.input_buf, input_lines, nil, nil) require('cp.utils').update_buffer_content(io_state.input_buf, input_lines, nil, nil)
end end
end end
end, function()
logger.log(
('Loaded %d test%s.'):format(to_fetch, to_fetch == 1 and '' or 's'),
{ level = vim.log.levels.INFO, override = true }
)
end) end)
end end
end end
@ -171,39 +128,24 @@ function M.setup_contest(platform, contest_id, problem_id, language)
if language then if language then
local lang_result = config_module.get_language_for_platform(platform, language) local lang_result = config_module.get_language_for_platform(platform, language)
if not lang_result.valid then if not lang_result.valid then
logger.log(lang_result.error, { level = vim.log.levels.ERROR }) logger.log(lang_result.error, vim.log.levels.ERROR)
return return
end end
end end
local is_new_contest = old_platform ~= platform or old_contest_id ~= contest_id local is_new_contest = old_platform ~= platform and old_contest_id ~= contest_id
if is_new_contest then
local views = require('cp.ui.views')
views.cancel_io_view()
local active = state.get_active_panel()
if active == 'interactive' then
views.cancel_interactive()
elseif active == 'stress' then
require('cp.stress').cancel()
elseif active == 'run' then
views.disable()
end
end
cache.load() cache.load()
local function proceed(contest_data) local function proceed(contest_data)
if is_new_contest then
local io_state = state.get_io_view_state()
if io_state and io_state.output_buf and vim.api.nvim_buf_is_valid(io_state.output_buf) then
require('cp.utils').update_buffer_content(io_state.output_buf, {}, nil, nil)
end
end
local problems = contest_data.problems local problems = contest_data.problems
local pid = problem_id and problem_id or problems[1].id local pid = problem_id and problem_id or problems[1].id
M.setup_problem(pid, language) M.setup_problem(pid, language)
start_tests(platform, contest_id, problems) start_tests(platform, contest_id, problems)
if config_module.get_config().open_url and is_new_contest and contest_data.url then
vim.ui.open(contest_data.url:format(pid))
end
end end
local contest_data = cache.get_contest_data(platform, contest_id) local contest_data = cache.get_contest_data(platform, contest_id)
@ -218,7 +160,12 @@ function M.setup_contest(platform, contest_id, problem_id, language)
vim.bo[bufnr].buftype = '' vim.bo[bufnr].buftype = ''
vim.bo[bufnr].swapfile = false vim.bo[bufnr].swapfile = false
state.set_language(lang) if cfg.hooks and cfg.hooks.setup_code and not vim.b[bufnr].cp_setup_done then
local ok = pcall(cfg.hooks.setup_code, state)
if ok then
vim.b[bufnr].cp_setup_done = true
end
end
state.set_provisional({ state.set_provisional({
bufnr = bufnr, bufnr = bufnr,
@ -226,23 +173,16 @@ function M.setup_contest(platform, contest_id, problem_id, language)
contest_id = contest_id, contest_id = contest_id,
language = lang, language = lang,
requested_problem_id = problem_id, requested_problem_id = problem_id,
token = vim.uv.hrtime(), token = vim.loop.hrtime(),
}) })
logger.log('Fetching contests problems...', { level = vim.log.levels.INFO, override = true }) logger.log('Fetching contests problems...', vim.log.levels.INFO, true)
scraper.scrape_contest_metadata( scraper.scrape_contest_metadata(
platform, platform,
contest_id, contest_id,
vim.schedule_wrap(function(result) vim.schedule_wrap(function(result)
local problems = result.problems or {} local problems = result.problems or {}
cache.set_contest_data( cache.set_contest_data(platform, contest_id, problems, result.url)
platform,
contest_id,
problems,
result.url,
result.contest_url or '',
result.standings_url or ''
)
local prov = state.get_provisional() local prov = state.get_provisional()
if not prov or prov.platform ~= platform or prov.contest_id ~= contest_id then if not prov or prov.platform ~= platform or prov.contest_id ~= contest_id then
return return
@ -272,7 +212,7 @@ end
function M.setup_problem(problem_id, language) function M.setup_problem(problem_id, language)
local platform = state.get_platform() local platform = state.get_platform()
if not platform then if not platform then
logger.log('No platform/contest/problem configured.', { level = vim.log.levels.ERROR }) logger.log('No platform/contest/problem configured.', vim.log.levels.ERROR)
return return
end end
@ -293,7 +233,7 @@ function M.setup_problem(problem_id, language)
if language then if language then
local lang_result = config_module.get_language_for_platform(platform, language) local lang_result = config_module.get_language_for_platform(platform, language)
if not lang_result.valid then if not lang_result.valid then
logger.log(lang_result.error, { level = vim.log.levels.ERROR }) logger.log(lang_result.error, vim.log.levels.ERROR)
return return
end end
end end
@ -305,38 +245,7 @@ function M.setup_problem(problem_id, language)
return return
end end
if vim.fn.filereadable(source_file) == 1 then vim.fn.mkdir(vim.fn.fnamemodify(source_file, ':h'), 'p')
local existing = cache.get_file_state(vim.fn.fnamemodify(source_file, ':p'))
if
existing
and (
existing.platform ~= platform
or existing.contest_id ~= (state.get_contest_id() or '')
or existing.problem_id ~= problem_id
)
then
logger.log(
('File %q already exists for %s/%s %s.'):format(
source_file,
existing.platform,
existing.contest_id,
existing.problem_id
),
{ level = vim.log.levels.ERROR }
)
return
end
end
local contest_dir = vim.fn.fnamemodify(source_file, ':h')
local is_new_dir = vim.fn.isdirectory(contest_dir) == 0
vim.fn.mkdir(contest_dir, 'p')
if is_new_dir then
local s = config.hooks and config.hooks.setup
if s and s.contest then
pcall(s.contest, state)
end
end
local prov = state.get_provisional() local prov = state.get_provisional()
if prov and prov.platform == platform and prov.contest_id == (state.get_contest_id() or '') then if prov and prov.platform == platform and prov.contest_id == (state.get_contest_id() or '') then
@ -347,6 +256,7 @@ function M.setup_problem(problem_id, language)
state.set_provisional(nil) state.set_provisional(nil)
else else
vim.api.nvim_buf_set_name(prov.bufnr, source_file) vim.api.nvim_buf_set_name(prov.bufnr, source_file)
vim.bo[prov.bufnr].swapfile = true
-- selene: allow(mixed_table) -- selene: allow(mixed_table)
vim.cmd.write({ vim.cmd.write({
vim.fn.fnameescape(source_file), vim.fn.fnameescape(source_file),
@ -354,29 +264,14 @@ function M.setup_problem(problem_id, language)
mods = { silent = true, noautocmd = true, keepalt = true }, mods = { silent = true, noautocmd = true, keepalt = true },
}) })
state.set_solution_win(vim.api.nvim_get_current_win()) state.set_solution_win(vim.api.nvim_get_current_win())
if not vim.b[prov.bufnr].cp_setup_done then if config.hooks and config.hooks.setup_code and not vim.b[prov.bufnr].cp_setup_done then
apply_template(prov.bufnr, lang, platform) local ok = pcall(config.hooks.setup_code, state)
local s = config.hooks and config.hooks.setup if ok then
if s and s.code then
local ok = pcall(s.code, state)
if ok then
vim.b[prov.bufnr].cp_setup_done = true
end
else
helpers.clearcol(prov.bufnr)
vim.b[prov.bufnr].cp_setup_done = true vim.b[prov.bufnr].cp_setup_done = true
end end
local o = config.hooks and config.hooks.on elseif not vim.b[prov.bufnr].cp_setup_done then
if o and o.enter then helpers.clearcol(prov.bufnr)
local bufnr = prov.bufnr vim.b[prov.bufnr].cp_setup_done = true
vim.api.nvim_create_autocmd('BufEnter', {
buffer = bufnr,
callback = function()
pcall(o.enter, state)
end,
})
pcall(o.enter, state)
end
end end
cache.set_file_state( cache.set_file_state(
vim.fn.fnamemodify(source_file, ':p'), vim.fn.fnamemodify(source_file, ':p'),
@ -395,39 +290,18 @@ function M.setup_problem(problem_id, language)
end end
vim.cmd.only({ mods = { silent = true } }) vim.cmd.only({ mods = { silent = true } })
local current_file = vim.fn.expand('%:p') vim.cmd.e(source_file)
if current_file ~= vim.fn.fnamemodify(source_file, ':p') then
vim.cmd.e(source_file)
end
local bufnr = vim.api.nvim_get_current_buf() local bufnr = vim.api.nvim_get_current_buf()
state.set_solution_win(vim.api.nvim_get_current_win()) state.set_solution_win(vim.api.nvim_get_current_win())
require('cp.ui.views').ensure_io_view() require('cp.ui.views').ensure_io_view()
if not vim.b[bufnr].cp_setup_done then if config.hooks and config.hooks.setup_code and not vim.b[bufnr].cp_setup_done then
local is_new = vim.api.nvim_buf_line_count(bufnr) == 1 local ok = pcall(config.hooks.setup_code, state)
and vim.api.nvim_buf_get_lines(bufnr, 0, 1, false)[1] == '' if ok then
if is_new then
apply_template(bufnr, lang, platform)
end
local s = config.hooks and config.hooks.setup
if s and s.code then
local ok = pcall(s.code, state)
if ok then
vim.b[bufnr].cp_setup_done = true
end
else
helpers.clearcol(bufnr)
vim.b[bufnr].cp_setup_done = true vim.b[bufnr].cp_setup_done = true
end end
local o = config.hooks and config.hooks.on elseif not vim.b[bufnr].cp_setup_done then
if o and o.enter then helpers.clearcol(bufnr)
vim.api.nvim_create_autocmd('BufEnter', { vim.b[bufnr].cp_setup_done = true
buffer = bufnr,
callback = function()
pcall(o.enter, state)
end,
})
pcall(o.enter, state)
end
end end
cache.set_file_state( cache.set_file_state(
vim.fn.expand('%:p'), vim.fn.expand('%:p'),
@ -450,7 +324,7 @@ function M.navigate_problem(direction, language)
local contest_id = state.get_contest_id() local contest_id = state.get_contest_id()
local current_problem_id = state.get_problem_id() local current_problem_id = state.get_problem_id()
if not platform or not contest_id or not current_problem_id then if not platform or not contest_id or not current_problem_id then
logger.log('No platform configured.', { level = vim.log.levels.ERROR }) logger.log('No platform configured.', vim.log.levels.ERROR)
return return
end end
@ -462,7 +336,7 @@ function M.navigate_problem(direction, language)
constants.PLATFORM_DISPLAY_NAMES[platform], constants.PLATFORM_DISPLAY_NAMES[platform],
contest_id contest_id
), ),
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
return return
end end
@ -476,15 +350,9 @@ function M.navigate_problem(direction, language)
logger.log(('navigate_problem: %s -> %s'):format(current_problem_id, problems[new_index].id)) logger.log(('navigate_problem: %s -> %s'):format(current_problem_id, problems[new_index].id))
local views = require('cp.ui.views')
views.cancel_io_view()
local active_panel = state.get_active_panel() local active_panel = state.get_active_panel()
if active_panel == 'run' then if active_panel == 'run' then
views.disable() require('cp.ui.views').disable()
elseif active_panel == 'interactive' then
views.cancel_interactive()
elseif active_panel == 'stress' then
require('cp.stress').cancel()
end end
local lang = nil local lang = nil
@ -492,7 +360,7 @@ function M.navigate_problem(direction, language)
if language then if language then
local lang_result = config_module.get_language_for_platform(platform, language) local lang_result = config_module.get_language_for_platform(platform, language)
if not lang_result.valid then if not lang_result.valid then
logger.log(lang_result.error, { level = vim.log.levels.ERROR }) logger.log(lang_result.error, vim.log.levels.ERROR)
return return
end end
lang = language lang = language

View file

@ -9,6 +9,7 @@
---@class cp.IoViewState ---@class cp.IoViewState
---@field output_buf integer ---@field output_buf integer
---@field input_buf integer ---@field input_buf integer
---@field current_test_index integer?
---@field source_buf integer? ---@field source_buf integer?
---@class cp.State ---@class cp.State

View file

@ -1,259 +0,0 @@
local M = {}
local logger = require('cp.log')
local state = require('cp.state')
local utils = require('cp.utils')
local GENERATOR_PATTERNS = {
'gen.py',
'gen.cc',
'gen.cpp',
'generator.py',
'generator.cc',
'generator.cpp',
}
local BRUTE_PATTERNS = {
'brute.py',
'brute.cc',
'brute.cpp',
'slow.py',
'slow.cc',
'slow.cpp',
}
local function find_file(patterns)
for _, pattern in ipairs(patterns) do
if vim.fn.filereadable(pattern) == 1 then
return pattern
end
end
return nil
end
local function compile_cpp(source, output)
local result = vim.system({ 'sh', '-c', 'g++ -O2 -o ' .. output .. ' ' .. source }):wait()
if result.code ~= 0 then
logger.log(
('Failed to compile %s: %s'):format(source, result.stderr or ''),
{ level = vim.log.levels.ERROR }
)
return false
end
return true
end
local function build_run_cmd(file)
local ext = file:match('%.([^%.]+)$')
if ext == 'cc' or ext == 'cpp' then
local base = file:gsub('%.[^%.]+$', '')
local bin = base .. '_bin'
if not compile_cpp(file, bin) then
return nil
end
return './' .. bin
elseif ext == 'py' then
return 'python ' .. file
end
return './' .. file
end
---@param generator_cmd? string
---@param brute_cmd? string
function M.toggle(generator_cmd, brute_cmd)
if state.get_active_panel() == 'stress' then
if state.stress_buf and vim.api.nvim_buf_is_valid(state.stress_buf) then
local job = vim.b[state.stress_buf].terminal_job_id
if job then
vim.fn.jobstop(job)
end
end
if state.saved_stress_session then
vim.cmd.source(state.saved_stress_session)
vim.fn.delete(state.saved_stress_session)
state.saved_stress_session = nil
end
state.set_active_panel(nil)
require('cp.ui.views').ensure_io_view()
return
end
if state.get_active_panel() then
logger.log('Another panel is already active.', { level = vim.log.levels.WARN })
return
end
local gen_file = generator_cmd
local brute_file = brute_cmd
if not gen_file then
gen_file = find_file(GENERATOR_PATTERNS)
end
if not brute_file then
brute_file = find_file(BRUTE_PATTERNS)
end
if not gen_file then
logger.log(
'No generator found. Pass generator as first arg or add gen.{py,cc,cpp}.',
{ level = vim.log.levels.ERROR }
)
return
end
if not brute_file then
logger.log(
'No brute solution found. Pass brute as second arg or add brute.{py,cc,cpp}.',
{ level = vim.log.levels.ERROR }
)
return
end
local gen_cmd = build_run_cmd(gen_file)
if not gen_cmd then
return
end
local brute_run_cmd = build_run_cmd(brute_file)
if not brute_run_cmd then
return
end
state.saved_stress_session = vim.fn.tempname()
-- selene: allow(mixed_table)
vim.cmd.mksession({ state.saved_stress_session, bang = true })
vim.cmd.only({ mods = { silent = true } })
local execute = require('cp.runner.execute')
local function restore_session()
if state.saved_stress_session then
vim.cmd.source(state.saved_stress_session)
vim.fn.delete(state.saved_stress_session)
state.saved_stress_session = nil
end
require('cp.ui.views').ensure_io_view()
end
execute.compile_problem(false, function(compile_result)
if not compile_result.success then
local run = require('cp.runner.run')
run.handle_compilation_failure(compile_result.output)
restore_session()
return
end
local binary = state.get_binary_file()
if not binary or binary == '' then
logger.log('No binary produced.', { level = vim.log.levels.ERROR })
restore_session()
return
end
local script = vim.fn.fnamemodify(utils.get_plugin_path() .. '/scripts/stress.py', ':p')
local cmdline
if utils.is_nix_build() then
cmdline = table.concat({
vim.fn.shellescape(utils.get_nix_python()),
vim.fn.shellescape(script),
vim.fn.shellescape(gen_cmd),
vim.fn.shellescape(brute_run_cmd),
vim.fn.shellescape(binary),
}, ' ')
else
cmdline = table.concat({
'uv',
'run',
vim.fn.shellescape(script),
vim.fn.shellescape(gen_cmd),
vim.fn.shellescape(brute_run_cmd),
vim.fn.shellescape(binary),
}, ' ')
end
vim.cmd.terminal(cmdline)
local term_buf = vim.api.nvim_get_current_buf()
pcall(
vim.api.nvim_buf_set_name,
term_buf,
("term://stress.py '%s' '%s' '%s'"):format(gen_cmd, brute_run_cmd, binary)
)
local term_win = vim.api.nvim_get_current_win()
local cleaned = false
local function cleanup()
if cleaned then
return
end
cleaned = true
if term_buf and vim.api.nvim_buf_is_valid(term_buf) then
local job = vim.b[term_buf] and vim.b[term_buf].terminal_job_id or nil
if job then
pcall(vim.fn.jobstop, job)
end
end
restore_session()
state.stress_buf = nil
state.stress_win = nil
state.set_active_panel(nil)
end
vim.api.nvim_create_autocmd({ 'BufWipeout', 'BufUnload' }, {
buffer = term_buf,
callback = cleanup,
})
vim.api.nvim_create_autocmd('WinClosed', {
callback = function()
if cleaned then
return
end
local any = false
for _, win in ipairs(vim.api.nvim_list_wins()) do
if vim.api.nvim_win_is_valid(win) and vim.api.nvim_win_get_buf(win) == term_buf then
any = true
break
end
end
if not any then
cleanup()
end
end,
})
vim.api.nvim_create_autocmd('TermClose', {
buffer = term_buf,
callback = function()
vim.b[term_buf].cp_stress_exited = true
end,
})
vim.keymap.set('t', '<c-q>', function()
cleanup()
end, { buffer = term_buf, silent = true })
vim.keymap.set('n', '<c-q>', function()
cleanup()
end, { buffer = term_buf, silent = true })
state.stress_buf = term_buf
state.stress_win = term_win
state.set_active_panel('stress')
end)
end
---@return nil
function M.cancel()
if state.stress_buf and vim.api.nvim_buf_is_valid(state.stress_buf) then
local job = vim.b[state.stress_buf].terminal_job_id
if job then
vim.fn.jobstop(job)
end
end
if state.saved_stress_session then
vim.fn.delete(state.saved_stress_session)
state.saved_stress_session = nil
end
state.set_active_panel(nil)
end
return M

View file

@ -1,116 +0,0 @@
local M = {}
local cache = require('cp.cache')
local config = require('cp.config')
local constants = require('cp.constants')
local logger = require('cp.log')
local state = require('cp.state')
local STATUS_MSGS = {
installing_browser = 'Installing browser (first time setup)...',
checking_login = 'Checking login...',
logging_in = 'Logging in...',
submitting = 'Submitting...',
}
local function prompt_credentials(platform, callback)
local saved = cache.get_credentials(platform)
if saved and saved.username and saved.password then
callback(saved)
return
end
vim.ui.input({ prompt = platform .. ' username: ' }, function(username)
if not username or username == '' then
logger.log('Submit cancelled', { level = vim.log.levels.WARN })
return
end
vim.fn.inputsave()
local password = vim.fn.inputsecret(platform .. ' password: ')
vim.fn.inputrestore()
vim.cmd.redraw()
if not password or password == '' then
logger.log('Submit cancelled', { level = vim.log.levels.WARN })
return
end
local creds = { username = username, password = password }
cache.set_credentials(platform, creds)
callback(creds)
end)
end
---@param opts { language?: string }?
function M.submit(opts)
local platform = state.get_platform()
local contest_id = state.get_contest_id()
local problem_id = state.get_problem_id()
local language = (opts and opts.language) or state.get_language()
if not platform or not contest_id or not problem_id or not language then
logger.log(
'No active problem. Use :CP <platform> <contest> first.',
{ level = vim.log.levels.ERROR }
)
return
end
local source_file = state.get_source_file()
if not source_file or vim.fn.filereadable(source_file) ~= 1 then
logger.log('Source file not found', { level = vim.log.levels.ERROR })
return
end
source_file = vim.fn.fnamemodify(source_file, ':p')
local submit_language = language
local cfg = config.get_config()
local plat_effective = cfg.runtime and cfg.runtime.effective and cfg.runtime.effective[platform]
local eff = plat_effective and plat_effective[language]
if eff then
if eff.submit_id then
submit_language = eff.submit_id or submit_language
else
local ver = eff.version or constants.DEFAULT_VERSIONS[language]
if ver then
local versions = (constants.LANGUAGE_VERSIONS[platform] or {})[language]
if versions and versions[ver] then
submit_language = versions[ver] or submit_language
end
end
end
end
prompt_credentials(platform, function(creds)
vim.cmd.update()
logger.log('Submitting...', { level = vim.log.levels.INFO, override = true })
require('cp.scraper').submit(
platform,
contest_id,
problem_id,
submit_language,
source_file,
creds,
function(ev)
vim.schedule(function()
logger.log(
STATUS_MSGS[ev.status] or ev.status,
{ level = vim.log.levels.INFO, override = true }
)
end)
end,
function(result)
vim.schedule(function()
if result and result.success then
logger.log('Submitted successfully', { level = vim.log.levels.INFO, override = true })
else
local err = result and result.error or 'unknown error'
if err:match('^Login failed') then
cache.clear_credentials(platform)
end
logger.log('Submit failed: ' .. err, { level = vim.log.levels.ERROR })
end
end)
end
)
end)
end
return M

View file

@ -90,7 +90,7 @@ local function delete_current_test()
return return
end end
if #edit_state.test_buffers == 1 then if #edit_state.test_buffers == 1 then
logger.log('Problems must have at least one test case.', { level = vim.log.levels.ERROR }) logger.log('Problems must have at least one test case.', vim.log.levels.ERROR)
return return
end end
@ -144,7 +144,7 @@ local function add_new_test()
vim.api.nvim_win_set_buf(input_win, input_buf) vim.api.nvim_win_set_buf(input_win, input_buf)
vim.bo[input_buf].modifiable = true vim.bo[input_buf].modifiable = true
vim.bo[input_buf].readonly = false vim.bo[input_buf].readonly = false
vim.bo[input_buf].buftype = 'acwrite' vim.bo[input_buf].buftype = 'nofile'
vim.bo[input_buf].buflisted = false vim.bo[input_buf].buflisted = false
helpers.clearcol(input_buf) helpers.clearcol(input_buf)
@ -155,7 +155,7 @@ local function add_new_test()
vim.api.nvim_win_set_buf(expected_win, expected_buf) vim.api.nvim_win_set_buf(expected_win, expected_buf)
vim.bo[expected_buf].modifiable = true vim.bo[expected_buf].modifiable = true
vim.bo[expected_buf].readonly = false vim.bo[expected_buf].readonly = false
vim.bo[expected_buf].buftype = 'acwrite' vim.bo[expected_buf].buftype = 'nofile'
vim.bo[expected_buf].buflisted = false vim.bo[expected_buf].buflisted = false
helpers.clearcol(expected_buf) helpers.clearcol(expected_buf)
@ -177,80 +177,6 @@ local function add_new_test()
logger.log(('Added test %d'):format(new_index)) logger.log(('Added test %d'):format(new_index))
end end
local function save_all_tests()
if not edit_state then
return
end
local platform = state.get_platform()
local contest_id = state.get_contest_id()
local problem_id = state.get_problem_id()
if not platform or not contest_id or not problem_id then
return
end
for i, pair in ipairs(edit_state.test_buffers) do
if
vim.api.nvim_buf_is_valid(pair.input_buf) and vim.api.nvim_buf_is_valid(pair.expected_buf)
then
local input_lines = vim.api.nvim_buf_get_lines(pair.input_buf, 0, -1, false)
local expected_lines = vim.api.nvim_buf_get_lines(pair.expected_buf, 0, -1, false)
edit_state.test_cases[i].input = table.concat(input_lines, '\n')
edit_state.test_cases[i].expected = table.concat(expected_lines, '\n')
end
end
local contest_data = cache.get_contest_data(platform, contest_id)
local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test
or false
local combined_input = table.concat(
vim.tbl_map(function(tc)
return tc.input
end, edit_state.test_cases),
'\n'
)
local combined_expected = table.concat(
vim.tbl_map(function(tc)
return tc.expected
end, edit_state.test_cases),
'\n'
)
cache.set_test_cases(
platform,
contest_id,
problem_id,
{ input = combined_input, expected = combined_expected },
edit_state.test_cases,
edit_state.constraints and edit_state.constraints.timeout_ms or 0,
edit_state.constraints and edit_state.constraints.memory_mb or 0,
false,
is_multi_test
)
local config = config_module.get_config()
local base_name = config.filename and config.filename(platform, contest_id, problem_id, config)
or config_module.default_filename(contest_id, problem_id)
vim.fn.mkdir('io', 'p')
for i, tc in ipairs(edit_state.test_cases) do
local input_file = string.format('io/%s.%d.cpin', base_name, i)
local expected_file = string.format('io/%s.%d.cpout', base_name, i)
local input_content = (tc.input or ''):gsub('\r', '')
local expected_content = (tc.expected or ''):gsub('\r', '')
vim.fn.writefile(vim.split(input_content, '\n', { trimempty = true }), input_file)
vim.fn.writefile(vim.split(expected_content, '\n', { trimempty = true }), expected_file)
end
logger.log('Saved all test cases')
end
---@param buf integer ---@param buf integer
setup_keybindings = function(buf) setup_keybindings = function(buf)
local config = config_module.get_config() local config = config_module.get_config()
@ -311,40 +237,92 @@ setup_keybindings = function(buf)
end end
if is_tracked then if is_tracked then
logger.log( logger.log('Test buffer closed unexpectedly. Exiting editor.', vim.log.levels.WARN)
'Test buffer closed unexpectedly. Exiting editor.',
{ level = vim.log.levels.WARN }
)
M.toggle_edit() M.toggle_edit()
end end
end) end)
end, end,
}) })
vim.api.nvim_create_autocmd('BufWriteCmd', {
group = augroup,
buffer = buf,
callback = function()
save_all_tests()
vim.bo[buf].modified = false
end,
})
end end
---@param test_index? integer local function save_all_tests()
if not edit_state then
return
end
local platform = state.get_platform()
local contest_id = state.get_contest_id()
local problem_id = state.get_problem_id()
if not platform or not contest_id or not problem_id then
return
end
for i, pair in ipairs(edit_state.test_buffers) do
if
vim.api.nvim_buf_is_valid(pair.input_buf) and vim.api.nvim_buf_is_valid(pair.expected_buf)
then
local input_lines = vim.api.nvim_buf_get_lines(pair.input_buf, 0, -1, false)
local expected_lines = vim.api.nvim_buf_get_lines(pair.expected_buf, 0, -1, false)
edit_state.test_cases[i].input = table.concat(input_lines, '\n')
edit_state.test_cases[i].expected = table.concat(expected_lines, '\n')
end
end
local contest_data = cache.get_contest_data(platform, contest_id)
local is_multi_test = contest_data.problems[contest_data.index_map[problem_id]].multi_test
or false
-- Generate combined test from individual test cases
local combined_input = table.concat(
vim.tbl_map(function(tc)
return tc.input
end, edit_state.test_cases),
'\n'
)
local combined_expected = table.concat(
vim.tbl_map(function(tc)
return tc.expected
end, edit_state.test_cases),
'\n'
)
cache.set_test_cases(
platform,
contest_id,
problem_id,
{ input = combined_input, expected = combined_expected },
edit_state.test_cases,
edit_state.constraints and edit_state.constraints.timeout_ms or 0,
edit_state.constraints and edit_state.constraints.memory_mb or 0,
false,
is_multi_test
)
local config = config_module.get_config()
local base_name = config.filename and config.filename(platform, contest_id, problem_id, config)
or config_module.default_filename(contest_id, problem_id)
vim.fn.mkdir('io', 'p')
for i, tc in ipairs(edit_state.test_cases) do
local input_file = string.format('io/%s.%d.cpin', base_name, i)
local expected_file = string.format('io/%s.%d.cpout', base_name, i)
local input_content = (tc.input or ''):gsub('\r', '')
local expected_content = (tc.expected or ''):gsub('\r', '')
vim.fn.writefile(vim.split(input_content, '\n', { trimempty = true }), input_file)
vim.fn.writefile(vim.split(expected_content, '\n', { trimempty = true }), expected_file)
end
logger.log('Saved all test cases')
end
function M.toggle_edit(test_index) function M.toggle_edit(test_index)
if edit_state then if edit_state then
save_all_tests() save_all_tests()
for _, pair in ipairs(edit_state.test_buffers) do
if vim.api.nvim_buf_is_valid(pair.input_buf) then
vim.api.nvim_buf_delete(pair.input_buf, { force = true })
end
if vim.api.nvim_buf_is_valid(pair.expected_buf) then
vim.api.nvim_buf_delete(pair.expected_buf, { force = true })
end
end
edit_state = nil edit_state = nil
pcall(vim.api.nvim_clear_autocmds, { group = 'cp_edit_guard' }) pcall(vim.api.nvim_clear_autocmds, { group = 'cp_edit_guard' })
@ -372,10 +350,7 @@ function M.toggle_edit(test_index)
state.get_platform(), state.get_contest_id(), state.get_problem_id() state.get_platform(), state.get_contest_id(), state.get_problem_id()
if not platform or not contest_id or not problem_id then if not platform or not contest_id or not problem_id then
logger.log( logger.log('No problem context. Run :CP <platform> <contest> first.', vim.log.levels.ERROR)
'No problem context. Run :CP <platform> <contest> first.',
{ level = vim.log.levels.ERROR }
)
return return
end end
@ -383,7 +358,7 @@ function M.toggle_edit(test_index)
local test_cases = cache.get_test_cases(platform, contest_id, problem_id) local test_cases = cache.get_test_cases(platform, contest_id, problem_id)
if not test_cases or #test_cases == 0 then if not test_cases or #test_cases == 0 then
logger.log('No test cases available for editing.', { level = vim.log.levels.ERROR }) logger.log('No test cases available for editing.', vim.log.levels.ERROR)
return return
end end
@ -396,7 +371,7 @@ function M.toggle_edit(test_index)
if target_index < 1 or target_index > #test_cases then if target_index < 1 or target_index > #test_cases then
logger.log( logger.log(
('Test %d does not exist (only %d tests available)'):format(target_index, #test_cases), ('Test %d does not exist (only %d tests available)'):format(target_index, #test_cases),
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
return return
end end
@ -436,7 +411,7 @@ function M.toggle_edit(test_index)
vim.api.nvim_win_set_buf(input_win, input_buf) vim.api.nvim_win_set_buf(input_win, input_buf)
vim.bo[input_buf].modifiable = true vim.bo[input_buf].modifiable = true
vim.bo[input_buf].readonly = false vim.bo[input_buf].readonly = false
vim.bo[input_buf].buftype = 'acwrite' vim.bo[input_buf].buftype = 'nofile'
vim.bo[input_buf].buflisted = false vim.bo[input_buf].buflisted = false
helpers.clearcol(input_buf) helpers.clearcol(input_buf)
@ -446,7 +421,7 @@ function M.toggle_edit(test_index)
vim.api.nvim_win_set_buf(expected_win, expected_buf) vim.api.nvim_win_set_buf(expected_win, expected_buf)
vim.bo[expected_buf].modifiable = true vim.bo[expected_buf].modifiable = true
vim.bo[expected_buf].readonly = false vim.bo[expected_buf].readonly = false
vim.bo[expected_buf].buftype = 'acwrite' vim.bo[expected_buf].buftype = 'nofile'
vim.bo[expected_buf].buflisted = false vim.bo[expected_buf].buflisted = false
helpers.clearcol(expected_buf) helpers.clearcol(expected_buf)

View file

@ -1,9 +1,3 @@
---@class DiffLayout
---@field buffers integer[]
---@field windows integer[]
---@field mode string
---@field cleanup fun()
local M = {} local M = {}
local helpers = require('cp.helpers') local helpers = require('cp.helpers')
@ -177,11 +171,6 @@ local function create_single_layout(parent_win, content)
} }
end end
---@param mode string
---@param parent_win integer
---@param expected_content string
---@param actual_content string
---@return DiffLayout
function M.create_diff_layout(mode, parent_win, expected_content, actual_content) function M.create_diff_layout(mode, parent_win, expected_content, actual_content)
if mode == 'single' then if mode == 'single' then
return create_single_layout(parent_win, actual_content) return create_single_layout(parent_win, actual_content)
@ -196,13 +185,6 @@ function M.create_diff_layout(mode, parent_win, expected_content, actual_content
end end
end end
---@param current_diff_layout DiffLayout?
---@param current_mode string?
---@param main_win integer
---@param run table
---@param config cp.Config
---@param setup_keybindings_for_buffer fun(buf: integer)
---@return DiffLayout?, string?
function M.update_diff_panes( function M.update_diff_panes(
current_diff_layout, current_diff_layout,
current_mode, current_mode,

View file

@ -2,7 +2,6 @@ local M = {}
---@class PanelOpts ---@class PanelOpts
---@field debug? boolean ---@field debug? boolean
---@field test_index? integer
local cache = require('cp.cache') local cache = require('cp.cache')
local config_module = require('cp.config') local config_module = require('cp.config')
@ -14,9 +13,8 @@ local utils = require('cp.utils')
local current_diff_layout = nil local current_diff_layout = nil
local current_mode = nil local current_mode = nil
local _run_gen = 0 local io_view_running = false
---@return nil
function M.disable() function M.disable()
local active_panel = state.get_active_panel() local active_panel = state.get_active_panel()
if not active_panel then if not active_panel then
@ -28,8 +26,6 @@ function M.disable()
M.toggle_panel() M.toggle_panel()
elseif active_panel == 'interactive' then elseif active_panel == 'interactive' then
M.toggle_interactive() M.toggle_interactive()
elseif active_panel == 'stress' then
require('cp.stress').toggle()
else else
logger.log(('Unknown panel type: %s'):format(tostring(active_panel))) logger.log(('Unknown panel type: %s'):format(tostring(active_panel)))
end end
@ -54,7 +50,7 @@ function M.toggle_interactive(interactor_cmd)
end end
if state.get_active_panel() then if state.get_active_panel() then
logger.log('Another panel is already active.', { level = vim.log.levels.WARN }) logger.log('Another panel is already active.', vim.log.levels.WARN)
return return
end end
@ -63,7 +59,7 @@ function M.toggle_interactive(interactor_cmd)
if not platform or not contest_id or not problem_id then if not platform or not contest_id or not problem_id then
logger.log( logger.log(
'No platform/contest/problem configured. Use :CP <platform> <contest> [...] first.', 'No platform/contest/problem configured. Use :CP <platform> <contest> [...] first.',
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
return return
end end
@ -71,14 +67,11 @@ function M.toggle_interactive(interactor_cmd)
cache.load() cache.load()
local contest_data = cache.get_contest_data(platform, contest_id) local contest_data = cache.get_contest_data(platform, contest_id)
if if
contest_data not contest_data
and contest_data.index_map or not contest_data.index_map
and not contest_data.problems[contest_data.index_map[problem_id]].interactive or not contest_data.problems[contest_data.index_map[problem_id]].interactive
then then
logger.log( logger.log('This problem is interactive. Use :CP interact.', vim.log.levels.ERROR)
'This problem is not interactive. Use :CP {run,panel}.',
{ level = vim.log.levels.ERROR }
)
return return
end end
@ -107,7 +100,7 @@ function M.toggle_interactive(interactor_cmd)
local binary = state.get_binary_file() local binary = state.get_binary_file()
if not binary or binary == '' then if not binary or binary == '' then
logger.log('No binary produced.', { level = vim.log.levels.ERROR }) logger.log('No binary produced.', vim.log.levels.ERROR)
restore_session() restore_session()
return return
end end
@ -121,29 +114,20 @@ function M.toggle_interactive(interactor_cmd)
if vim.fn.executable(interactor) ~= 1 then if vim.fn.executable(interactor) ~= 1 then
logger.log( logger.log(
("Interactor '%s' is not executable."):format(interactor_cmd), ("Interactor '%s' is not executable."):format(interactor_cmd),
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
restore_session() restore_session()
return return
end end
local orchestrator = local orchestrator =
vim.fn.fnamemodify(utils.get_plugin_path() .. '/scripts/interact.py', ':p') vim.fn.fnamemodify(utils.get_plugin_path() .. '/scripts/interact.py', ':p')
if utils.is_nix_build() then cmdline = table.concat({
cmdline = table.concat({ 'uv',
vim.fn.shellescape(utils.get_nix_python()), 'run',
vim.fn.shellescape(orchestrator), vim.fn.shellescape(orchestrator),
vim.fn.shellescape(interactor), vim.fn.shellescape(interactor),
vim.fn.shellescape(binary), vim.fn.shellescape(binary),
}, ' ') }, ' ')
else
cmdline = table.concat({
'uv',
'run',
vim.fn.shellescape(orchestrator),
vim.fn.shellescape(interactor),
vim.fn.shellescape(binary),
}, ' ')
end
else else
cmdline = vim.fn.shellescape(binary) cmdline = vim.fn.shellescape(binary)
end end
@ -245,6 +229,7 @@ local function get_or_create_io_buffers()
state.set_io_view_state({ state.set_io_view_state({
output_buf = output_buf, output_buf = output_buf,
input_buf = input_buf, input_buf = input_buf,
current_test_index = 1,
source_buf = current_source_buf, source_buf = current_source_buf,
}) })
@ -309,6 +294,49 @@ local function get_or_create_io_buffers()
end, end,
}) })
local cfg = config_module.get_config()
local platform = state.get_platform()
local contest_id = state.get_contest_id()
local problem_id = state.get_problem_id()
local function navigate_test(delta)
local io_view_state = state.get_io_view_state()
if not io_view_state then
return
end
if not platform or not contest_id or not problem_id then
return
end
local test_cases = cache.get_test_cases(platform, contest_id, problem_id)
if not test_cases or #test_cases == 0 then
return
end
local new_index = (io_view_state.current_test_index or 1) + delta
if new_index < 1 or new_index > #test_cases then
return
end
io_view_state.current_test_index = new_index
M.run_io_view(new_index)
end
if cfg.ui.run.next_test_key then
vim.keymap.set('n', cfg.ui.run.next_test_key, function()
navigate_test(1)
end, { buffer = output_buf, silent = true, desc = 'Next test' })
vim.keymap.set('n', cfg.ui.run.next_test_key, function()
navigate_test(1)
end, { buffer = input_buf, silent = true, desc = 'Next test' })
end
if cfg.ui.run.prev_test_key then
vim.keymap.set('n', cfg.ui.run.prev_test_key, function()
navigate_test(-1)
end, { buffer = output_buf, silent = true, desc = 'Previous test' })
vim.keymap.set('n', cfg.ui.run.prev_test_key, function()
navigate_test(-1)
end, { buffer = input_buf, silent = true, desc = 'Previous test' })
end
return output_buf, input_buf return output_buf, input_buf
end end
@ -352,14 +380,13 @@ local function create_window_layout(output_buf, input_buf)
vim.api.nvim_set_current_win(solution_win) vim.api.nvim_set_current_win(solution_win)
end end
---@return nil
function M.ensure_io_view() function M.ensure_io_view()
local platform, contest_id, problem_id = local platform, contest_id, problem_id =
state.get_platform(), state.get_contest_id(), state.get_problem_id() state.get_platform(), state.get_contest_id(), state.get_problem_id()
if not platform or not contest_id or not problem_id then if not platform or not contest_id or not problem_id then
logger.log( logger.log(
'No platform/contest/problem configured. Use :CP <platform> <contest> [...] first.', 'No platform/contest/problem configured. Use :CP <platform> <contest> [...] first.',
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
return return
end end
@ -388,10 +415,7 @@ function M.ensure_io_view()
and contest_data.index_map and contest_data.index_map
and contest_data.problems[contest_data.index_map[problem_id]].interactive and contest_data.problems[contest_data.index_map[problem_id]].interactive
then then
logger.log( logger.log('This problem is not interactive. Use :CP {run,panel}.', vim.log.levels.ERROR)
'This problem is not interactive. Use :CP {run,panel}.',
{ level = vim.log.levels.ERROR }
)
return return
end end
@ -411,12 +435,12 @@ function M.ensure_io_view()
local cfg = config_module.get_config() local cfg = config_module.get_config()
local io = cfg.hooks and cfg.hooks.setup and cfg.hooks.setup.io if cfg.hooks and cfg.hooks.setup_io_output then
if io and io.output then pcall(cfg.hooks.setup_io_output, output_buf, state)
pcall(io.output, output_buf, state)
end end
if io and io.input then
pcall(io.input, input_buf, state) if cfg.hooks and cfg.hooks.setup_io_input then
pcall(cfg.hooks.setup_io_input, input_buf, state)
end end
local test_cases = cache.get_test_cases(platform, contest_id, problem_id) local test_cases = cache.get_test_cases(platform, contest_id, problem_id)
@ -600,17 +624,14 @@ local function render_io_view_results(io_state, test_indices, mode, combined_res
utils.update_buffer_content(io_state.output_buf, output_lines, final_highlights, output_ns) utils.update_buffer_content(io_state.output_buf, output_lines, final_highlights, output_ns)
end end
---@param test_indices_arg integer[]?
---@param debug boolean?
---@param mode? string
function M.run_io_view(test_indices_arg, debug, mode) function M.run_io_view(test_indices_arg, debug, mode)
_run_gen = _run_gen + 1 if io_view_running then
local gen = _run_gen logger.log('Tests already running', vim.log.levels.WARN)
return
end
io_view_running = true
logger.log( logger.log(('%s tests...'):format(debug and 'Debugging' or 'Running'), vim.log.levels.INFO, true)
('%s tests...'):format(debug and 'Debugging' or 'Running'),
{ level = vim.log.levels.INFO, override = true }
)
mode = mode or 'combined' mode = mode or 'combined'
@ -619,15 +640,17 @@ function M.run_io_view(test_indices_arg, debug, mode)
if not platform or not contest_id or not problem_id then if not platform or not contest_id or not problem_id then
logger.log( logger.log(
'No platform/contest/problem configured. Use :CP <platform> <contest> [...] first.', 'No platform/contest/problem configured. Use :CP <platform> <contest> [...] first.',
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
io_view_running = false
return return
end end
cache.load() cache.load()
local contest_data = cache.get_contest_data(platform, contest_id) local contest_data = cache.get_contest_data(platform, contest_id)
if not contest_data or not contest_data.index_map then if not contest_data or not contest_data.index_map then
logger.log('No test cases available.', { level = vim.log.levels.ERROR }) logger.log('No test cases available.', vim.log.levels.ERROR)
io_view_running = false
return return
end end
@ -643,12 +666,14 @@ function M.run_io_view(test_indices_arg, debug, mode)
if mode == 'combined' then if mode == 'combined' then
local combined = cache.get_combined_test(platform, contest_id, problem_id) local combined = cache.get_combined_test(platform, contest_id, problem_id)
if not combined then if not combined then
logger.log('No combined test available', { level = vim.log.levels.ERROR }) logger.log('No combined test available', vim.log.levels.ERROR)
io_view_running = false
return return
end end
else else
if not run.load_test_cases() then if not run.load_test_cases() then
logger.log('No test cases available', { level = vim.log.levels.ERROR }) logger.log('No test cases available', vim.log.levels.ERROR)
io_view_running = false
return return
end end
end end
@ -667,8 +692,9 @@ function M.run_io_view(test_indices_arg, debug, mode)
idx, idx,
#test_state.test_cases #test_state.test_cases
), ),
{ level = vim.log.levels.WARN } vim.log.levels.WARN
) )
io_view_running = false
return return
end end
end end
@ -686,6 +712,7 @@ function M.run_io_view(test_indices_arg, debug, mode)
local io_state = state.get_io_view_state() local io_state = state.get_io_view_state()
if not io_state then if not io_state then
io_view_running = false
return return
end end
@ -698,10 +725,8 @@ function M.run_io_view(test_indices_arg, debug, mode)
local execute = require('cp.runner.execute') local execute = require('cp.runner.execute')
execute.compile_problem(debug, function(compile_result) execute.compile_problem(debug, function(compile_result)
if gen ~= _run_gen then
return
end
if not vim.api.nvim_buf_is_valid(io_state.output_buf) then if not vim.api.nvim_buf_is_valid(io_state.output_buf) then
io_view_running = false
return return
end end
@ -721,64 +746,43 @@ function M.run_io_view(test_indices_arg, debug, mode)
local ns = vim.api.nvim_create_namespace('cp_io_view_compile_error') local ns = vim.api.nvim_create_namespace('cp_io_view_compile_error')
utils.update_buffer_content(io_state.output_buf, lines, highlights, ns) utils.update_buffer_content(io_state.output_buf, lines, highlights, ns)
io_view_running = false
return return
end end
if mode == 'combined' then if mode == 'combined' then
local combined = cache.get_combined_test(platform, contest_id, problem_id) local combined = cache.get_combined_test(platform, contest_id, problem_id)
if not combined then if not combined then
logger.log('No combined test found', { level = vim.log.levels.ERROR }) logger.log('No combined test found', vim.log.levels.ERROR)
io_view_running = false
return return
end end
run.load_test_cases() run.load_test_cases()
run.run_combined_test(debug, function(result) run.run_combined_test(debug, function(result)
if gen ~= _run_gen then
return
end
if not result then if not result then
logger.log('Failed to run combined test', { level = vim.log.levels.ERROR }) logger.log('Failed to run combined test', vim.log.levels.ERROR)
io_view_running = false
return return
end end
if vim.api.nvim_buf_is_valid(io_state.output_buf) then if vim.api.nvim_buf_is_valid(io_state.output_buf) then
render_io_view_results(io_state, test_indices, mode, result, combined.input) render_io_view_results(io_state, test_indices, mode, result, combined.input)
end end
io_view_running = false
end) end)
else else
run.run_all_test_cases(test_indices, debug, nil, function() run.run_all_test_cases(test_indices, debug, nil, function()
if gen ~= _run_gen then
return
end
if vim.api.nvim_buf_is_valid(io_state.output_buf) then if vim.api.nvim_buf_is_valid(io_state.output_buf) then
render_io_view_results(io_state, test_indices, mode, nil, nil) render_io_view_results(io_state, test_indices, mode, nil, nil)
end end
io_view_running = false
end) end)
end end
end) end)
end end
---@return nil
function M.cancel_io_view()
_run_gen = _run_gen + 1
end
---@return nil
function M.cancel_interactive()
if state.interactive_buf and vim.api.nvim_buf_is_valid(state.interactive_buf) then
local job = vim.b[state.interactive_buf].terminal_job_id
if job then
vim.fn.jobstop(job)
end
end
if state.saved_interactive_session then
vim.fn.delete(state.saved_interactive_session)
state.saved_interactive_session = nil
end
state.set_active_panel(nil)
end
---@param panel_opts? PanelOpts ---@param panel_opts? PanelOpts
function M.toggle_panel(panel_opts) function M.toggle_panel(panel_opts)
if state.get_active_panel() == 'run' then if state.get_active_panel() == 'run' then
@ -799,7 +803,7 @@ function M.toggle_panel(panel_opts)
end end
if state.get_active_panel() then if state.get_active_panel() then
logger.log('another panel is already active', { level = vim.log.levels.ERROR }) logger.log('another panel is already active', vim.log.levels.ERROR)
return return
end end
@ -808,7 +812,7 @@ function M.toggle_panel(panel_opts)
if not platform or not contest_id then if not platform or not contest_id then
logger.log( logger.log(
'No platform/contest configured. Use :CP <platform> <contest> [...] first.', 'No platform/contest configured. Use :CP <platform> <contest> [...] first.',
{ level = vim.log.levels.ERROR } vim.log.levels.ERROR
) )
return return
end end
@ -817,13 +821,9 @@ function M.toggle_panel(panel_opts)
local contest_data = cache.get_contest_data(platform, contest_id) local contest_data = cache.get_contest_data(platform, contest_id)
if if
contest_data contest_data
and contest_data.index_map
and contest_data.problems[contest_data.index_map[state.get_problem_id()]].interactive and contest_data.problems[contest_data.index_map[state.get_problem_id()]].interactive
then then
logger.log( logger.log('This is an interactive problem. Use :CP interact instead.', vim.log.levels.WARN)
'This is an interactive problem. Use :CP interact instead.',
{ level = vim.log.levels.WARN }
)
return return
end end
@ -834,17 +834,10 @@ function M.toggle_panel(panel_opts)
logger.log(('run panel: checking test cases for %s'):format(input_file or 'none')) logger.log(('run panel: checking test cases for %s'):format(input_file or 'none'))
if not run.load_test_cases() then if not run.load_test_cases() then
logger.log('no test cases found', { level = vim.log.levels.WARN }) logger.log('no test cases found', vim.log.levels.WARN)
return return
end end
if panel_opts and panel_opts.test_index then
local test_state = run.get_panel_state()
if panel_opts.test_index >= 1 and panel_opts.test_index <= #test_state.test_cases then
test_state.current_index = panel_opts.test_index
end
end
local io_state = state.get_io_view_state() local io_state = state.get_io_view_state()
if io_state then if io_state then
for _, win in ipairs(vim.api.nvim_list_wins()) do for _, win in ipairs(vim.api.nvim_list_wins()) do
@ -956,15 +949,14 @@ function M.toggle_panel(panel_opts)
setup_keybindings_for_buffer(test_buffers.tab_buf) setup_keybindings_for_buffer(test_buffers.tab_buf)
local o = config.hooks and config.hooks.on if config.hooks and config.hooks.before_run then
if o and o.run then vim.schedule_wrap(function()
vim.schedule(function() config.hooks.before_run(state)
o.run(state)
end) end)
end end
if panel_opts and panel_opts.debug and o and o.debug then if panel_opts and panel_opts.debug and config.hooks and config.hooks.before_debug then
vim.schedule(function() vim.schedule_wrap(function()
o.debug(state) config.hooks.before_debug(state)
end) end)
end end

View file

@ -2,11 +2,7 @@ local M = {}
local logger = require('cp.log') local logger = require('cp.log')
local _nix_python = nil local uname = vim.loop.os_uname()
local _nix_submit_cmd = nil
local _nix_discovered = false
local uname = vim.uv.os_uname()
local _time_cached = false local _time_cached = false
local _time_path = nil local _time_path = nil
@ -61,11 +57,7 @@ local function find_gnu_time()
_time_cached = true _time_cached = true
_time_path = nil _time_path = nil
if uname and uname.sysname == 'Darwin' then _time_reason = 'GNU time not found'
_time_reason = 'GNU time not found (install via: brew install coreutils)'
else
_time_reason = 'GNU time not found'
end
return _time_path, _time_reason return _time_path, _time_reason
end end
@ -87,170 +79,7 @@ function M.get_plugin_path()
return vim.fn.fnamemodify(plugin_path, ':h:h:h') return vim.fn.fnamemodify(plugin_path, ':h:h:h')
end end
---@return boolean
function M.is_nix_build()
return _nix_python ~= nil
end
---@return string|nil
function M.get_nix_python()
return _nix_python
end
---@return boolean
function M.is_nix_discovered()
return _nix_discovered
end
---@param module string
---@param plugin_path string
---@return string[]
function M.get_python_cmd(module, plugin_path)
if _nix_python then
return { _nix_python, '-m', 'scrapers.' .. module }
end
return { 'uv', 'run', '--directory', plugin_path, '-m', 'scrapers.' .. module }
end
---@param module string
---@param plugin_path string
---@return string[]
function M.get_python_submit_cmd(module, plugin_path)
if _nix_submit_cmd then
return { _nix_submit_cmd, 'run', '--directory', plugin_path, '-m', 'scrapers.' .. module }
end
return { 'uv', 'run', '--directory', plugin_path, '-m', 'scrapers.' .. module }
end
local python_env_setup = false local python_env_setup = false
local _nix_submit_attempted = false
---@return boolean
local function discover_nix_submit_cmd()
local cache_dir = vim.fn.stdpath('cache') .. '/cp-nvim'
local cache_file = cache_dir .. '/nix-submit'
local f = io.open(cache_file, 'r')
if f then
local cached = f:read('*l')
f:close()
if cached and vim.fn.executable(cached) == 1 then
_nix_submit_cmd = cached
return true
end
end
local plugin_path = M.get_plugin_path()
vim.cmd.redraw()
logger.log(
'Building submit environment...',
{ level = vim.log.levels.INFO, override = true, sync = true }
)
vim.cmd.redraw()
local result = vim
.system(
{ 'nix', 'build', plugin_path .. '#submitEnv', '--no-link', '--print-out-paths' },
{ text = true }
)
:wait()
if result.code ~= 0 then
logger.log(
'nix build #submitEnv failed: ' .. (result.stderr or ''),
{ level = vim.log.levels.WARN }
)
return false
end
local store_path = result.stdout:gsub('%s+$', '')
local submit_cmd = store_path .. '/bin/cp-nvim-submit'
if vim.fn.executable(submit_cmd) ~= 1 then
logger.log('nix submit cmd not executable at ' .. submit_cmd, { level = vim.log.levels.WARN })
return false
end
vim.fn.mkdir(cache_dir, 'p')
f = io.open(cache_file, 'w')
if f then
f:write(submit_cmd)
f:close()
end
_nix_submit_cmd = submit_cmd
return true
end
---@return boolean
function M.setup_nix_submit_env()
if _nix_submit_cmd then
return true
end
if _nix_submit_attempted then
return false
end
_nix_submit_attempted = true
if vim.fn.executable('nix') == 1 then
return discover_nix_submit_cmd()
end
return false
end
---@return boolean
local function discover_nix_python()
local cache_dir = vim.fn.stdpath('cache') .. '/cp-nvim'
local cache_file = cache_dir .. '/nix-python'
local f = io.open(cache_file, 'r')
if f then
local cached = f:read('*l')
f:close()
if cached and vim.fn.executable(cached) == 1 then
_nix_python = cached
return true
end
end
local plugin_path = M.get_plugin_path()
logger.log(
'Building Python environment with nix...',
{ level = vim.log.levels.INFO, override = true, sync = true }
)
vim.cmd.redraw()
local result = vim
.system(
{ 'nix', 'build', plugin_path .. '#pythonEnv', '--no-link', '--print-out-paths' },
{ text = true }
)
:wait()
if result.code ~= 0 then
logger.log(
'nix build #pythonEnv failed: ' .. (result.stderr or ''),
{ level = vim.log.levels.WARN }
)
return false
end
local store_path = result.stdout:gsub('%s+$', '')
local python_path = store_path .. '/bin/python3'
if vim.fn.executable(python_path) ~= 1 then
logger.log('nix python not executable at ' .. python_path, { level = vim.log.levels.WARN })
return false
end
vim.fn.mkdir(cache_dir, 'p')
f = io.open(cache_file, 'w')
if f then
f:write(python_path)
f:close()
end
_nix_python = python_path
_nix_discovered = true
return true
end
---@return boolean success ---@return boolean success
function M.setup_python_env() function M.setup_python_env()
@ -258,23 +87,19 @@ function M.setup_python_env()
return true return true
end end
if _nix_python then local plugin_path = M.get_plugin_path()
logger.log('Python env: nix (python=' .. _nix_python .. ')') local venv_dir = plugin_path .. '/.venv'
python_env_setup = true
return true if vim.fn.executable('uv') == 0 then
logger.log(
'uv is not installed. Install it to enable problem scraping: https://docs.astral.sh/uv/',
vim.log.levels.WARN
)
return false
end end
local on_nixos = vim.fn.filereadable('/etc/NIXOS') == 1 if vim.fn.isdirectory(venv_dir) == 0 then
logger.log('Setting up Python environment for scrapers...')
if not on_nixos and vim.fn.executable('uv') == 1 then
local plugin_path = M.get_plugin_path()
logger.log('Python env: uv sync (dir=' .. plugin_path .. ')')
logger.log(
'Setting up Python environment...',
{ level = vim.log.levels.INFO, override = true, sync = true }
)
vim.cmd.redraw()
local env = vim.fn.environ() local env = vim.fn.environ()
env.VIRTUAL_ENV = '' env.VIRTUAL_ENV = ''
env.PYTHONPATH = '' env.PYTHONPATH = ''
@ -283,38 +108,18 @@ function M.setup_python_env()
.system({ 'uv', 'sync' }, { cwd = plugin_path, text = true, env = env }) .system({ 'uv', 'sync' }, { cwd = plugin_path, text = true, env = env })
:wait() :wait()
if result.code ~= 0 then if result.code ~= 0 then
logger.log( logger.log('Failed to setup Python environment: ' .. result.stderr, vim.log.levels.ERROR)
'Failed to setup Python environment: ' .. (result.stderr or ''),
{ level = vim.log.levels.ERROR }
)
return false return false
end end
if result.stderr and result.stderr ~= '' then logger.log('Python environment setup complete.')
logger.log('uv sync stderr: ' .. result.stderr:gsub('%s+$', ''))
end
python_env_setup = true
return true
end end
if vim.fn.executable('nix') == 1 then python_env_setup = true
logger.log('Python env: nix discovery') return true
if discover_nix_python() then
python_env_setup = true
return true
end
end
logger.log(
'No Python environment available. Install uv (https://docs.astral.sh/uv/) or use nix.',
{ level = vim.log.levels.WARN }
)
return false
end end
--- Configure the buffer with good defaults --- Configure the buffer with good defaults
---@param filetype? string ---@param filetype? string
---@return integer
function M.create_buffer_with_options(filetype) function M.create_buffer_with_options(filetype)
local buf = vim.api.nvim_create_buf(false, true) local buf = vim.api.nvim_create_buf(false, true)
vim.api.nvim_set_option_value('bufhidden', 'hide', { buf = buf }) vim.api.nvim_set_option_value('bufhidden', 'hide', { buf = buf })
@ -346,7 +151,6 @@ function M.update_buffer_content(bufnr, lines, highlights, namespace)
end end
end end
---@return boolean, string?
function M.check_required_runtime() function M.check_required_runtime()
if is_windows() then if is_windows() then
return false, 'Windows is not supported' return false, 'Windows is not supported'
@ -358,12 +162,20 @@ function M.check_required_runtime()
local time = M.time_capability() local time = M.time_capability()
if not time.ok then if not time.ok then
return false, time.reason return false, 'GNU time not found: ' .. (time.reason or '')
end end
local timeout = M.timeout_capability() local timeout = M.timeout_capability()
if not timeout.ok then if not timeout.ok then
return false, timeout.reason return false, 'GNU timeout not found: ' .. (timeout.reason or '')
end
if vim.fn.executable('uv') ~= 1 then
return false, 'uv not found (https://docs.astral.sh/uv/)'
end
if not M.setup_python_env() then
return false, 'failed to set up Python virtual environment'
end end
return true return true
@ -413,29 +225,22 @@ local function find_gnu_timeout()
_timeout_cached = true _timeout_cached = true
_timeout_path = nil _timeout_path = nil
if uname and uname.sysname == 'Darwin' then _timeout_reason = 'GNU timeout not found'
_timeout_reason = 'GNU timeout not found (install via: brew install coreutils)'
else
_timeout_reason = 'GNU timeout not found'
end
return _timeout_path, _timeout_reason return _timeout_path, _timeout_reason
end end
---@return string?
function M.timeout_path() function M.timeout_path()
local path = find_gnu_timeout() local path = find_gnu_timeout()
return path return path
end end
---@return { ok: boolean, path: string|nil, reason: string|nil }
function M.timeout_capability() function M.timeout_capability()
local path, reason = find_gnu_timeout() local path, reason = find_gnu_timeout()
return { ok = path ~= nil, path = path, reason = reason } return { ok = path ~= nil, path = path, reason = reason }
end end
---@return string[]
function M.cwd_executables() function M.cwd_executables()
local uv = vim.uv local uv = vim.uv or vim.loop
local req = uv.fs_scandir('.') local req = uv.fs_scandir('.')
if not req then if not req then
return {} return {}
@ -457,7 +262,6 @@ function M.cwd_executables()
return out return out
end end
---@return nil
function M.ensure_dirs() function M.ensure_dirs()
vim.system({ 'mkdir', '-p', 'build', 'io' }):wait() vim.system({ 'mkdir', '-p', 'build', 'io' }):wait()
end end

0
new Normal file
View file

View file

@ -43,13 +43,9 @@ end, {
vim.list_extend(candidates, platforms) vim.list_extend(candidates, platforms)
table.insert(candidates, 'cache') table.insert(candidates, 'cache')
table.insert(candidates, 'pick') table.insert(candidates, 'pick')
if platform and contest_id then if platform and contest_id then
vim.list_extend( vim.list_extend(candidates, actions)
candidates,
vim.tbl_filter(function(a)
return a ~= 'pick' and a ~= 'cache'
end, actions)
)
local cache = require('cp.cache') local cache = require('cp.cache')
cache.load() cache.load()
local contest_data = cache.get_contest_data(platform, contest_id) local contest_data = cache.get_contest_data(platform, contest_id)
@ -64,14 +60,13 @@ end, {
return filter_candidates(candidates) return filter_candidates(candidates)
elseif num_args == 3 then elseif num_args == 3 then
if vim.tbl_contains(platforms, args[2]) then if vim.tbl_contains(platforms, args[2]) then
local candidates = { 'login', 'logout', 'signup' }
local cache = require('cp.cache') local cache = require('cp.cache')
cache.load() cache.load()
vim.list_extend(candidates, cache.get_cached_contest_ids(args[2])) local contests = cache.get_cached_contest_ids(args[2])
return filter_candidates(candidates) return filter_candidates(contests)
elseif args[2] == 'cache' then elseif args[2] == 'cache' then
return filter_candidates({ 'clear', 'read' }) return filter_candidates({ 'clear', 'read' })
elseif args[2] == 'stress' or args[2] == 'interact' then elseif args[2] == 'interact' then
local utils = require('cp.utils') local utils = require('cp.utils')
return filter_candidates(utils.cwd_executables()) return filter_candidates(utils.cwd_executables())
elseif args[2] == 'edit' then elseif args[2] == 'edit' then
@ -108,8 +103,6 @@ end, {
end end
end end
return filter_candidates(candidates) return filter_candidates(candidates)
elseif args[2] == 'open' then
return filter_candidates({ 'problem', 'contest', 'standings' })
elseif args[2] == 'next' or args[2] == 'prev' or args[2] == 'pick' then elseif args[2] == 'next' or args[2] == 'prev' or args[2] == 'pick' then
return filter_candidates({ '--lang' }) return filter_candidates({ '--lang' })
else else
@ -119,10 +112,7 @@ end, {
end end
end end
elseif num_args == 4 then elseif num_args == 4 then
if args[2] == 'stress' then if args[2] == 'cache' and args[3] == 'clear' then
local utils = require('cp.utils')
return filter_candidates(utils.cwd_executables())
elseif args[2] == 'cache' and args[3] == 'clear' then
local candidates = vim.list_extend({}, platforms) local candidates = vim.list_extend({}, platforms)
table.insert(candidates, '') table.insert(candidates, '')
return filter_candidates(candidates) return filter_candidates(candidates)
@ -136,9 +126,6 @@ end, {
cache.load() cache.load()
local contest_data = cache.get_contest_data(args[2], args[3]) local contest_data = cache.get_contest_data(args[2], args[3])
local candidates = { '--lang' } local candidates = { '--lang' }
if not require('cp.race').status().active then
table.insert(candidates, '--race')
end
if contest_data and contest_data.problems then if contest_data and contest_data.problems then
for _, problem in ipairs(contest_data.problems) do for _, problem in ipairs(contest_data.problems) do
table.insert(candidates, problem.id) table.insert(candidates, problem.id)
@ -153,47 +140,17 @@ end, {
local contests = cache.get_cached_contest_ids(args[4]) local contests = cache.get_cached_contest_ids(args[4])
return filter_candidates(contests) return filter_candidates(contests)
elseif vim.tbl_contains(platforms, args[2]) then elseif vim.tbl_contains(platforms, args[2]) then
if args[3] == '--race' then if args[4] == '--lang' then
return filter_candidates({ '--lang' })
elseif args[4] == '--lang' then
return filter_candidates(get_enabled_languages(args[2])) return filter_candidates(get_enabled_languages(args[2]))
elseif args[3] == '--lang' then
local candidates = {}
if not require('cp.race').status().active then
table.insert(candidates, '--race')
end
return filter_candidates(candidates)
else else
return filter_candidates({ '--lang' }) return filter_candidates({ '--lang' })
end end
end end
elseif num_args == 6 then elseif num_args == 6 then
if vim.tbl_contains(platforms, args[2]) then if vim.tbl_contains(platforms, args[2]) and args[5] == '--lang' then
if args[3] == '--race' and args[4] == '--lang' then return filter_candidates(get_enabled_languages(args[2]))
return filter_candidates(get_enabled_languages(args[2]))
elseif args[3] == '--lang' and args[5] == '--race' then
return {}
elseif args[5] == '--lang' then
return filter_candidates(get_enabled_languages(args[2]))
end
end end
end end
return {} return {}
end, end,
}) })
local function cp_action(action)
return function()
require('cp').handle_command({ fargs = { action } })
end
end
vim.keymap.set('n', '<Plug>(cp-run)', cp_action('run'), { desc = 'CP run tests' })
vim.keymap.set('n', '<Plug>(cp-panel)', cp_action('panel'), { desc = 'CP open panel' })
vim.keymap.set('n', '<Plug>(cp-edit)', cp_action('edit'), { desc = 'CP edit test cases' })
vim.keymap.set('n', '<Plug>(cp-next)', cp_action('next'), { desc = 'CP next problem' })
vim.keymap.set('n', '<Plug>(cp-prev)', cp_action('prev'), { desc = 'CP previous problem' })
vim.keymap.set('n', '<Plug>(cp-pick)', cp_action('pick'), { desc = 'CP pick contest' })
vim.keymap.set('n', '<Plug>(cp-interact)', cp_action('interact'), { desc = 'CP interactive mode' })
vim.keymap.set('n', '<Plug>(cp-stress)', cp_action('stress'), { desc = 'CP stress test' })
vim.keymap.set('n', '<Plug>(cp-submit)', cp_action('submit'), { desc = 'CP submit solution' })

View file

@ -7,11 +7,13 @@ requires-python = ">=3.11"
dependencies = [ dependencies = [
"backoff>=2.2.1", "backoff>=2.2.1",
"beautifulsoup4>=4.13.5", "beautifulsoup4>=4.13.5",
"scrapling[fetchers]>=0.4", "curl-cffi>=0.13.0",
"httpx>=0.28.1", "httpx>=0.28.1",
"ndjson>=0.3.1", "ndjson>=0.3.1",
"pydantic>=2.11.10", "pydantic>=2.11.10",
"requests>=2.32.5", "requests>=2.32.5",
"scrapling[fetchers]>=0.3.5",
"types-requests>=2.32.4.20250913",
] ]
[dependency-groups] [dependency-groups]

View file

@ -2,11 +2,9 @@
import asyncio import asyncio
import json import json
import os
import re import re
import subprocess import sys
import time import time
from pathlib import Path
from typing import Any from typing import Any
import backoff import backoff
@ -16,121 +14,21 @@ from bs4 import BeautifulSoup, Tag
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry from urllib3.util.retry import Retry
from .base import BaseScraper, extract_precision from .base import BaseScraper
from .models import ( from .models import (
CombinedTest,
ContestListResult, ContestListResult,
ContestSummary, ContestSummary,
LoginResult,
MetadataResult, MetadataResult,
ProblemSummary, ProblemSummary,
SubmitResult,
TestCase, TestCase,
TestsResult,
) )
from .timeouts import (
BROWSER_ELEMENT_WAIT,
BROWSER_NAV_TIMEOUT,
BROWSER_SESSION_TIMEOUT,
BROWSER_SETTLE_DELAY,
BROWSER_SUBMIT_NAV_TIMEOUT,
BROWSER_TURNSTILE_POLL,
HTTP_TIMEOUT,
)
_LANGUAGE_ID_EXTENSION: dict[str, str] = {
"6002": "ada",
"6003": "apl",
"6004": "asm",
"6005": "asm",
"6006": "awk",
"6008": "sh",
"6009": "bas",
"6010": "bc",
"6012": "bf",
"6013": "c",
"6014": "c",
"6015": "cs",
"6016": "cs",
"6017": "cc",
"6021": "clj",
"6022": "clj",
"6023": "clj",
"6025": "cljs",
"6026": "cob",
"6027": "lisp",
"6028": "cr",
"6030": "d",
"6031": "d",
"6032": "d",
"6033": "dart",
"6038": "ex",
"6039": "el",
"6041": "erl",
"6042": "fs",
"6043": "factor",
"6044": "fish",
"6045": "fth",
"6046": "f90",
"6047": "f90",
"6048": "f",
"6049": "gleam",
"6050": "go",
"6051": "go",
"6052": "hs",
"6053": "hx",
"6054": "cc",
"6056": "java",
"6057": "js",
"6058": "js",
"6059": "js",
"6060": "jule",
"6061": "kk",
"6062": "kt",
"6065": "lean",
"6066": "ll",
"6067": "lua",
"6068": "lua",
"6071": "nim",
"6072": "nim",
"6073": "ml",
"6074": "m",
"6075": "pas",
"6076": "pl",
"6077": "php",
"6079": "pony",
"6080": "ps1",
"6081": "pro",
"6082": "py",
"6083": "py",
"6084": "r",
"6085": "re",
"6086": "rb",
"6087": "rb",
"6088": "rs",
"6089": "py",
"6090": "scala",
"6091": "scala",
"6092": "scm",
"6093": "scm",
"6094": "sd7",
"6095": "swift",
"6096": "tcl",
"6100": "ts",
"6101": "ts",
"6102": "ts",
"6105": "v",
"6106": "vala",
"6107": "v",
"6109": "wat",
"6111": "zig",
"6114": "jl",
"6115": "py",
"6116": "cc",
"6118": "sql",
}
MIB_TO_MB = 1.048576 MIB_TO_MB = 1.048576
BASE_URL = "https://atcoder.jp" BASE_URL = "https://atcoder.jp"
ARCHIVE_URL = f"{BASE_URL}/contests/archive" ARCHIVE_URL = f"{BASE_URL}/contests/archive"
TIMEOUT_SECONDS = 30
HEADERS = { HEADERS = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
} }
@ -173,7 +71,7 @@ def _retry_after_requests(details):
on_backoff=_retry_after_requests, on_backoff=_retry_after_requests,
) )
def _fetch(url: str) -> str: def _fetch(url: str) -> str:
r = _session.get(url, headers=HEADERS, timeout=HTTP_TIMEOUT) r = _session.get(url, headers=HEADERS, timeout=TIMEOUT_SECONDS)
if r.status_code in RETRY_STATUS: if r.status_code in RETRY_STATUS:
raise requests.HTTPError(response=r) raise requests.HTTPError(response=r)
r.raise_for_status() r.raise_for_status()
@ -196,7 +94,7 @@ def _giveup_httpx(exc: Exception) -> bool:
giveup=_giveup_httpx, giveup=_giveup_httpx,
) )
async def _get_async(client: httpx.AsyncClient, url: str) -> str: async def _get_async(client: httpx.AsyncClient, url: str) -> str:
r = await client.get(url, headers=HEADERS, timeout=HTTP_TIMEOUT) r = await client.get(url, headers=HEADERS, timeout=TIMEOUT_SECONDS)
r.raise_for_status() r.raise_for_status()
return r.text return r.text
@ -223,23 +121,6 @@ def _parse_last_page(html: str) -> int:
return max(nums) if nums else 1 return max(nums) if nums else 1
def _parse_start_time(tr: Tag) -> int | None:
tds = tr.select("td")
if not tds:
return None
time_el = tds[0].select_one("time.fixtime-full")
if not time_el:
return None
text = time_el.get_text(strip=True)
try:
from datetime import datetime
dt = datetime.strptime(text, "%Y-%m-%d %H:%M:%S%z")
return int(dt.timestamp())
except (ValueError, TypeError):
return None
def _parse_archive_contests(html: str) -> list[ContestSummary]: def _parse_archive_contests(html: str) -> list[ContestSummary]:
soup = BeautifulSoup(html, "html.parser") soup = BeautifulSoup(html, "html.parser")
tbody = soup.select_one("table.table-default tbody") or soup.select_one("tbody") tbody = soup.select_one("table.table-default tbody") or soup.select_one("tbody")
@ -258,10 +139,7 @@ def _parse_archive_contests(html: str) -> list[ContestSummary]:
continue continue
cid = m.group(1) cid = m.group(1)
name = a.get_text(strip=True) name = a.get_text(strip=True)
start_time = _parse_start_time(tr) out.append(ContestSummary(id=cid, name=name, display_name=name))
out.append(
ContestSummary(id=cid, name=name, display_name=name, start_time=start_time)
)
return out return out
@ -291,7 +169,7 @@ def _parse_tasks_list(html: str) -> list[dict[str, str]]:
return rows return rows
def _extract_problem_info(html: str) -> tuple[int, float, bool, float | None]: def _extract_problem_info(html: str) -> tuple[int, float, bool]:
soup = BeautifulSoup(html, "html.parser") soup = BeautifulSoup(html, "html.parser")
txt = soup.get_text(" ", strip=True) txt = soup.get_text(" ", strip=True)
timeout_ms = 0 timeout_ms = 0
@ -303,10 +181,9 @@ def _extract_problem_info(html: str) -> tuple[int, float, bool, float | None]:
if ms: if ms:
memory_mb = float(ms.group(1)) * MIB_TO_MB memory_mb = float(ms.group(1)) * MIB_TO_MB
div = soup.select_one("#problem-statement") div = soup.select_one("#problem-statement")
body = div.get_text(" ", strip=True) if div else soup.get_text(" ", strip=True) txt = div.get_text(" ", strip=True) if div else soup.get_text(" ", strip=True)
interactive = "This is an interactive" in body interactive = "This is an interactive" in txt
precision = extract_precision(body) return timeout_ms, memory_mb, interactive
return timeout_ms, memory_mb, interactive, precision
def _extract_samples(html: str) -> list[TestCase]: def _extract_samples(html: str) -> list[TestCase]:
@ -332,215 +209,6 @@ def _extract_samples(html: str) -> list[TestCase]:
return cases return cases
_TURNSTILE_JS = "() => { const el = document.querySelector('[name=\"cf-turnstile-response\"]'); return el && el.value.length > 0; }"
def _solve_turnstile(page) -> None:
if page.evaluate(_TURNSTILE_JS):
return
iframe_loc = page.locator('iframe[src*="challenges.cloudflare.com"]')
if not iframe_loc.count():
return
for _ in range(6):
try:
box = iframe_loc.first.bounding_box()
if box:
page.mouse.click(
box["x"] + box["width"] * 0.15,
box["y"] + box["height"] * 0.5,
)
except Exception:
pass
try:
page.wait_for_function(_TURNSTILE_JS, timeout=BROWSER_TURNSTILE_POLL)
return
except Exception:
pass
raise RuntimeError("Turnstile not solved after multiple attempts")
def _ensure_browser() -> None:
try:
from patchright._impl._driver import compute_driver_executable # type: ignore[import-untyped,unresolved-import]
node, cli = compute_driver_executable()
except Exception:
return
browser_info = subprocess.run(
[node, cli, "install", "--dry-run", "chromium"],
capture_output=True,
text=True,
)
for line in browser_info.stdout.splitlines():
if "Install location:" in line:
install_dir = line.split(":", 1)[1].strip()
if not os.path.isdir(install_dir):
print(json.dumps({"status": "installing_browser"}), flush=True)
subprocess.run([node, cli, "install", "chromium"], check=True)
break
def _login_headless(credentials: dict[str, str]) -> LoginResult:
try:
from scrapling.fetchers import StealthySession # type: ignore[import-untyped,unresolved-import]
except ImportError:
return LoginResult(
success=False,
error="scrapling is required for AtCoder login. Install it: uv add 'scrapling[fetchers]>=0.4'",
)
_ensure_browser()
logged_in = False
login_error: str | None = None
def check_login(page):
nonlocal logged_in
logged_in = page.evaluate(
"() => Array.from(document.querySelectorAll('a')).some(a => a.textContent.trim() === 'Sign Out')"
)
def login_action(page):
nonlocal login_error
try:
_solve_turnstile(page)
page.fill('input[name="username"]', credentials.get("username", ""))
page.fill('input[name="password"]', credentials.get("password", ""))
page.click("#submit")
page.wait_for_url(
lambda url: "/login" not in url, timeout=BROWSER_NAV_TIMEOUT
)
except Exception as e:
login_error = str(e)
try:
with StealthySession(
headless=True,
timeout=BROWSER_SESSION_TIMEOUT,
google_search=False,
) as session:
print(json.dumps({"status": "logging_in"}), flush=True)
session.fetch(
f"{BASE_URL}/login",
page_action=login_action,
solve_cloudflare=True,
)
if login_error:
return LoginResult(success=False, error=f"Login failed: {login_error}")
session.fetch(
f"{BASE_URL}/home", page_action=check_login, network_idle=True
)
if not logged_in:
return LoginResult(
success=False, error="Login failed (bad credentials?)"
)
return LoginResult(success=True, error="")
except Exception as e:
return LoginResult(success=False, error=str(e))
def _submit_headless(
contest_id: str,
problem_id: str,
file_path: str,
language_id: str,
credentials: dict[str, str],
) -> "SubmitResult":
try:
from scrapling.fetchers import StealthySession # type: ignore[import-untyped,unresolved-import]
except ImportError:
return SubmitResult(
success=False,
error="scrapling is required for AtCoder submit. Install it: uv add 'scrapling[fetchers]>=0.4'",
)
_ensure_browser()
login_error: str | None = None
submit_error: str | None = None
def login_action(page):
nonlocal login_error
try:
_solve_turnstile(page)
page.fill('input[name="username"]', credentials.get("username", ""))
page.fill('input[name="password"]', credentials.get("password", ""))
page.click("#submit")
page.wait_for_url(
lambda url: "/login" not in url, timeout=BROWSER_NAV_TIMEOUT
)
except Exception as e:
login_error = str(e)
def submit_action(page):
nonlocal submit_error
if "/login" in page.url:
submit_error = "Not logged in after login step"
return
try:
_solve_turnstile(page)
page.select_option(
'select[name="data.TaskScreenName"]',
f"{contest_id}_{problem_id}",
)
page.locator(
f'select[name="data.LanguageId"] option[value="{language_id}"]'
).wait_for(state="attached", timeout=BROWSER_ELEMENT_WAIT)
page.select_option('select[name="data.LanguageId"]', language_id)
ext = _LANGUAGE_ID_EXTENSION.get(
language_id, Path(file_path).suffix.lstrip(".") or "txt"
)
page.set_input_files(
"#input-open-file",
{
"name": f"solution.{ext}",
"mimeType": "text/plain",
"buffer": Path(file_path).read_bytes(),
},
)
page.wait_for_timeout(BROWSER_SETTLE_DELAY)
page.locator('button[type="submit"]').click(no_wait_after=True)
page.wait_for_url(
lambda url: "/submissions/me" in url,
timeout=BROWSER_SUBMIT_NAV_TIMEOUT["atcoder"],
)
except Exception as e:
submit_error = str(e)
try:
with StealthySession(
headless=True,
timeout=BROWSER_SESSION_TIMEOUT,
google_search=False,
) as session:
print(json.dumps({"status": "logging_in"}), flush=True)
session.fetch(
f"{BASE_URL}/login",
page_action=login_action,
solve_cloudflare=True,
)
if login_error:
return SubmitResult(success=False, error=f"Login failed: {login_error}")
print(json.dumps({"status": "submitting"}), flush=True)
session.fetch(
f"{BASE_URL}/contests/{contest_id}/submit",
page_action=submit_action,
solve_cloudflare=True,
)
if submit_error:
return SubmitResult(success=False, error=submit_error)
return SubmitResult(
success=True, error="", submission_id="", verdict="submitted"
)
except Exception as e:
return SubmitResult(success=False, error=str(e))
def _scrape_tasks_sync(contest_id: str) -> list[dict[str, str]]: def _scrape_tasks_sync(contest_id: str) -> list[dict[str, str]]:
html = _fetch(f"{BASE_URL}/contests/{contest_id}/tasks") html = _fetch(f"{BASE_URL}/contests/{contest_id}/tasks")
return _parse_tasks_list(html) return _parse_tasks_list(html)
@ -552,13 +220,12 @@ def _scrape_problem_page_sync(contest_id: str, slug: str) -> dict[str, Any]:
tests = _extract_samples(html) tests = _extract_samples(html)
except Exception: except Exception:
tests = [] tests = []
timeout_ms, memory_mb, interactive, precision = _extract_problem_info(html) timeout_ms, memory_mb, interactive = _extract_problem_info(html)
return { return {
"tests": tests, "tests": tests,
"timeout_ms": timeout_ms, "timeout_ms": timeout_ms,
"memory_mb": memory_mb, "memory_mb": memory_mb,
"interactive": interactive, "interactive": interactive,
"precision": precision,
} }
@ -574,29 +241,14 @@ def _to_problem_summaries(rows: list[dict[str, str]]) -> list[ProblemSummary]:
return out return out
async def _fetch_upcoming_contests_async(
client: httpx.AsyncClient,
) -> list[ContestSummary]:
try:
html = await _get_async(client, f"{BASE_URL}/contests/")
return _parse_archive_contests(html)
except Exception:
return []
async def _fetch_all_contests_async() -> list[ContestSummary]: async def _fetch_all_contests_async() -> list[ContestSummary]:
async with httpx.AsyncClient( async with httpx.AsyncClient(
limits=httpx.Limits(max_connections=100, max_keepalive_connections=100), limits=httpx.Limits(max_connections=100, max_keepalive_connections=100),
) as client: ) as client:
upcoming = await _fetch_upcoming_contests_async(client)
first_html = await _get_async(client, ARCHIVE_URL) first_html = await _get_async(client, ARCHIVE_URL)
last = _parse_last_page(first_html) last = _parse_last_page(first_html)
out = _parse_archive_contests(first_html) out = _parse_archive_contests(first_html)
if last <= 1: if last <= 1:
seen = {c.id for c in out}
for c in upcoming:
if c.id not in seen:
out.append(c)
return out return out
tasks = [ tasks = [
asyncio.create_task(_get_async(client, f"{ARCHIVE_URL}?page={p}")) asyncio.create_task(_get_async(client, f"{ARCHIVE_URL}?page={p}"))
@ -605,10 +257,6 @@ async def _fetch_all_contests_async() -> list[ContestSummary]:
for coro in asyncio.as_completed(tasks): for coro in asyncio.as_completed(tasks):
html = await coro html = await coro
out.extend(_parse_archive_contests(html)) out.extend(_parse_archive_contests(html))
seen = {c.id for c in out}
for c in upcoming:
if c.id not in seen:
out.append(c)
return out return out
@ -631,8 +279,6 @@ class AtcoderScraper(BaseScraper):
contest_id=contest_id, contest_id=contest_id,
problems=problems, problems=problems,
url=f"https://atcoder.jp/contests/{contest_id}/tasks/{contest_id}_%s", url=f"https://atcoder.jp/contests/{contest_id}/tasks/{contest_id}_%s",
contest_url=f"https://atcoder.jp/contests/{contest_id}",
standings_url=f"https://atcoder.jp/contests/{contest_id}/standings",
) )
except Exception as e: except Exception as e:
return self._metadata_error(str(e)) return self._metadata_error(str(e))
@ -673,7 +319,6 @@ class AtcoderScraper(BaseScraper):
"memory_mb": data.get("memory_mb", 0), "memory_mb": data.get("memory_mb", 0),
"interactive": bool(data.get("interactive")), "interactive": bool(data.get("interactive")),
"multi_test": False, "multi_test": False,
"precision": data.get("precision"),
} }
), ),
flush=True, flush=True,
@ -681,28 +326,74 @@ class AtcoderScraper(BaseScraper):
await asyncio.gather(*(emit(r) for r in rows)) await asyncio.gather(*(emit(r) for r in rows))
async def submit(
self,
contest_id: str,
problem_id: str,
file_path: str,
language_id: str,
credentials: dict[str, str],
) -> SubmitResult:
return await asyncio.to_thread(
_submit_headless,
contest_id,
problem_id,
file_path,
language_id,
credentials,
)
async def login(self, credentials: dict[str, str]) -> LoginResult: async def main_async() -> int:
if not credentials.get("username") or not credentials.get("password"): if len(sys.argv) < 2:
return self._login_error("Missing username or password") result = MetadataResult(
return await asyncio.to_thread(_login_headless, credentials) success=False,
error="Usage: atcoder.py metadata <contest_id> OR atcoder.py tests <contest_id> OR atcoder.py contests",
url="",
)
print(result.model_dump_json())
return 1
mode: str = sys.argv[1]
scraper = AtcoderScraper()
if mode == "metadata":
if len(sys.argv) != 3:
result = MetadataResult(
success=False,
error="Usage: atcoder.py metadata <contest_id>",
url="",
)
print(result.model_dump_json())
return 1
contest_id = sys.argv[2]
result = await scraper.scrape_contest_metadata(contest_id)
print(result.model_dump_json())
return 0 if result.success else 1
if mode == "tests":
if len(sys.argv) != 3:
tests_result = TestsResult(
success=False,
error="Usage: atcoder.py tests <contest_id>",
problem_id="",
combined=CombinedTest(input="", expected=""),
tests=[],
timeout_ms=0,
memory_mb=0,
)
print(tests_result.model_dump_json())
return 1
contest_id = sys.argv[2]
await scraper.stream_tests_for_category_async(contest_id)
return 0
if mode == "contests":
if len(sys.argv) != 2:
contest_result = ContestListResult(
success=False, error="Usage: atcoder.py contests"
)
print(contest_result.model_dump_json())
return 1
contest_result = await scraper.scrape_contest_list()
print(contest_result.model_dump_json())
return 0 if contest_result.success else 1
result = MetadataResult(
success=False,
error="Unknown mode. Use 'metadata <contest_id>', 'tests <contest_id>', or 'contests'",
url="",
)
print(result.model_dump_json())
return 1
def main() -> None:
sys.exit(asyncio.run(main_async()))
if __name__ == "__main__": if __name__ == "__main__":
AtcoderScraper().run_cli() main()

View file

@ -1,38 +1,8 @@
import asyncio import asyncio
import json
import os
import re
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from .language_ids import get_language_id from .models import CombinedTest, ContestListResult, MetadataResult, TestsResult
from .models import (
CombinedTest,
ContestListResult,
LoginResult,
MetadataResult,
SubmitResult,
TestsResult,
)
_PRECISION_ABS_REL_RE = re.compile(
r"(?:absolute|relative)\s+error[^.]*?10\s*[\^{]\s*\{?\s*[-\u2212]\s*(\d+)\s*\}?",
re.IGNORECASE,
)
_PRECISION_DECIMAL_RE = re.compile(
r"round(?:ed)?\s+to\s+(\d+)\s+decimal\s+place",
re.IGNORECASE,
)
def extract_precision(text: str) -> float | None:
m = _PRECISION_ABS_REL_RE.search(text)
if m:
return 10 ** -int(m.group(1))
m = _PRECISION_DECIMAL_RE.search(text)
if m:
return 10 ** -int(m.group(1))
return None
class BaseScraper(ABC): class BaseScraper(ABC):
@ -49,22 +19,9 @@ class BaseScraper(ABC):
@abstractmethod @abstractmethod
async def stream_tests_for_category_async(self, category_id: str) -> None: ... async def stream_tests_for_category_async(self, category_id: str) -> None: ...
@abstractmethod
async def submit(
self,
contest_id: str,
problem_id: str,
file_path: str,
language_id: str,
credentials: dict[str, str],
) -> SubmitResult: ...
@abstractmethod
async def login(self, credentials: dict[str, str]) -> LoginResult: ...
def _usage(self) -> str: def _usage(self) -> str:
name = self.platform_name name = self.platform_name
return f"Usage: {name}.py metadata <id> | tests <id> | contests | login" return f"Usage: {name}.py metadata <id> | tests <id> | contests"
def _metadata_error(self, msg: str) -> MetadataResult: def _metadata_error(self, msg: str) -> MetadataResult:
return MetadataResult(success=False, error=msg, url="") return MetadataResult(success=False, error=msg, url="")
@ -83,12 +40,6 @@ class BaseScraper(ABC):
def _contests_error(self, msg: str) -> ContestListResult: def _contests_error(self, msg: str) -> ContestListResult:
return ContestListResult(success=False, error=msg) return ContestListResult(success=False, error=msg)
def _submit_error(self, msg: str) -> SubmitResult:
return SubmitResult(success=False, error=msg)
def _login_error(self, msg: str) -> LoginResult:
return LoginResult(success=False, error=msg)
async def _run_cli_async(self, args: list[str]) -> int: async def _run_cli_async(self, args: list[str]) -> int:
if len(args) < 2: if len(args) < 2:
print(self._metadata_error(self._usage()).model_dump_json()) print(self._metadata_error(self._usage()).model_dump_json())
@ -120,36 +71,6 @@ class BaseScraper(ABC):
print(result.model_dump_json()) print(result.model_dump_json())
return 0 if result.success else 1 return 0 if result.success else 1
case "submit":
if len(args) != 6:
print(
self._submit_error(
"Usage: <platform> submit <contest_id> <problem_id> <language_id> <file_path>"
).model_dump_json()
)
return 1
creds_raw = os.environ.get("CP_CREDENTIALS", "{}")
try:
credentials = json.loads(creds_raw)
except json.JSONDecodeError:
credentials = {}
language_id = get_language_id(self.platform_name, args[4]) or args[4]
result = await self.submit(
args[2], args[3], args[5], language_id, credentials
)
print(result.model_dump_json())
return 0 if result.success else 1
case "login":
creds_raw = os.environ.get("CP_CREDENTIALS", "{}")
try:
credentials = json.loads(creds_raw)
except json.JSONDecodeError:
credentials = {}
result = await self.login(credentials)
print(result.model_dump_json())
return 0 if result.success else 1
case _: case _:
print( print(
self._metadata_error( self._metadata_error(

View file

@ -3,285 +3,55 @@
import asyncio import asyncio
import json import json
import re import re
from datetime import datetime
from pathlib import Path
from typing import Any from typing import Any
import httpx import httpx
from scrapling.fetchers import Fetcher
from .base import BaseScraper from .base import BaseScraper
from .timeouts import BROWSER_SESSION_TIMEOUT, HTTP_TIMEOUT
from .models import ( from .models import (
ContestListResult, ContestListResult,
ContestSummary, ContestSummary,
LoginResult,
MetadataResult, MetadataResult,
ProblemSummary, ProblemSummary,
SubmitResult,
TestCase, TestCase,
) )
BASE_URL = "https://www.codechef.com" BASE_URL = "https://www.codechef.com"
API_CONTESTS_ALL = "/api/list/contests/all" API_CONTESTS_ALL = "/api/list/contests/all"
API_CONTESTS_PAST = "/api/list/contests/past"
API_CONTEST = "/api/contests/{contest_id}" API_CONTEST = "/api/contests/{contest_id}"
API_PROBLEM = "/api/contests/{contest_id}/problems/{problem_id}" API_PROBLEM = "/api/contests/{contest_id}/problems/{problem_id}"
PROBLEM_URL = "https://www.codechef.com/problems/{problem_id}"
HEADERS = { HEADERS = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
} }
TIMEOUT_S = 15.0
CONNECTIONS = 8 CONNECTIONS = 8
MEMORY_LIMIT_RE = re.compile(
_COOKIE_PATH = Path.home() / ".cache" / "cp-nvim" / "codechef-cookies.json" r"Memory\s+[Ll]imit.*?([0-9.]+)\s*(MB|GB)", re.IGNORECASE | re.DOTALL
)
_CC_CHECK_LOGIN_JS = "() => !!document.querySelector('a[href*=\"/users/\"]')"
_CC_LANG_IDS: dict[str, str] = {
"C++": "42",
"PYTH 3": "116",
"JAVA": "10",
"PYPY3": "109",
"GO": "114",
"rust": "93",
"KTLN": "47",
"NODEJS": "56",
"TS": "35",
}
async def fetch_json(client: httpx.AsyncClient, path: str) -> dict[str, Any]: async def fetch_json(client: httpx.AsyncClient, path: str) -> dict:
r = await client.get(BASE_URL + path, headers=HEADERS, timeout=HTTP_TIMEOUT) r = await client.get(BASE_URL + path, headers=HEADERS, timeout=TIMEOUT_S)
r.raise_for_status() r.raise_for_status()
return r.json() return r.json()
def _login_headless_codechef(credentials: dict[str, str]) -> LoginResult: def _extract_memory_limit(html: str) -> float:
try: m = MEMORY_LIMIT_RE.search(html)
from scrapling.fetchers import StealthySession # type: ignore[import-untyped,unresolved-import] if not m:
except ImportError: return 256.0
return LoginResult( value = float(m.group(1))
success=False, unit = m.group(2).upper()
error="scrapling is required for CodeChef login", if unit == "GB":
) return value * 1024.0
return value
from .atcoder import _ensure_browser
_ensure_browser()
_COOKIE_PATH.parent.mkdir(parents=True, exist_ok=True)
logged_in = False
login_error: str | None = None
def check_login(page):
nonlocal logged_in
logged_in = "dashboard" in page.url or page.evaluate(_CC_CHECK_LOGIN_JS)
def login_action(page):
nonlocal login_error
try:
page.locator('input[name="name"]').fill(credentials.get("username", ""))
page.locator('input[name="pass"]').fill(credentials.get("password", ""))
page.locator("input.cc-login-btn").click()
try:
page.wait_for_url(lambda url: "/login" not in url, timeout=3000)
except Exception:
login_error = "Login failed (bad credentials?)"
return
except Exception as e:
login_error = str(e)
try:
with StealthySession(
headless=True,
timeout=BROWSER_SESSION_TIMEOUT,
google_search=False,
) as session:
print(json.dumps({"status": "logging_in"}), flush=True)
session.fetch(f"{BASE_URL}/login", page_action=login_action)
if login_error:
return LoginResult(success=False, error=f"Login failed: {login_error}")
session.fetch(f"{BASE_URL}/", page_action=check_login, network_idle=True)
if not logged_in:
return LoginResult(
success=False, error="Login failed (bad credentials?)"
)
try:
browser_cookies = session.context.cookies()
if browser_cookies:
_COOKIE_PATH.write_text(json.dumps(browser_cookies))
except Exception:
pass
return LoginResult(success=True, error="")
except Exception as e:
return LoginResult(success=False, error=str(e))
def _submit_headless_codechef( def _fetch_html_sync(url: str) -> str:
contest_id: str, response = Fetcher.get(url)
problem_id: str, return str(response.body)
file_path: str,
language_id: str,
credentials: dict[str, str],
_retried: bool = False,
) -> SubmitResult:
source_code = Path(file_path).read_text()
try:
from scrapling.fetchers import StealthySession # type: ignore[import-untyped,unresolved-import]
except ImportError:
return SubmitResult(
success=False,
error="scrapling is required for CodeChef submit",
)
from .atcoder import _ensure_browser
_ensure_browser()
_COOKIE_PATH.parent.mkdir(parents=True, exist_ok=True)
saved_cookies: list[dict[str, Any]] = []
if _COOKIE_PATH.exists() and not _retried:
try:
saved_cookies = json.loads(_COOKIE_PATH.read_text())
except Exception:
pass
logged_in = bool(saved_cookies) and not _retried
login_error: str | None = None
submit_error: str | None = None
needs_relogin = False
def check_login(page):
nonlocal logged_in
logged_in = "dashboard" in page.url or page.evaluate(_CC_CHECK_LOGIN_JS)
def login_action(page):
nonlocal login_error
try:
page.locator('input[name="name"]').fill(credentials.get("username", ""))
page.locator('input[name="pass"]').fill(credentials.get("password", ""))
page.locator("input.cc-login-btn").click()
try:
page.wait_for_url(lambda url: "/login" not in url, timeout=3000)
except Exception:
login_error = "Login failed (bad credentials?)"
return
except Exception as e:
login_error = str(e)
def submit_action(page):
nonlocal submit_error, needs_relogin
if "/login" in page.url:
needs_relogin = True
return
try:
page.wait_for_timeout(2000)
page.locator('[aria-haspopup="listbox"]').click()
page.wait_for_selector('[role="option"]', timeout=5000)
page.locator(f'[role="option"][data-value="{language_id}"]').click()
page.wait_for_timeout(2000)
page.locator(".ace_editor").click()
page.keyboard.press("Control+a")
page.wait_for_timeout(200)
page.evaluate(
"""(code) => {
const textarea = document.querySelector('.ace_text-input');
const dt = new DataTransfer();
dt.setData('text/plain', code);
textarea.dispatchEvent(new ClipboardEvent('paste', {
clipboardData: dt, bubbles: true, cancelable: true
}));
}""",
source_code,
)
page.wait_for_timeout(1000)
page.evaluate(
"() => document.getElementById('submit_btn').scrollIntoView({block:'center'})"
)
page.wait_for_timeout(200)
page.locator("#submit_btn").dispatch_event("click")
page.wait_for_timeout(3000)
dialog_text = page.evaluate("""() => {
const d = document.querySelector('[role="dialog"], .swal2-popup');
return d ? d.textContent.trim() : null;
}""")
if dialog_text and "not available for accepting solutions" in dialog_text:
submit_error = "PRACTICE_FALLBACK"
elif dialog_text:
submit_error = dialog_text
except Exception as e:
submit_error = str(e)
try:
with StealthySession(
headless=True,
timeout=BROWSER_SESSION_TIMEOUT,
google_search=False,
cookies=saved_cookies if (saved_cookies and not _retried) else [],
) as session:
if not logged_in:
print(json.dumps({"status": "checking_login"}), flush=True)
session.fetch(
f"{BASE_URL}/", page_action=check_login, network_idle=True
)
if not logged_in:
print(json.dumps({"status": "logging_in"}), flush=True)
session.fetch(f"{BASE_URL}/login", page_action=login_action)
if login_error:
return SubmitResult(
success=False, error=f"Login failed: {login_error}"
)
print(json.dumps({"status": "submitting"}), flush=True)
submit_url = (
f"{BASE_URL}/submit/{problem_id}"
if contest_id == "PRACTICE"
else f"{BASE_URL}/{contest_id}/submit/{problem_id}"
)
session.fetch(submit_url, page_action=submit_action)
try:
browser_cookies = session.context.cookies()
if browser_cookies and logged_in:
_COOKIE_PATH.write_text(json.dumps(browser_cookies))
except Exception:
pass
if needs_relogin and not _retried:
_COOKIE_PATH.unlink(missing_ok=True)
return _submit_headless_codechef(
contest_id,
problem_id,
file_path,
language_id,
credentials,
_retried=True,
)
if submit_error == "PRACTICE_FALLBACK" and not _retried:
return _submit_headless_codechef(
"PRACTICE",
problem_id,
file_path,
language_id,
credentials,
_retried=True,
)
if submit_error:
return SubmitResult(success=False, error=submit_error)
return SubmitResult(success=True, error="", submission_id="")
except Exception as e:
return SubmitResult(success=False, error=str(e))
class CodeChefScraper(BaseScraper): class CodeChefScraper(BaseScraper):
@ -295,19 +65,12 @@ class CodeChefScraper(BaseScraper):
data = await fetch_json( data = await fetch_json(
client, API_CONTEST.format(contest_id=contest_id) client, API_CONTEST.format(contest_id=contest_id)
) )
problems_raw = data.get("problems") if not data.get("problems"):
if not problems_raw and isinstance(data.get("child_contests"), dict):
for div in ("div_4", "div_3", "div_2", "div_1"):
child = data["child_contests"].get(div, {})
child_code = child.get("contest_code")
if child_code:
return await self.scrape_contest_metadata(child_code)
if not problems_raw:
return self._metadata_error( return self._metadata_error(
f"No problems found for contest {contest_id}" f"No problems found for contest {contest_id}"
) )
problems = [] problems = []
for problem_code, problem_data in problems_raw.items(): for problem_code, problem_data in data["problems"].items():
if problem_data.get("category_name") == "main": if problem_data.get("category_name") == "main":
problems.append( problems.append(
ProblemSummary( ProblemSummary(
@ -320,120 +83,67 @@ class CodeChefScraper(BaseScraper):
error="", error="",
contest_id=contest_id, contest_id=contest_id,
problems=problems, problems=problems,
url=f"{BASE_URL}/problems/%s", url=f"{BASE_URL}/{contest_id}",
contest_url=f"{BASE_URL}/{contest_id}",
standings_url=f"{BASE_URL}/{contest_id}/rankings",
) )
except Exception as e: except Exception as e:
return self._metadata_error(f"Failed to fetch contest {contest_id}: {e}") return self._metadata_error(f"Failed to fetch contest {contest_id}: {e}")
async def scrape_contest_list(self) -> ContestListResult: async def scrape_contest_list(self) -> ContestListResult:
async with httpx.AsyncClient( async with httpx.AsyncClient() as client:
limits=httpx.Limits(max_connections=CONNECTIONS)
) as client:
try: try:
data = await fetch_json(client, API_CONTESTS_ALL) data = await fetch_json(client, API_CONTESTS_ALL)
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
return self._contests_error(f"Failed to fetch contests: {e}") return self._contests_error(f"Failed to fetch contests: {e}")
all_contests = data.get("future_contests", []) + data.get(
present = data.get("present_contests", []) "past_contests", []
future = data.get("future_contests", []) )
max_num = 0
async def fetch_past_page(offset: int) -> list[dict[str, Any]]: for contest in all_contests:
r = await client.get( contest_code = contest.get("contest_code", "")
BASE_URL + API_CONTESTS_PAST, if contest_code.startswith("START"):
params={ match = re.match(r"START(\d+)", contest_code)
"sort_by": "START", if match:
"sorting_order": "desc", num = int(match.group(1))
"offset": offset, max_num = max(max_num, num)
}, if max_num == 0:
headers=HEADERS, return self._contests_error("No Starters contests found")
timeout=HTTP_TIMEOUT, contests = []
)
r.raise_for_status()
return r.json().get("contests", [])
past: list[dict[str, Any]] = []
offset = 0
while True:
page = await fetch_past_page(offset)
past.extend(
c for c in page if re.match(r"^START\d+", c.get("contest_code", ""))
)
if len(page) < 20:
break
offset += 20
raw: list[dict[str, Any]] = []
seen_raw: set[str] = set()
for c in present + future + past:
code = c.get("contest_code", "")
if not code or code in seen_raw:
continue
seen_raw.add(code)
raw.append(c)
sem = asyncio.Semaphore(CONNECTIONS) sem = asyncio.Semaphore(CONNECTIONS)
async def expand(c: dict[str, Any]) -> list[ContestSummary]: async def fetch_divisions(i: int) -> list[ContestSummary]:
code = c["contest_code"] parent_id = f"START{i}"
name = c.get("contest_name", code) async with sem:
start_time: int | None = None
iso = c.get("contest_start_date_iso")
if iso:
try: try:
start_time = int(datetime.fromisoformat(iso).timestamp()) parent_data = await fetch_json(
except Exception: client, API_CONTEST.format(contest_id=parent_id)
pass
base_name = re.sub(r"\s*\(.*?\)\s*$", "", name).strip()
try:
async with sem:
detail = await fetch_json(
client, API_CONTEST.format(contest_id=code)
) )
children = detail.get("child_contests") except Exception as e:
if children and isinstance(children, dict): import sys
divs: list[ContestSummary] = []
for div_key in ("div_1", "div_2", "div_3", "div_4"): print(f"Error fetching {parent_id}: {e}", file=sys.stderr)
child = children.get(div_key) return []
if not child: child_contests = parent_data.get("child_contests", {})
continue if not child_contests:
child_code = child.get("contest_code") return []
div_num = child.get("div", {}).get( base_name = f"Starters {i}"
"div_number", div_key[-1] divisions = []
for div_key, div_data in child_contests.items():
div_code = div_data.get("contest_code", "")
div_num = div_data.get("div", {}).get("div_number", "")
if div_code and div_num:
divisions.append(
ContestSummary(
id=div_code,
name=base_name,
display_name=f"{base_name} (Div. {div_num})",
) )
if child_code: )
display = f"{base_name} (Div. {div_num})" return divisions
divs.append(
ContestSummary(
id=child_code,
name=display,
display_name=display,
start_time=start_time,
)
)
if divs:
return divs
except Exception:
pass
return [
ContestSummary(
id=code, name=name, display_name=name, start_time=start_time
)
]
results = await asyncio.gather(*[expand(c) for c in raw]) tasks = [fetch_divisions(i) for i in range(1, max_num + 1)]
for coro in asyncio.as_completed(tasks):
contests: list[ContestSummary] = [] divisions = await coro
seen: set[str] = set() contests.extend(divisions)
for group in results:
for entry in group:
if entry.id not in seen:
seen.add(entry.id)
contests.append(entry)
if not contests:
return self._contests_error("No contests found")
return ContestListResult(success=True, error="", contests=contests) return ContestListResult(success=True, error="", contests=contests)
async def stream_tests_for_category_async(self, category_id: str) -> None: async def stream_tests_for_category_async(self, category_id: str) -> None:
@ -453,15 +163,6 @@ class CodeChefScraper(BaseScraper):
) )
return return
all_problems = contest_data.get("problems", {}) all_problems = contest_data.get("problems", {})
if not all_problems and isinstance(
contest_data.get("child_contests"), dict
):
for div in ("div_4", "div_3", "div_2", "div_1"):
child = contest_data["child_contests"].get(div, {})
child_code = child.get("contest_code")
if child_code:
await self.stream_tests_for_category_async(child_code)
return
if not all_problems: if not all_problems:
print( print(
json.dumps( json.dumps(
@ -510,15 +211,18 @@ class CodeChefScraper(BaseScraper):
] ]
time_limit_str = problem_data.get("max_timelimit", "1") time_limit_str = problem_data.get("max_timelimit", "1")
timeout_ms = int(float(time_limit_str) * 1000) timeout_ms = int(float(time_limit_str) * 1000)
memory_mb = 256.0 problem_url = PROBLEM_URL.format(problem_id=problem_code)
loop = asyncio.get_event_loop()
html = await loop.run_in_executor(
None, _fetch_html_sync, problem_url
)
memory_mb = _extract_memory_limit(html)
interactive = False interactive = False
precision = None
except Exception: except Exception:
tests = [] tests = []
timeout_ms = 1000 timeout_ms = 1000
memory_mb = 256.0 memory_mb = 256.0
interactive = False interactive = False
precision = None
combined_input = "\n".join(t.input for t in tests) if tests else "" combined_input = "\n".join(t.input for t in tests) if tests else ""
combined_expected = ( combined_expected = (
"\n".join(t.expected for t in tests) if tests else "" "\n".join(t.expected for t in tests) if tests else ""
@ -536,7 +240,6 @@ class CodeChefScraper(BaseScraper):
"memory_mb": memory_mb, "memory_mb": memory_mb,
"interactive": interactive, "interactive": interactive,
"multi_test": False, "multi_test": False,
"precision": precision,
} }
tasks = [run_one(problem_code) for problem_code in problems.keys()] tasks = [run_one(problem_code) for problem_code in problems.keys()]
@ -544,30 +247,6 @@ class CodeChefScraper(BaseScraper):
payload = await coro payload = await coro
print(json.dumps(payload), flush=True) print(json.dumps(payload), flush=True)
async def submit(
self,
contest_id: str,
problem_id: str,
file_path: str,
language_id: str,
credentials: dict[str, str],
) -> SubmitResult:
if not credentials.get("username") or not credentials.get("password"):
return self._submit_error("Missing credentials. Use :CP codechef login")
return await asyncio.to_thread(
_submit_headless_codechef,
contest_id,
problem_id,
file_path,
language_id,
credentials,
)
async def login(self, credentials: dict[str, str]) -> LoginResult:
if not credentials.get("username") or not credentials.get("password"):
return self._login_error("Missing username or password")
return await asyncio.to_thread(_login_headless_codechef, credentials)
if __name__ == "__main__": if __name__ == "__main__":
CodeChefScraper().run_cli() CodeChefScraper().run_cli()

View file

@ -2,31 +2,30 @@
import asyncio import asyncio
import json import json
import logging
import re import re
from typing import Any from typing import Any
import requests import requests
from bs4 import BeautifulSoup, Tag from bs4 import BeautifulSoup, Tag
from scrapling.fetchers import Fetcher
from .base import BaseScraper, extract_precision from .base import BaseScraper
from .models import ( from .models import (
ContestListResult, ContestListResult,
ContestSummary, ContestSummary,
LoginResult,
MetadataResult, MetadataResult,
ProblemSummary, ProblemSummary,
SubmitResult,
TestCase, TestCase,
) )
from .timeouts import (
BROWSER_NAV_TIMEOUT, # suppress scrapling logging - https://github.com/D4Vinci/Scrapling/issues/31)
BROWSER_SESSION_TIMEOUT, logging.getLogger("scrapling").setLevel(logging.CRITICAL)
BROWSER_SUBMIT_NAV_TIMEOUT,
HTTP_TIMEOUT,
)
BASE_URL = "https://codeforces.com" BASE_URL = "https://codeforces.com"
API_CONTEST_LIST_URL = f"{BASE_URL}/api/contest.list" API_CONTEST_LIST_URL = f"{BASE_URL}/api/contest.list"
TIMEOUT_SECONDS = 30
HEADERS = { HEADERS = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
} }
@ -84,7 +83,7 @@ def _extract_title(block: Tag) -> tuple[str, str]:
def _extract_samples(block: Tag) -> tuple[list[TestCase], bool]: def _extract_samples(block: Tag) -> tuple[list[TestCase], bool]:
st = block.find("div", class_="sample-test") st = block.find("div", class_="sample-test")
if not isinstance(st, Tag): if not st:
return [], False return [], False
input_pres: list[Tag] = [ input_pres: list[Tag] = [
@ -140,30 +139,11 @@ def _is_interactive(block: Tag) -> bool:
def _fetch_problems_html(contest_id: str) -> str: def _fetch_problems_html(contest_id: str) -> str:
try:
from scrapling.fetchers import StealthySession # type: ignore[import-untyped,unresolved-import]
except ImportError:
raise RuntimeError("scrapling is required for Codeforces metadata")
from .atcoder import _ensure_browser
_ensure_browser()
url = f"{BASE_URL}/contest/{contest_id}/problems" url = f"{BASE_URL}/contest/{contest_id}/problems"
html = "" page = Fetcher.get(
url,
def page_action(page): )
nonlocal html return page.html_content
html = page.content()
with StealthySession(
headless=True,
timeout=BROWSER_SESSION_TIMEOUT,
google_search=False,
) as session:
session.fetch(url, page_action=page_action, solve_cloudflare=True)
return html
def _parse_all_blocks(html: str) -> list[dict[str, Any]]: def _parse_all_blocks(html: str) -> list[dict[str, Any]]:
@ -179,7 +159,6 @@ def _parse_all_blocks(html: str) -> list[dict[str, Any]]:
raw_samples, is_grouped = _extract_samples(b) raw_samples, is_grouped = _extract_samples(b)
timeout_ms, memory_mb = _extract_limits(b) timeout_ms, memory_mb = _extract_limits(b)
interactive = _is_interactive(b) interactive = _is_interactive(b)
precision = extract_precision(b.get_text(" ", strip=True))
if is_grouped and raw_samples: if is_grouped and raw_samples:
combined_input = f"{len(raw_samples)}\n" + "\n".join( combined_input = f"{len(raw_samples)}\n" + "\n".join(
@ -206,7 +185,6 @@ def _parse_all_blocks(html: str) -> list[dict[str, Any]]:
"memory_mb": memory_mb, "memory_mb": memory_mb,
"interactive": interactive, "interactive": interactive,
"multi_test": is_grouped, "multi_test": is_grouped,
"precision": precision,
} }
) )
return out return out
@ -242,15 +220,13 @@ class CodeforcesScraper(BaseScraper):
contest_id=contest_id, contest_id=contest_id,
problems=problems, problems=problems,
url=f"https://codeforces.com/contest/{contest_id}/problem/%s", url=f"https://codeforces.com/contest/{contest_id}/problem/%s",
contest_url=f"https://codeforces.com/contest/{contest_id}",
standings_url=f"https://codeforces.com/contest/{contest_id}/standings",
) )
except Exception as e: except Exception as e:
return self._metadata_error(str(e)) return self._metadata_error(str(e))
async def scrape_contest_list(self) -> ContestListResult: async def scrape_contest_list(self) -> ContestListResult:
try: try:
r = requests.get(API_CONTEST_LIST_URL, timeout=HTTP_TIMEOUT) r = requests.get(API_CONTEST_LIST_URL, timeout=TIMEOUT_SECONDS)
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
if data.get("status") != "OK": if data.get("status") != "OK":
@ -258,20 +234,11 @@ class CodeforcesScraper(BaseScraper):
contests: list[ContestSummary] = [] contests: list[ContestSummary] = []
for c in data["result"]: for c in data["result"]:
phase = c.get("phase") if c.get("phase") != "FINISHED":
if phase not in ("FINISHED", "BEFORE", "CODING"):
continue continue
cid = str(c["id"]) cid = str(c["id"])
name = c["name"] name = c["name"]
start_time = c.get("startTimeSeconds") if phase != "FINISHED" else None contests.append(ContestSummary(id=cid, name=name, display_name=name))
contests.append(
ContestSummary(
id=cid,
name=name,
display_name=name,
start_time=start_time,
)
)
if not contests: if not contests:
return self._contests_error("No contests found") return self._contests_error("No contests found")
@ -302,283 +269,11 @@ class CodeforcesScraper(BaseScraper):
"memory_mb": b.get("memory_mb", 0), "memory_mb": b.get("memory_mb", 0),
"interactive": bool(b.get("interactive")), "interactive": bool(b.get("interactive")),
"multi_test": bool(b.get("multi_test", False)), "multi_test": bool(b.get("multi_test", False)),
"precision": b.get("precision"),
} }
), ),
flush=True, flush=True,
) )
async def submit(
self,
contest_id: str,
problem_id: str,
file_path: str,
language_id: str,
credentials: dict[str, str],
) -> SubmitResult:
return await asyncio.to_thread(
_submit_headless,
contest_id,
problem_id,
file_path,
language_id,
credentials,
)
async def login(self, credentials: dict[str, str]) -> LoginResult:
if not credentials.get("username") or not credentials.get("password"):
return self._login_error("Missing username or password")
return await asyncio.to_thread(_login_headless_cf, credentials)
def _login_headless_cf(credentials: dict[str, str]) -> LoginResult:
from pathlib import Path
try:
from scrapling.fetchers import StealthySession # type: ignore[import-untyped,unresolved-import]
except ImportError:
return LoginResult(
success=False,
error="scrapling is required for Codeforces login",
)
from .atcoder import _ensure_browser
_ensure_browser()
cookie_cache = Path.home() / ".cache" / "cp-nvim" / "codeforces-cookies.json"
cookie_cache.parent.mkdir(parents=True, exist_ok=True)
logged_in = False
login_error: str | None = None
def check_login(page):
nonlocal logged_in
logged_in = page.evaluate(
"() => Array.from(document.querySelectorAll('a'))"
".some(a => a.textContent.includes('Logout'))"
)
def login_action(page):
nonlocal login_error
try:
page.fill(
'input[name="handleOrEmail"]',
credentials.get("username", ""),
)
page.fill(
'input[name="password"]',
credentials.get("password", ""),
)
page.locator('#enterForm input[type="submit"]').click()
page.wait_for_url(
lambda url: "/enter" not in url, timeout=BROWSER_NAV_TIMEOUT
)
except Exception as e:
login_error = str(e)
try:
with StealthySession(
headless=True,
timeout=BROWSER_SESSION_TIMEOUT,
google_search=False,
) as session:
print(json.dumps({"status": "logging_in"}), flush=True)
session.fetch(
f"{BASE_URL}/enter",
page_action=login_action,
solve_cloudflare=True,
)
if login_error:
return LoginResult(success=False, error=f"Login failed: {login_error}")
session.fetch(
f"{BASE_URL}/",
page_action=check_login,
network_idle=True,
)
if not logged_in:
return LoginResult(
success=False, error="Login failed (bad credentials?)"
)
try:
browser_cookies = session.context.cookies()
if any(c.get("name") == "X-User-Handle" for c in browser_cookies):
cookie_cache.write_text(json.dumps(browser_cookies))
except Exception:
pass
return LoginResult(success=True, error="")
except Exception as e:
return LoginResult(success=False, error=str(e))
def _submit_headless(
contest_id: str,
problem_id: str,
file_path: str,
language_id: str,
credentials: dict[str, str],
_retried: bool = False,
) -> SubmitResult:
from pathlib import Path
source_code = Path(file_path).read_text()
try:
from scrapling.fetchers import StealthySession # type: ignore[import-untyped,unresolved-import]
except ImportError:
return SubmitResult(
success=False,
error="scrapling is required for Codeforces submit",
)
from .atcoder import _ensure_browser, _solve_turnstile
_ensure_browser()
cookie_cache = Path.home() / ".cache" / "cp-nvim" / "codeforces-cookies.json"
cookie_cache.parent.mkdir(parents=True, exist_ok=True)
saved_cookies: list[dict[str, Any]] = []
if cookie_cache.exists():
try:
saved_cookies = json.loads(cookie_cache.read_text())
except Exception:
pass
logged_in = cookie_cache.exists() and not _retried
login_error: str | None = None
submit_error: str | None = None
needs_relogin = False
def check_login(page):
nonlocal logged_in
logged_in = page.evaluate(
"() => Array.from(document.querySelectorAll('a'))"
".some(a => a.textContent.includes('Logout'))"
)
def login_action(page):
nonlocal login_error
try:
page.fill(
'input[name="handleOrEmail"]',
credentials.get("username", ""),
)
page.fill(
'input[name="password"]',
credentials.get("password", ""),
)
page.locator('#enterForm input[type="submit"]').click()
page.wait_for_url(
lambda url: "/enter" not in url, timeout=BROWSER_NAV_TIMEOUT
)
except Exception as e:
login_error = str(e)
def submit_action(page):
nonlocal submit_error, needs_relogin
if "/enter" in page.url or "/login" in page.url:
needs_relogin = True
return
_solve_turnstile(page)
try:
page.select_option(
'select[name="submittedProblemIndex"]',
problem_id.upper(),
)
page.select_option('select[name="programTypeId"]', language_id)
page.evaluate(
"""(code) => {
const cm = document.querySelector('.CodeMirror');
if (cm && cm.CodeMirror) {
cm.CodeMirror.setValue(code);
}
const ta = document.querySelector('textarea[name="source"]');
if (ta) ta.value = code;
}""",
source_code,
)
page.locator("form.submit-form input.submit").click(no_wait_after=True)
try:
page.wait_for_url(
lambda url: "/my" in url or "/status" in url,
timeout=BROWSER_SUBMIT_NAV_TIMEOUT["codeforces"],
)
except Exception:
err_el = page.query_selector("span.error")
if err_el:
submit_error = err_el.inner_text().strip()
else:
submit_error = "Submit failed: page did not navigate"
except Exception as e:
submit_error = str(e)
try:
with StealthySession(
headless=True,
timeout=BROWSER_SESSION_TIMEOUT,
google_search=False,
cookies=saved_cookies if (cookie_cache.exists() and not _retried) else [],
) as session:
if not (cookie_cache.exists() and not _retried):
print(json.dumps({"status": "checking_login"}), flush=True)
session.fetch(
f"{BASE_URL}/",
page_action=check_login,
network_idle=True,
)
if not logged_in:
print(json.dumps({"status": "logging_in"}), flush=True)
session.fetch(
f"{BASE_URL}/enter",
page_action=login_action,
solve_cloudflare=True,
)
if login_error:
return SubmitResult(
success=False, error=f"Login failed: {login_error}"
)
print(json.dumps({"status": "submitting"}), flush=True)
session.fetch(
f"{BASE_URL}/contest/{contest_id}/submit",
page_action=submit_action,
solve_cloudflare=False,
)
try:
browser_cookies = session.context.cookies()
if any(c.get("name") == "X-User-Handle" for c in browser_cookies):
cookie_cache.write_text(json.dumps(browser_cookies))
except Exception:
pass
if needs_relogin and not _retried:
cookie_cache.unlink(missing_ok=True)
return _submit_headless(
contest_id,
problem_id,
file_path,
language_id,
credentials,
_retried=True,
)
if submit_error:
return SubmitResult(success=False, error=submit_error)
return SubmitResult(
success=True,
error="",
submission_id="",
verdict="submitted",
)
except Exception as e:
return SubmitResult(success=False, error=str(e))
if __name__ == "__main__": if __name__ == "__main__":
CodeforcesScraper().run_cli() CodeforcesScraper().run_cli()

View file

@ -1,51 +1,30 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import asyncio import asyncio
import base64
import json import json
import re import re
from typing import Any from typing import Any
import httpx import httpx
from .base import BaseScraper, extract_precision from .base import BaseScraper
from .timeouts import HTTP_TIMEOUT
from .models import ( from .models import (
ContestListResult, ContestListResult,
ContestSummary, ContestSummary,
LoginResult,
MetadataResult, MetadataResult,
ProblemSummary, ProblemSummary,
SubmitResult,
TestCase, TestCase,
) )
BASE_URL = "https://cses.fi" BASE_URL = "https://cses.fi"
API_URL = "https://cses.fi/api"
SUBMIT_SCOPE = "courses/problemset"
INDEX_PATH = "/problemset" INDEX_PATH = "/problemset"
TASK_PATH = "/problemset/task/{id}" TASK_PATH = "/problemset/task/{id}"
HEADERS = { HEADERS = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
} }
TIMEOUT_S = 15.0
CONNECTIONS = 8 CONNECTIONS = 8
CSES_LANGUAGES: dict[str, dict[str, str]] = {
"C++17": {"name": "C++", "option": "C++17"},
"Python3": {"name": "Python3", "option": "CPython3"},
"PyPy3": {"name": "Python3", "option": "PyPy3"},
"Java": {"name": "Java", "option": "Java"},
"Rust2021": {"name": "Rust", "option": "2021"},
}
EXTENSIONS: dict[str, str] = {
"C++17": "cpp",
"Python3": "py",
"PyPy3": "py",
"Java": "java",
"Rust2021": "rs",
}
def normalize_category_name(category_name: str) -> str: def normalize_category_name(category_name: str) -> str:
return category_name.lower().replace(" ", "_").replace("&", "and") return category_name.lower().replace(" ", "_").replace("&", "and")
@ -85,7 +64,7 @@ def snake_to_title(name: str) -> str:
async def fetch_text(client: httpx.AsyncClient, path: str) -> str: async def fetch_text(client: httpx.AsyncClient, path: str) -> str:
r = await client.get(BASE_URL + path, headers=HEADERS, timeout=HTTP_TIMEOUT) r = await client.get(BASE_URL + path, headers=HEADERS, timeout=TIMEOUT_S)
r.raise_for_status() r.raise_for_status()
return r.text return r.text
@ -150,21 +129,17 @@ def parse_category_problems(category_id: str, html: str) -> list[ProblemSummary]
return [] return []
def _extract_problem_info(html: str) -> tuple[int, int, bool, float | None]: def _extract_problem_info(html: str) -> tuple[int, int, bool]:
tm = TIME_RE.search(html) tm = TIME_RE.search(html)
mm = MEM_RE.search(html) mm = MEM_RE.search(html)
t = int(round(float(tm.group(1)) * 1000)) if tm else 0 t = int(round(float(tm.group(1)) * 1000)) if tm else 0
m = int(mm.group(1)) if mm else 0 m = int(mm.group(1)) if mm else 0
md = MD_BLOCK_RE.search(html) md = MD_BLOCK_RE.search(html)
interactive = False interactive = False
precision = None
if md: if md:
body = md.group(1) body = md.group(1)
interactive = "This is an interactive problem." in body interactive = "This is an interactive problem." in body
from bs4 import BeautifulSoup return t, m, interactive
precision = extract_precision(BeautifulSoup(body, "html.parser").get_text(" "))
return t, m, interactive, precision
def parse_title(html: str) -> str: def parse_title(html: str) -> str:
@ -224,8 +199,6 @@ class CSESScraper(BaseScraper):
contest_id=contest_id, contest_id=contest_id,
problems=problems, problems=problems,
url="https://cses.fi/problemset/task/%s", url="https://cses.fi/problemset/task/%s",
contest_url="https://cses.fi/problemset",
standings_url="",
) )
async def scrape_contest_list(self) -> ContestListResult: async def scrape_contest_list(self) -> ContestListResult:
@ -234,35 +207,9 @@ class CSESScraper(BaseScraper):
cats = parse_categories(html) cats = parse_categories(html)
if not cats: if not cats:
return ContestListResult( return ContestListResult(
success=False, success=False, error=f"{self.platform_name}: No contests found"
error=f"{self.platform_name}: No contests found",
supports_countdown=False,
)
return ContestListResult(
success=True, error="", contests=cats, supports_countdown=False
)
async def login(self, credentials: dict[str, str]) -> LoginResult:
username = credentials.get("username", "")
password = credentials.get("password", "")
if not username or not password:
return self._login_error("Missing username or password")
async with httpx.AsyncClient(follow_redirects=True) as client:
print(json.dumps({"status": "logging_in"}), flush=True)
token = await self._web_login(client, username, password)
if not token:
return self._login_error("Login failed (bad credentials?)")
return LoginResult(
success=True,
error="",
credentials={
"username": username,
"password": password,
"token": token,
},
) )
return ContestListResult(success=True, error="", contests=cats)
async def stream_tests_for_category_async(self, category_id: str) -> None: async def stream_tests_for_category_async(self, category_id: str) -> None:
async with httpx.AsyncClient( async with httpx.AsyncClient(
@ -280,17 +227,10 @@ class CSESScraper(BaseScraper):
try: try:
html = await fetch_text(client, task_path(pid)) html = await fetch_text(client, task_path(pid))
tests = parse_tests(html) tests = parse_tests(html)
timeout_ms, memory_mb, interactive, precision = ( timeout_ms, memory_mb, interactive = _extract_problem_info(html)
_extract_problem_info(html)
)
except Exception: except Exception:
tests = [] tests = []
timeout_ms, memory_mb, interactive, precision = ( timeout_ms, memory_mb, interactive = 0, 0, False
0,
0,
False,
None,
)
combined_input = "\n".join(t.input for t in tests) if tests else "" combined_input = "\n".join(t.input for t in tests) if tests else ""
combined_expected = ( combined_expected = (
@ -310,7 +250,6 @@ class CSESScraper(BaseScraper):
"memory_mb": memory_mb, "memory_mb": memory_mb,
"interactive": interactive, "interactive": interactive,
"multi_test": False, "multi_test": False,
"precision": precision,
} }
tasks = [run_one(p.id) for p in problems] tasks = [run_one(p.id) for p in problems]
@ -318,156 +257,6 @@ class CSESScraper(BaseScraper):
payload = await coro payload = await coro
print(json.dumps(payload), flush=True) print(json.dumps(payload), flush=True)
async def _web_login(
self,
client: httpx.AsyncClient,
username: str,
password: str,
) -> str | None:
login_page = await client.get(
f"{BASE_URL}/login", headers=HEADERS, timeout=HTTP_TIMEOUT
)
csrf_match = re.search(r'name="csrf_token" value="([^"]+)"', login_page.text)
if not csrf_match:
return None
login_resp = await client.post(
f"{BASE_URL}/login",
data={
"csrf_token": csrf_match.group(1),
"nick": username,
"pass": password,
},
headers=HEADERS,
timeout=HTTP_TIMEOUT,
)
if "Invalid username or password" in login_resp.text:
return None
api_resp = await client.post(
f"{API_URL}/login", headers=HEADERS, timeout=HTTP_TIMEOUT
)
api_data = api_resp.json()
token: str | None = api_data.get("X-Auth-Token")
auth_url: str | None = api_data.get("authentication_url")
if not token:
raise RuntimeError("CSES API login response missing 'X-Auth-Token'")
if not auth_url:
raise RuntimeError("CSES API login response missing 'authentication_url'")
auth_page = await client.get(auth_url, headers=HEADERS, timeout=HTTP_TIMEOUT)
auth_csrf = re.search(r'name="csrf_token" value="([^"]+)"', auth_page.text)
form_token = re.search(r'name="token" value="([^"]+)"', auth_page.text)
if not auth_csrf or not form_token:
return None
await client.post(
auth_url,
data={
"csrf_token": auth_csrf.group(1),
"token": form_token.group(1),
},
headers=HEADERS,
timeout=HTTP_TIMEOUT,
)
check = await client.get(
f"{API_URL}/login",
headers={"X-Auth-Token": token, **HEADERS},
timeout=HTTP_TIMEOUT,
)
if check.status_code != 200:
return None
return token
async def _check_token(self, client: httpx.AsyncClient, token: str) -> bool:
try:
r = await client.get(
f"{API_URL}/login",
headers={"X-Auth-Token": token, **HEADERS},
timeout=HTTP_TIMEOUT,
)
return r.status_code == 200
except (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError):
raise
async def submit(
self,
contest_id: str,
problem_id: str,
file_path: str,
language_id: str,
credentials: dict[str, str],
) -> SubmitResult:
from pathlib import Path
source_code = Path(file_path).read_text()
username = credentials.get("username", "")
password = credentials.get("password", "")
if not username or not password:
return self._submit_error("Missing credentials. Use :CP login cses")
async with httpx.AsyncClient(follow_redirects=True) as client:
token = credentials.get("token")
if token:
print(json.dumps({"status": "checking_login"}), flush=True)
if not await self._check_token(client, token):
token = None
if not token:
print(json.dumps({"status": "logging_in"}), flush=True)
token = await self._web_login(client, username, password)
if not token:
return self._submit_error("Login failed (bad credentials?)")
print(
json.dumps(
{
"credentials": {
"username": username,
"password": password,
"token": token,
}
}
),
flush=True,
)
print(json.dumps({"status": "submitting"}), flush=True)
ext = EXTENSIONS.get(language_id, "cpp")
lang = CSES_LANGUAGES.get(language_id, {"name": "C++", "option": "C++17"})
content_b64 = base64.b64encode(source_code.encode()).decode()
payload: dict[str, Any] = {
"language": lang,
"filename": f"{problem_id}.{ext}",
"content": content_b64,
}
r = await client.post(
f"{API_URL}/{SUBMIT_SCOPE}/submissions",
json=payload,
params={"task": problem_id},
headers={
"X-Auth-Token": token,
"Content-Type": "application/json",
**HEADERS,
},
timeout=HTTP_TIMEOUT,
)
if r.status_code not in range(200, 300):
try:
err = r.json().get("message", r.text)
except Exception:
err = r.text
return self._submit_error(f"Submit request failed: {err}")
submission_id = str(r.json().get("id", ""))
return SubmitResult(success=True, error="", submission_id=submission_id)
if __name__ == "__main__": if __name__ == "__main__":
CSESScraper().run_cli() CSESScraper().run_cli()

View file

@ -1,414 +0,0 @@
#!/usr/bin/env python3
import asyncio
import io
import json
import re
import zipfile
from datetime import datetime
from pathlib import Path
import httpx
from .base import BaseScraper, extract_precision
from .timeouts import HTTP_TIMEOUT
from .models import (
ContestListResult,
ContestSummary,
LoginResult,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "https://open.kattis.com"
HEADERS = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
}
CONNECTIONS = 8
_COOKIE_PATH = Path.home() / ".cache" / "cp-nvim" / "kattis-cookies.json"
TIME_RE = re.compile(
r"CPU Time limit</span>\s*<span[^>]*>\s*(\d+)\s*seconds?\s*</span>",
re.DOTALL,
)
MEM_RE = re.compile(
r"Memory limit</span>\s*<span[^>]*>\s*(\d+)\s*MB\s*</span>",
re.DOTALL,
)
async def _fetch_text(client: httpx.AsyncClient, url: str) -> str:
r = await client.get(url, headers=HEADERS, timeout=HTTP_TIMEOUT)
r.raise_for_status()
return r.text
async def _fetch_bytes(client: httpx.AsyncClient, url: str) -> bytes:
r = await client.get(url, headers=HEADERS, timeout=HTTP_TIMEOUT)
r.raise_for_status()
return r.content
def _parse_limits(html: str) -> tuple[int, int]:
tm = TIME_RE.search(html)
mm = MEM_RE.search(html)
timeout_ms = int(tm.group(1)) * 1000 if tm else 1000
memory_mb = int(mm.group(1)) if mm else 1024
return timeout_ms, memory_mb
def _parse_samples_html(html: str) -> list[TestCase]:
tests: list[TestCase] = []
tables = re.finditer(r'<table\s+class="sample"[^>]*>.*?</table>', html, re.DOTALL)
for table_match in tables:
table_html = table_match.group(0)
pres = re.findall(r"<pre>(.*?)</pre>", table_html, re.DOTALL)
if len(pres) >= 2:
inp = pres[0].strip()
out = pres[1].strip()
tests.append(TestCase(input=inp, expected=out))
return tests
def _parse_samples_zip(data: bytes) -> list[TestCase]:
try:
zf = zipfile.ZipFile(io.BytesIO(data))
except zipfile.BadZipFile:
return []
inputs: dict[str, str] = {}
outputs: dict[str, str] = {}
for name in zf.namelist():
content = zf.read(name).decode("utf-8").strip()
if name.endswith(".in"):
key = name[: -len(".in")]
inputs[key] = content
elif name.endswith(".ans"):
key = name[: -len(".ans")]
outputs[key] = content
tests: list[TestCase] = []
for key in sorted(set(inputs) & set(outputs)):
tests.append(TestCase(input=inputs[key], expected=outputs[key]))
return tests
def _is_interactive(html: str) -> bool:
return "This is an interactive problem" in html
def _parse_contests_page(html: str) -> list[ContestSummary]:
results: list[ContestSummary] = []
seen: set[str] = set()
for row_m in re.finditer(r"<tr[^>]*>(.*?)</tr>", html, re.DOTALL):
row = row_m.group(1)
link_m = re.search(r'href="/contests/([a-z0-9]+)"[^>]*>([^<]+)</a>', row)
if not link_m:
continue
cid = link_m.group(1)
name = link_m.group(2).strip()
if cid in seen:
continue
seen.add(cid)
start_time: int | None = None
ts_m = re.search(r'data-timestamp="(\d+)"', row)
if ts_m:
start_time = int(ts_m.group(1))
else:
time_m = re.search(r'<time[^>]+datetime="([^"]+)"', row)
if time_m:
try:
dt = datetime.fromisoformat(time_m.group(1).replace("Z", "+00:00"))
start_time = int(dt.timestamp())
except Exception:
pass
results.append(
ContestSummary(id=cid, name=name, display_name=name, start_time=start_time)
)
return results
def _parse_contest_problem_list(html: str) -> list[tuple[str, str]]:
if "The problems will become available when the contest starts" in html:
return []
results: list[tuple[str, str]] = []
seen: set[str] = set()
for row_m in re.finditer(r"<tr[^>]*>(.*?)</tr>", html, re.DOTALL):
row = row_m.group(1)
link_m = re.search(
r'href="/contests/[^/]+/problems/([^"]+)"[^>]*>([^<]+)</a>', row
)
if not link_m:
continue
slug = link_m.group(1)
name = link_m.group(2).strip()
if slug in seen:
continue
seen.add(slug)
label_m = re.search(r"<td[^>]*>\s*([A-Z])\s*</td>", row)
label = label_m.group(1) if label_m else ""
display = f"{label} - {name}" if label else name
results.append((slug, display))
return results
async def _fetch_contest_slugs(
client: httpx.AsyncClient, contest_id: str
) -> list[tuple[str, str]]:
try:
html = await _fetch_text(client, f"{BASE_URL}/contests/{contest_id}/problems")
return _parse_contest_problem_list(html)
except httpx.HTTPStatusError:
return []
except Exception:
return []
async def _stream_single_problem(client: httpx.AsyncClient, slug: str) -> None:
try:
html = await _fetch_text(client, f"{BASE_URL}/problems/{slug}")
except Exception:
return
timeout_ms, memory_mb = _parse_limits(html)
interactive = _is_interactive(html)
precision = extract_precision(html)
tests: list[TestCase] = []
try:
zip_data = await _fetch_bytes(
client,
f"{BASE_URL}/problems/{slug}/file/statement/samples.zip",
)
tests = _parse_samples_zip(zip_data)
except Exception:
tests = _parse_samples_html(html)
combined_input = "\n".join(t.input for t in tests) if tests else ""
combined_expected = "\n".join(t.expected for t in tests) if tests else ""
print(
json.dumps(
{
"problem_id": slug,
"combined": {
"input": combined_input,
"expected": combined_expected,
},
"tests": [{"input": t.input, "expected": t.expected} for t in tests],
"timeout_ms": timeout_ms,
"memory_mb": memory_mb,
"interactive": interactive,
"multi_test": False,
"precision": precision,
}
),
flush=True,
)
async def _load_kattis_cookies(client: httpx.AsyncClient) -> None:
if not _COOKIE_PATH.exists():
return
try:
for k, v in json.loads(_COOKIE_PATH.read_text()).items():
client.cookies.set(k, v)
except Exception:
pass
async def _save_kattis_cookies(client: httpx.AsyncClient) -> None:
cookies = {k: v for k, v in client.cookies.items()}
if cookies:
_COOKIE_PATH.parent.mkdir(parents=True, exist_ok=True)
_COOKIE_PATH.write_text(json.dumps(cookies))
async def _do_kattis_login(
client: httpx.AsyncClient, username: str, password: str
) -> bool:
client.cookies.clear()
r = await client.post(
f"{BASE_URL}/login",
data={"user": username, "password": password, "script": "true"},
headers=HEADERS,
timeout=HTTP_TIMEOUT,
)
return r.status_code == 200
class KattisScraper(BaseScraper):
@property
def platform_name(self) -> str:
return "kattis"
async def scrape_contest_metadata(self, contest_id: str) -> MetadataResult:
try:
async with httpx.AsyncClient() as client:
slugs = await _fetch_contest_slugs(client, contest_id)
if slugs:
return MetadataResult(
success=True,
error="",
contest_id=contest_id,
problems=[
ProblemSummary(id=slug, name=name) for slug, name in slugs
],
url=f"{BASE_URL}/problems/%s",
contest_url=f"{BASE_URL}/contests/{contest_id}",
standings_url=f"{BASE_URL}/contests/{contest_id}/standings",
)
try:
html = await _fetch_text(
client, f"{BASE_URL}/problems/{contest_id}"
)
except Exception as e:
return self._metadata_error(str(e))
title_m = re.search(r"<title>([^<]+)</title>", html)
name = (
title_m.group(1).split("\u2013")[0].strip()
if title_m
else contest_id
)
return MetadataResult(
success=True,
error="",
contest_id=contest_id,
problems=[ProblemSummary(id=contest_id, name=name)],
url=f"{BASE_URL}/problems/%s",
contest_url=f"{BASE_URL}/problems/{contest_id}",
standings_url="",
)
except Exception as e:
return self._metadata_error(str(e))
async def scrape_contest_list(self) -> ContestListResult:
try:
async with httpx.AsyncClient() as client:
html = await _fetch_text(
client,
f"{BASE_URL}/contests?kattis_original=on&kattis_recycled=off&user_created=off",
)
contests = _parse_contests_page(html)
if not contests:
return self._contests_error("No contests found")
return ContestListResult(success=True, error="", contests=contests)
except Exception as e:
return self._contests_error(str(e))
async def stream_tests_for_category_async(self, category_id: str) -> None:
async with httpx.AsyncClient(
limits=httpx.Limits(max_connections=CONNECTIONS)
) as client:
slugs = await _fetch_contest_slugs(client, category_id)
if slugs:
sem = asyncio.Semaphore(CONNECTIONS)
async def emit_one(slug: str, _name: str) -> None:
async with sem:
await _stream_single_problem(client, slug)
await asyncio.gather(*(emit_one(s, n) for s, n in slugs))
return
await _stream_single_problem(client, category_id)
async def submit(
self,
contest_id: str,
problem_id: str,
file_path: str,
language_id: str,
credentials: dict[str, str],
) -> SubmitResult:
source = Path(file_path).read_bytes()
username = credentials.get("username", "")
password = credentials.get("password", "")
if not username or not password:
return self._submit_error("Missing credentials. Use :CP kattis login")
async with httpx.AsyncClient(follow_redirects=True) as client:
await _load_kattis_cookies(client)
if not client.cookies:
print(json.dumps({"status": "logging_in"}), flush=True)
ok = await _do_kattis_login(client, username, password)
if not ok:
return self._submit_error("Login failed (bad credentials?)")
await _save_kattis_cookies(client)
print(json.dumps({"status": "submitting"}), flush=True)
lang_lower = language_id.lower()
mainclass = Path(file_path).stem if "java" in lang_lower else ""
data: dict[str, str] = {
"submit": "true",
"script": "true",
"language": language_id,
"problem": problem_id,
"mainclass": mainclass,
"submit_ctr": "2",
}
if contest_id != problem_id:
data["contest"] = contest_id
async def _do_submit() -> httpx.Response:
return await client.post(
f"{BASE_URL}/submit",
data=data,
files={"sub_file[]": (Path(file_path).name, source, "text/plain")},
headers=HEADERS,
timeout=HTTP_TIMEOUT,
)
try:
r = await _do_submit()
r.raise_for_status()
except Exception as e:
return self._submit_error(f"Submit request failed: {e}")
if r.text == "Request validation failed":
_COOKIE_PATH.unlink(missing_ok=True)
print(json.dumps({"status": "logging_in"}), flush=True)
ok = await _do_kattis_login(client, username, password)
if not ok:
return self._submit_error("Login failed (bad credentials?)")
await _save_kattis_cookies(client)
try:
r = await _do_submit()
r.raise_for_status()
except Exception as e:
return self._submit_error(f"Submit request failed: {e}")
sid_m = re.search(r"Submission ID:\s*(\d+)", r.text, re.IGNORECASE)
if not sid_m:
return self._submit_error(
r.text.strip() or "Submit failed (no submission ID)"
)
return SubmitResult(
success=True,
error="",
submission_id=sid_m.group(1),
verdict="submitted",
)
async def login(self, credentials: dict[str, str]) -> LoginResult:
username = credentials.get("username", "")
password = credentials.get("password", "")
if not username or not password:
return self._login_error("Missing username or password")
async with httpx.AsyncClient(follow_redirects=True) as client:
print(json.dumps({"status": "logging_in"}), flush=True)
ok = await _do_kattis_login(client, username, password)
if not ok:
return self._login_error("Login failed (bad credentials?)")
await _save_kattis_cookies(client)
return LoginResult(
success=True,
error="",
credentials={"username": username, "password": password},
)
if __name__ == "__main__":
KattisScraper().run_cli()

View file

@ -1,146 +0,0 @@
LANGUAGE_IDS = {
"atcoder": {
"cpp": "6017",
"python": "6082",
"java": "6056",
"rust": "6088",
"c": "6014",
"go": "6051",
"haskell": "6052",
"csharp": "6015",
"kotlin": "6062",
"ruby": "6087",
"javascript": "6059",
"typescript": "6100",
"scala": "6090",
"ocaml": "6073",
"dart": "6033",
"elixir": "6038",
"erlang": "6041",
"fsharp": "6042",
"swift": "6095",
"zig": "6111",
"nim": "6072",
"lua": "6067",
"perl": "6076",
"php": "6077",
"pascal": "6075",
"crystal": "6028",
"d": "6030",
"julia": "6114",
"r": "6084",
"commonlisp": "6027",
"scheme": "6092",
"clojure": "6022",
"ada": "6002",
"bash": "6008",
"fortran": "6047",
"gleam": "6049",
"lean": "6065",
"vala": "6106",
"v": "6105",
},
"codeforces": {
"cpp": "89",
"python": "70",
"java": "87",
"kotlin": "99",
"rust": "75",
"go": "32",
"csharp": "96",
"haskell": "12",
"javascript": "55",
"ruby": "67",
"scala": "20",
"ocaml": "19",
"d": "28",
"perl": "13",
"php": "6",
"pascal": "4",
"fsharp": "97",
},
"cses": {
"cpp": "C++17",
"python": "Python3",
"java": "Java",
"rust": "Rust2021",
},
"usaco": {
"cpp": "cpp",
"python": "python",
"java": "java",
},
"kattis": {
"cpp": "C++",
"python": "Python 3",
"java": "Java",
"rust": "Rust",
"ada": "Ada",
"algol60": "Algol 60",
"algol68": "Algol 68",
"apl": "APL",
"bash": "Bash",
"bcpl": "BCPL",
"bqn": "BQN",
"c": "C",
"cobol": "COBOL",
"commonlisp": "Common Lisp",
"crystal": "Crystal",
"csharp": "C#",
"d": "D",
"dart": "Dart",
"elixir": "Elixir",
"erlang": "Erlang",
"forth": "Forth",
"fortran": "Fortran",
"fortran77": "Fortran 77",
"fsharp": "F#",
"gerbil": "Gerbil",
"go": "Go",
"haskell": "Haskell",
"icon": "Icon",
"javascript": "JavaScript (Node.js)",
"julia": "Julia",
"kotlin": "Kotlin",
"lua": "Lua",
"modula2": "Modula-2",
"nim": "Nim",
"objectivec": "Objective-C",
"ocaml": "OCaml",
"octave": "Octave",
"odin": "Odin",
"pascal": "Pascal",
"perl": "Perl",
"php": "PHP",
"pli": "PL/I",
"prolog": "Prolog",
"racket": "Racket",
"ruby": "Ruby",
"scala": "Scala",
"simula": "Simula 67",
"smalltalk": "Smalltalk",
"snobol": "SNOBOL",
"swift": "Swift",
"typescript": "TypeScript",
"visualbasic": "Visual Basic",
"zig": "Zig",
},
"codechef": {
"cpp": "C++",
"python": "PYTH 3",
"java": "JAVA",
"rust": "rust",
"c": "C",
"go": "GO",
"kotlin": "KTLN",
"javascript": "NODEJS",
"typescript": "TS",
"csharp": "C#",
"php": "PHP",
"r": "R",
},
}
def get_language_id(platform: str, language: str) -> str | None:
return LANGUAGE_IDS.get(platform, {}).get(language)

View file

@ -26,7 +26,6 @@ class ContestSummary(BaseModel):
id: str id: str
name: str name: str
display_name: str | None = None display_name: str | None = None
start_time: int | None = None
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
@ -42,15 +41,12 @@ class MetadataResult(ScrapingResult):
contest_id: str = "" contest_id: str = ""
problems: list[ProblemSummary] = Field(default_factory=list) problems: list[ProblemSummary] = Field(default_factory=list)
url: str url: str
contest_url: str = ""
standings_url: str = ""
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
class ContestListResult(ScrapingResult): class ContestListResult(ScrapingResult):
contests: list[ContestSummary] = Field(default_factory=list) contests: list[ContestSummary] = Field(default_factory=list)
supports_countdown: bool = True
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
@ -67,19 +63,6 @@ class TestsResult(ScrapingResult):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
class LoginResult(ScrapingResult):
credentials: dict[str, str] = Field(default_factory=dict)
model_config = ConfigDict(extra="forbid")
class SubmitResult(ScrapingResult):
submission_id: str = ""
verdict: str = ""
model_config = ConfigDict(extra="forbid")
class ScraperConfig(BaseModel): class ScraperConfig(BaseModel):
timeout_seconds: int = 30 timeout_seconds: int = 30
max_retries: int = 3 max_retries: int = 3

View file

@ -1,14 +0,0 @@
from collections import defaultdict
HTTP_TIMEOUT = 15.0
BROWSER_SESSION_TIMEOUT = 15000
BROWSER_NAV_TIMEOUT = 10000
BROWSER_SUBMIT_NAV_TIMEOUT: defaultdict[str, int] = defaultdict(
lambda: BROWSER_NAV_TIMEOUT
)
BROWSER_SUBMIT_NAV_TIMEOUT["atcoder"] = BROWSER_NAV_TIMEOUT * 2
BROWSER_SUBMIT_NAV_TIMEOUT["codeforces"] = BROWSER_NAV_TIMEOUT * 2
BROWSER_TURNSTILE_POLL = 5000
BROWSER_ELEMENT_WAIT = 10000
BROWSER_SETTLE_DELAY = 500

View file

@ -1,534 +0,0 @@
#!/usr/bin/env python3
import asyncio
import json
import re
from pathlib import Path
from typing import Any, cast
import httpx
from .base import BaseScraper, extract_precision
from .timeouts import HTTP_TIMEOUT
from .models import (
ContestListResult,
ContestSummary,
LoginResult,
MetadataResult,
ProblemSummary,
SubmitResult,
TestCase,
)
BASE_URL = "http://www.usaco.org"
_AUTH_BASE = "https://usaco.org"
HEADERS = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
}
CONNECTIONS = 4
_COOKIE_PATH = Path.home() / ".cache" / "cp-nvim" / "usaco-cookies.json"
_LOGIN_PATH = "/current/tpcm/login-session.php"
_SUBMIT_PATH = "/current/tpcm/submit-solution.php"
_LANG_KEYWORDS: dict[str, list[str]] = {
"cpp": ["c++17", "c++ 17", "g++17", "c++", "cpp"],
"python": ["python3", "python 3", "python"],
"java": ["java"],
}
MONTHS = [
"dec",
"jan",
"feb",
"mar",
"open",
]
DIVISION_HEADING_RE = re.compile(
r"<h2>.*?USACO\s+(\d{4})\s+(\w+)\s+Contest,\s+(\w+)\s*</h2>",
re.IGNORECASE,
)
PROBLEM_BLOCK_RE = re.compile(
r"<b>([^<]+)</b>\s*<br\s*/?>.*?"
r"viewproblem2&cpid=(\d+)",
re.DOTALL,
)
SAMPLE_IN_RE = re.compile(r"<pre\s+class=['\"]in['\"]>(.*?)</pre>", re.DOTALL)
SAMPLE_OUT_RE = re.compile(r"<pre\s+class=['\"]out['\"]>(.*?)</pre>", re.DOTALL)
TIME_NOTE_RE = re.compile(
r"time\s+limit\s+(?:for\s+this\s+problem\s+is\s+)?(\d+)s",
re.IGNORECASE,
)
MEMORY_NOTE_RE = re.compile(
r"memory\s+limit\s+(?:for\s+this\s+problem\s+is\s+)?(\d+)\s*MB",
re.IGNORECASE,
)
RESULTS_PAGE_RE = re.compile(
r'href="index\.php\?page=([a-z]+\d{2,4}results)"',
re.IGNORECASE,
)
async def _fetch_text(client: httpx.AsyncClient, url: str) -> str:
r = await client.get(
url, headers=HEADERS, timeout=HTTP_TIMEOUT, follow_redirects=True
)
r.raise_for_status()
return r.text
def _parse_results_page(html: str) -> dict[str, list[tuple[str, str]]]:
sections: dict[str, list[tuple[str, str]]] = {}
current_div: str | None = None
parts = re.split(r"(<h2>.*?</h2>)", html, flags=re.DOTALL)
for part in parts:
heading_m = DIVISION_HEADING_RE.search(part)
if heading_m:
div = heading_m.group(3)
if div:
key = div.lower()
current_div = key
sections.setdefault(key, [])
continue
if current_div is not None:
for m in PROBLEM_BLOCK_RE.finditer(part):
name = m.group(1).strip()
cpid = m.group(2)
sections[current_div].append((cpid, name))
return sections
def _parse_contest_id(contest_id: str) -> tuple[str, str]:
parts = contest_id.rsplit("_", 1)
if len(parts) != 2:
return contest_id, ""
return parts[0], parts[1].lower()
def _results_page_slug(month_year: str) -> str:
return f"{month_year}results"
def _parse_problem_page(html: str) -> dict[str, Any]:
inputs = SAMPLE_IN_RE.findall(html)
outputs = SAMPLE_OUT_RE.findall(html)
tests: list[TestCase] = []
for inp, out in zip(inputs, outputs):
tests.append(
TestCase(
input=inp.strip().replace("\r", ""),
expected=out.strip().replace("\r", ""),
)
)
tm = TIME_NOTE_RE.search(html)
mm = MEMORY_NOTE_RE.search(html)
timeout_ms = int(tm.group(1)) * 1000 if tm else 4000
memory_mb = int(mm.group(1)) if mm else 256
interactive = "interactive problem" in html.lower()
precision = extract_precision(html)
return {
"tests": tests,
"timeout_ms": timeout_ms,
"memory_mb": memory_mb,
"interactive": interactive,
"precision": precision,
}
def _pick_lang_option(select_body: str, language_id: str) -> str | None:
keywords = _LANG_KEYWORDS.get(language_id.lower(), [language_id.lower()])
options = [
(m.group(1), m.group(2).strip().lower())
for m in re.finditer(
r'<option\b[^>]*\bvalue=["\']([^"\']*)["\'][^>]*>([^<]+)',
select_body,
re.IGNORECASE,
)
]
for kw in keywords:
for val, text in options:
if kw in text:
return val
return None
def _parse_submit_form(
html: str, language_id: str
) -> tuple[str, dict[str, str], str | None]:
form_action = _AUTH_BASE + _SUBMIT_PATH
hidden: dict[str, str] = {}
lang_val: str | None = None
for form_m in re.finditer(
r'<form\b[^>]*action=["\']([^"\']+)["\'][^>]*>(.*?)</form>',
html,
re.DOTALL | re.IGNORECASE,
):
action, body = form_m.group(1), form_m.group(2)
if "sourcefile" not in body.lower():
continue
if action.startswith("http"):
form_action = action
elif action.startswith("/"):
form_action = _AUTH_BASE + action
else:
form_action = _AUTH_BASE + "/" + action
for input_m in re.finditer(
r'<input\b[^>]*\btype=["\']hidden["\'][^>]*/?>',
body,
re.IGNORECASE,
):
tag = input_m.group(0)
name_m = re.search(r'\bname=["\']([^"\']+)["\']', tag, re.IGNORECASE)
val_m = re.search(r'\bvalue=["\']([^"\']*)["\']', tag, re.IGNORECASE)
if name_m and val_m:
hidden[name_m.group(1)] = val_m.group(1)
for sel_m in re.finditer(
r'<select\b[^>]*\bname=["\']([^"\']+)["\'][^>]*>(.*?)</select>',
body,
re.DOTALL | re.IGNORECASE,
):
name, sel_body = sel_m.group(1), sel_m.group(2)
if "lang" in name.lower():
lang_val = _pick_lang_option(sel_body, language_id)
break
break
return form_action, hidden, lang_val
async def _load_usaco_cookies(client: httpx.AsyncClient) -> None:
if not _COOKIE_PATH.exists():
return
try:
for k, v in json.loads(_COOKIE_PATH.read_text()).items():
client.cookies.set(k, v)
except Exception:
pass
async def _save_usaco_cookies(client: httpx.AsyncClient) -> None:
cookies = {k: v for k, v in client.cookies.items()}
if cookies:
_COOKIE_PATH.parent.mkdir(parents=True, exist_ok=True)
_COOKIE_PATH.write_text(json.dumps(cookies))
async def _check_usaco_login(client: httpx.AsyncClient, username: str) -> bool:
try:
r = await client.get(
f"{_AUTH_BASE}/index.php",
headers=HEADERS,
timeout=HTTP_TIMEOUT,
)
text = r.text.lower()
return username.lower() in text or "logout" in text
except Exception:
return False
async def _do_usaco_login(
client: httpx.AsyncClient, username: str, password: str
) -> bool:
r = await client.post(
f"{_AUTH_BASE}{_LOGIN_PATH}",
data={"uname": username, "password": password},
headers=HEADERS,
timeout=HTTP_TIMEOUT,
)
r.raise_for_status()
try:
return r.json().get("code") == 1
except Exception:
return False
class USACOScraper(BaseScraper):
@property
def platform_name(self) -> str:
return "usaco"
async def scrape_contest_metadata(self, contest_id: str) -> MetadataResult:
try:
month_year, division = _parse_contest_id(contest_id)
if not division:
return self._metadata_error(
f"Invalid contest ID '{contest_id}'. "
"Expected format: <monthYY>_<division> (e.g. dec24_gold)"
)
slug = _results_page_slug(month_year)
async with httpx.AsyncClient() as client:
html = await _fetch_text(client, f"{BASE_URL}/index.php?page={slug}")
sections = _parse_results_page(html)
problems_raw = sections.get(division, [])
if not problems_raw:
return self._metadata_error(
f"No problems found for {contest_id} (division: {division})"
)
problems = [
ProblemSummary(id=cpid, name=name) for cpid, name in problems_raw
]
return MetadataResult(
success=True,
error="",
contest_id=contest_id,
problems=problems,
url=f"{BASE_URL}/index.php?page=viewproblem2&cpid=%s",
)
except Exception as e:
return self._metadata_error(str(e))
async def scrape_contest_list(self) -> ContestListResult:
try:
async with httpx.AsyncClient(
limits=httpx.Limits(max_connections=CONNECTIONS)
) as client:
html = await _fetch_text(client, f"{BASE_URL}/index.php?page=contests")
page_slugs: set[str] = set()
for m in RESULTS_PAGE_RE.finditer(html):
page_slugs.add(m.group(1))
recent_patterns = []
for year in range(15, 27):
for month in MONTHS:
recent_patterns.append(f"{month}{year:02d}results")
page_slugs.update(recent_patterns)
contests: list[ContestSummary] = []
sem = asyncio.Semaphore(CONNECTIONS)
async def check_page(slug: str) -> list[ContestSummary]:
async with sem:
try:
page_html = await _fetch_text(
client, f"{BASE_URL}/index.php?page={slug}"
)
except Exception:
return []
sections = _parse_results_page(page_html)
if not sections:
return []
month_year = slug.replace("results", "")
out: list[ContestSummary] = []
for div in sections:
cid = f"{month_year}_{div}"
year_m = re.search(r"\d{2,4}", month_year)
month_m = re.search(r"[a-z]+", month_year)
year_str = year_m.group() if year_m else ""
month_str = month_m.group().capitalize() if month_m else ""
if len(year_str) == 2:
year_str = f"20{year_str}"
display = (
f"USACO {year_str} {month_str} - {div.capitalize()}"
)
out.append(
ContestSummary(id=cid, name=cid, display_name=display)
)
return out
tasks = [check_page(slug) for slug in sorted(page_slugs)]
for coro in asyncio.as_completed(tasks):
contests.extend(await coro)
if not contests:
return ContestListResult(
success=False, error="No contests found", supports_countdown=False
)
return ContestListResult(
success=True, error="", contests=contests, supports_countdown=False
)
except Exception as e:
return ContestListResult(
success=False, error=str(e), supports_countdown=False
)
async def stream_tests_for_category_async(self, category_id: str) -> None:
month_year, division = _parse_contest_id(category_id)
if not division:
return
slug = _results_page_slug(month_year)
async with httpx.AsyncClient(
limits=httpx.Limits(max_connections=CONNECTIONS)
) as client:
try:
html = await _fetch_text(client, f"{BASE_URL}/index.php?page={slug}")
except Exception:
return
sections = _parse_results_page(html)
problems_raw = sections.get(division, [])
if not problems_raw:
return
sem = asyncio.Semaphore(CONNECTIONS)
async def run_one(cpid: str) -> dict[str, Any]:
async with sem:
try:
problem_html = await _fetch_text(
client,
f"{BASE_URL}/index.php?page=viewproblem2&cpid={cpid}",
)
info = _parse_problem_page(problem_html)
except Exception:
info = {
"tests": [],
"timeout_ms": 4000,
"memory_mb": 256,
"interactive": False,
"precision": None,
}
tests = cast(list[TestCase], info["tests"])
combined_input = "\n".join(t.input for t in tests) if tests else ""
combined_expected = (
"\n".join(t.expected for t in tests) if tests else ""
)
return {
"problem_id": cpid,
"combined": {
"input": combined_input,
"expected": combined_expected,
},
"tests": [
{"input": t.input, "expected": t.expected} for t in tests
],
"timeout_ms": info["timeout_ms"],
"memory_mb": info["memory_mb"],
"interactive": info["interactive"],
"multi_test": False,
"precision": info["precision"],
}
tasks = [run_one(cpid) for cpid, _ in problems_raw]
for coro in asyncio.as_completed(tasks):
payload = await coro
print(json.dumps(payload), flush=True)
async def submit(
self,
contest_id: str,
problem_id: str,
file_path: str,
language_id: str,
credentials: dict[str, str],
) -> SubmitResult:
source = Path(file_path).read_bytes()
username = credentials.get("username", "")
password = credentials.get("password", "")
if not username or not password:
return self._submit_error("Missing credentials. Use :CP usaco login")
async with httpx.AsyncClient(follow_redirects=True) as client:
await _load_usaco_cookies(client)
if not client.cookies:
print(json.dumps({"status": "logging_in"}), flush=True)
try:
ok = await _do_usaco_login(client, username, password)
except Exception as e:
return self._submit_error(f"Login failed: {e}")
if not ok:
return self._submit_error("Login failed (bad credentials?)")
await _save_usaco_cookies(client)
result = await self._do_submit(client, problem_id, language_id, source)
if result.success or result.error != "auth_failure":
return result
client.cookies.clear()
print(json.dumps({"status": "logging_in"}), flush=True)
try:
ok = await _do_usaco_login(client, username, password)
except Exception as e:
return self._submit_error(f"Login failed: {e}")
if not ok:
return self._submit_error("Login failed (bad credentials?)")
await _save_usaco_cookies(client)
return await self._do_submit(client, problem_id, language_id, source)
async def _do_submit(
self,
client: httpx.AsyncClient,
problem_id: str,
language_id: str,
source: bytes,
) -> SubmitResult:
print(json.dumps({"status": "submitting"}), flush=True)
try:
page_r = await client.get(
f"{_AUTH_BASE}/index.php?page=viewproblem2&cpid={problem_id}",
headers=HEADERS,
timeout=HTTP_TIMEOUT,
)
if "login" in page_r.url.path.lower() or "Login" in page_r.text[:2000]:
return self._submit_error("auth_failure")
form_url, hidden_fields, lang_val = _parse_submit_form(
page_r.text, language_id
)
except Exception:
form_url = _AUTH_BASE + _SUBMIT_PATH
hidden_fields = {}
lang_val = None
data: dict[str, str] = {"cpid": problem_id, **hidden_fields}
data["language"] = lang_val if lang_val is not None else language_id
ext = "py" if "python" in language_id.lower() else "cpp"
try:
r = await client.post(
form_url,
data=data,
files={"sourcefile": (f"solution.{ext}", source, "text/plain")},
headers=HEADERS,
timeout=HTTP_TIMEOUT,
)
r.raise_for_status()
except Exception as e:
return self._submit_error(f"Submit request failed: {e}")
try:
resp = r.json()
if resp.get("code") == 0 and "login" in resp.get("message", "").lower():
return self._submit_error("auth_failure")
sid = str(resp.get("submission_id", resp.get("id", "")))
except Exception:
sid = ""
return SubmitResult(
success=True, error="", submission_id=sid, verdict="submitted"
)
async def login(self, credentials: dict[str, str]) -> LoginResult:
username = credentials.get("username", "")
password = credentials.get("password", "")
if not username or not password:
return self._login_error("Missing username or password")
async with httpx.AsyncClient(follow_redirects=True) as client:
print(json.dumps({"status": "logging_in"}), flush=True)
try:
ok = await _do_usaco_login(client, username, password)
except Exception as e:
return self._login_error(f"Login request failed: {e}")
if not ok:
return self._login_error("Login failed (bad credentials?)")
await _save_usaco_cookies(client)
return LoginResult(
success=True,
error="",
credentials={"username": username, "password": password},
)
if __name__ == "__main__":
USACOScraper().run_cli()

View file

@ -1,13 +0,0 @@
#!/bin/sh
set -eu
nix develop --command stylua --check .
git ls-files '*.lua' | xargs nix develop --command selene --display-style quiet
nix develop --command prettier --check .
nix fmt
git diff --exit-code -- '*.nix'
nix develop --command lua-language-server --check . --checklevel=Warning
nix develop --command ruff format --check .
nix develop --command ruff check .
nix develop --command ty check .
nix develop --command python -m pytest tests/ -v

View file

@ -1,113 +0,0 @@
#!/usr/bin/env python3
import subprocess
import sys
def main() -> None:
argv = sys.argv[1:]
max_iterations = 1000
timeout = 10
positional: list[str] = []
i = 0
while i < len(argv):
if argv[i] == "--max-iterations" and i + 1 < len(argv):
max_iterations = int(argv[i + 1])
i += 2
elif argv[i] == "--timeout" and i + 1 < len(argv):
timeout = int(argv[i + 1])
i += 2
else:
positional.append(argv[i])
i += 1
if len(positional) != 3:
print(
"Usage: stress.py <generator> <brute> <candidate> "
"[--max-iterations N] [--timeout S]",
file=sys.stderr,
)
sys.exit(1)
generator, brute, candidate = positional
for iteration in range(1, max_iterations + 1):
try:
gen_result = subprocess.run(
generator,
capture_output=True,
text=True,
shell=True,
timeout=timeout,
)
except subprocess.TimeoutExpired:
print(
f"[stress] generator timed out on iteration {iteration}",
file=sys.stderr,
)
sys.exit(1)
if gen_result.returncode != 0:
print(
f"[stress] generator failed on iteration {iteration} "
f"(exit code {gen_result.returncode})",
file=sys.stderr,
)
if gen_result.stderr:
print(gen_result.stderr, file=sys.stderr, end="")
sys.exit(1)
test_input = gen_result.stdout
try:
brute_result = subprocess.run(
brute,
input=test_input,
capture_output=True,
text=True,
shell=True,
timeout=timeout,
)
except subprocess.TimeoutExpired:
print(f"[stress] brute timed out on iteration {iteration}", file=sys.stderr)
print(f"\n--- input ---\n{test_input}", end="")
sys.exit(1)
try:
cand_result = subprocess.run(
candidate,
input=test_input,
capture_output=True,
text=True,
shell=True,
timeout=timeout,
)
except subprocess.TimeoutExpired:
print(
f"[stress] candidate timed out on iteration {iteration}",
file=sys.stderr,
)
print(f"\n--- input ---\n{test_input}", end="")
sys.exit(1)
brute_out = brute_result.stdout.strip()
cand_out = cand_result.stdout.strip()
if brute_out != cand_out:
print(f"[stress] mismatch on iteration {iteration}", file=sys.stderr)
print(f"\n--- input ---\n{test_input}", end="")
print(f"\n--- expected (brute) ---\n{brute_out}")
print(f"\n--- actual (candidate) ---\n{cand_out}")
sys.exit(1)
print(f"[stress] iteration {iteration} OK", file=sys.stderr)
print(
f"[stress] all {max_iterations} iterations passed",
file=sys.stderr,
)
sys.exit(0)
if __name__ == "__main__":
main()

View file

@ -1,4 +1 @@
std = 'vim' std = 'vim'
[lints]
bad_string_escape = 'allow'

View file

@ -7,11 +7,10 @@ from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any from typing import Any
import re
import httpx import httpx
import pytest import pytest
import requests import requests
from scrapling import fetchers
ROOT = Path(__file__).resolve().parent.parent ROOT = Path(__file__).resolve().parent.parent
FIX = Path(__file__).resolve().parent / "fixtures" FIX = Path(__file__).resolve().parent / "fixtures"
@ -105,35 +104,6 @@ def run_scraper_offline(fixture_text):
raise AssertionError(f"No fixture for Codeforces url={url!r}") raise AssertionError(f"No fixture for Codeforces url={url!r}")
def _router_kattis(*, url: str) -> str:
url = url.removeprefix("https://open.kattis.com")
if "/contests?" in url:
return fixture_text("kattis/contests.html")
m = re.search(r"/contests/([^/]+)/problems", url)
if m:
try:
return fixture_text(f"kattis/contest_{m.group(1)}_problems.html")
except FileNotFoundError:
return "<html></html>"
if "/problems/" in url and "/file/statement" not in url:
slug = url.rstrip("/").split("/")[-1]
return fixture_text(f"kattis/problem_{slug}.html")
raise AssertionError(f"No fixture for Kattis url={url!r}")
def _router_usaco(*, url: str) -> str:
if "page=contests" in url and "results" not in url:
return fixture_text("usaco/contests.html")
m = re.search(r"page=([a-z]+\d{2,4}results)", url)
if m:
try:
return fixture_text(f"usaco/{m.group(1)}.html")
except FileNotFoundError:
return "<html></html>"
m = re.search(r"page=viewproblem2&cpid=(\d+)", url)
if m:
return fixture_text(f"usaco/problem_{m.group(1)}.html")
raise AssertionError(f"No fixture for USACO url={url!r}")
def _make_offline_fetches(scraper_name: str): def _make_offline_fetches(scraper_name: str):
match scraper_name: match scraper_name:
case "cses": case "cses":
@ -166,10 +136,12 @@ def run_scraper_offline(fixture_text):
case "codeforces": case "codeforces":
def _mock_fetch_problems_html(cid: str) -> str: class MockCodeForcesPage:
return _router_codeforces( def __init__(self, html: str):
url=f"https://codeforces.com/contest/{cid}/problems" self.html_content = html
)
def _mock_stealthy_fetch(url: str, **kwargs):
return MockCodeForcesPage(_router_codeforces(url=url))
def _mock_requests_get(url: str, **kwargs): def _mock_requests_get(url: str, **kwargs):
if "api/contest.list" in url: if "api/contest.list" in url:
@ -200,7 +172,7 @@ def run_scraper_offline(fixture_text):
raise AssertionError(f"Unexpected requests.get call: {url}") raise AssertionError(f"Unexpected requests.get call: {url}")
return { return {
"_fetch_problems_html": _mock_fetch_problems_html, "Fetcher.get": _mock_stealthy_fetch,
"requests.get": _mock_requests_get, "requests.get": _mock_requests_get,
} }
@ -221,9 +193,6 @@ def run_scraper_offline(fixture_text):
if "/api/list/contests/all" in url: if "/api/list/contests/all" in url:
data = json.loads(fixture_text("codechef/contests.json")) data = json.loads(fixture_text("codechef/contests.json"))
return MockResponse(data) return MockResponse(data)
if "/api/list/contests/past" in url:
data = json.loads(fixture_text("codechef/contests_past.json"))
return MockResponse(data)
if "/api/contests/START" in url and "/problems/" not in url: if "/api/contests/START" in url and "/problems/" not in url:
contest_id = url.rstrip("/").split("/")[-1] contest_id = url.rstrip("/").split("/")[-1]
try: try:
@ -243,39 +212,21 @@ def run_scraper_offline(fixture_text):
return MockResponse(data) return MockResponse(data)
raise AssertionError(f"No fixture for CodeChef url={url!r}") raise AssertionError(f"No fixture for CodeChef url={url!r}")
class MockCodeChefPage:
def __init__(self, html: str):
self.body = html
self.status = 200
def _mock_stealthy_fetch(url: str, **kwargs):
if "/problems/" in url:
problem_id = url.rstrip("/").split("/")[-1]
html = fixture_text(f"codechef/{problem_id}.html")
return MockCodeChefPage(html)
raise AssertionError(f"No fixture for CodeChef url={url!r}")
return { return {
"__offline_get_async": __offline_get_async, "__offline_get_async": __offline_get_async,
} "Fetcher.get": _mock_stealthy_fetch,
case "kattis":
async def __offline_get_kattis(client, url: str, **kwargs):
if "/file/statement/samples.zip" in url:
raise httpx.HTTPError("not found")
html = _router_kattis(url=url)
return SimpleNamespace(
text=html,
content=html.encode(),
status_code=200,
raise_for_status=lambda: None,
)
return {
"__offline_get_async": __offline_get_kattis,
}
case "usaco":
async def __offline_get_usaco(client, url: str, **kwargs):
html = _router_usaco(url=url)
return SimpleNamespace(
text=html,
status_code=200,
raise_for_status=lambda: None,
)
return {
"__offline_get_async": __offline_get_usaco,
} }
case _: case _:
@ -286,8 +237,6 @@ def run_scraper_offline(fixture_text):
"atcoder": "AtcoderScraper", "atcoder": "AtcoderScraper",
"codeforces": "CodeforcesScraper", "codeforces": "CodeforcesScraper",
"codechef": "CodeChefScraper", "codechef": "CodeChefScraper",
"kattis": "KattisScraper",
"usaco": "USACOScraper",
} }
def _run(scraper_name: str, mode: str, *args: str): def _run(scraper_name: str, mode: str, *args: str):
@ -296,15 +245,16 @@ def run_scraper_offline(fixture_text):
offline_fetches = _make_offline_fetches(scraper_name) offline_fetches = _make_offline_fetches(scraper_name)
if scraper_name == "codeforces": if scraper_name == "codeforces":
ns._fetch_problems_html = offline_fetches["_fetch_problems_html"] fetchers.Fetcher.get = offline_fetches["Fetcher.get"]
requests.get = offline_fetches["requests.get"] requests.get = offline_fetches["requests.get"]
elif scraper_name == "atcoder": elif scraper_name == "atcoder":
ns._fetch = offline_fetches["_fetch"] ns._fetch = offline_fetches["_fetch"]
ns._get_async = offline_fetches["_get_async"] ns._get_async = offline_fetches["_get_async"]
elif scraper_name == "cses": elif scraper_name == "cses":
httpx.AsyncClient.get = offline_fetches["__offline_fetch_text"] httpx.AsyncClient.get = offline_fetches["__offline_fetch_text"]
elif scraper_name in ("codechef", "kattis", "usaco"): elif scraper_name == "codechef":
httpx.AsyncClient.get = offline_fetches["__offline_get_async"] httpx.AsyncClient.get = offline_fetches["__offline_get_async"]
fetchers.Fetcher.get = offline_fetches["Fetcher.get"]
scraper_class = getattr(ns, scraper_classes[scraper_name]) scraper_class = getattr(ns, scraper_classes[scraper_name])
scraper = scraper_class() scraper = scraper_class()

View file

@ -1,16 +0,0 @@
{
"status": "success",
"message": "past contests list",
"contests": [
{
"contest_code": "START209D",
"contest_name": "Starters 209 Div 4",
"contest_start_date_iso": "2025-01-01T10:30:00+05:30"
},
{
"contest_code": "START208",
"contest_name": "Starters 208",
"contest_start_date_iso": "2024-12-25T10:30:00+05:30"
}
]
}

View file

@ -1,10 +0,0 @@
<html><body><table>
<tr>
<td>A</td>
<td><a href="/contests/open2024/problems/kth2024a">Arithmetic Sequence</a></td>
</tr>
<tr>
<td>B</td>
<td><a href="/contests/open2024/problems/kth2024b">Binary Tree</a></td>
</tr>
</table></body></html>

View file

@ -1,10 +0,0 @@
<html><body><table>
<tr>
<td><a href="/contests/open2024">Open 2024</a></td>
<td data-timestamp="1711800000">2024-03-30</td>
</tr>
<tr>
<td><a href="/contests/icpc2023">ICPC 2023</a></td>
<td data-timestamp="1698768000">2023-10-31</td>
</tr>
</table></body></html>

View file

@ -1,11 +0,0 @@
<html>
<head><title>Hello World</title></head>
<body>
<span>CPU Time limit</span><span class="num">1 second</span>
<span>Memory limit</span><span class="num">256 MB</span>
<table class="sample">
<pre>Hello World</pre>
<pre>Hello World</pre>
</table>
</body>
</html>

View file

@ -1,17 +0,0 @@
<html>
<head><title>Arithmetic Sequence</title></head>
<body>
<span>CPU Time limit</span><span class="num">2 seconds</span>
<span>Memory limit</span><span class="num">512 MB</span>
<table class="sample">
<pre>3
1 2 3</pre>
<pre>YES</pre>
</table>
<table class="sample">
<pre>2
1 3</pre>
<pre>NO</pre>
</table>
</body>
</html>

View file

@ -1,12 +0,0 @@
<html>
<head><title>Binary Tree</title></head>
<body>
<span>CPU Time limit</span><span class="num">1 second</span>
<span>Memory limit</span><span class="num">256 MB</span>
<table class="sample">
<pre>5
1 2 3 4 5</pre>
<pre>3</pre>
</table>
</body>
</html>

View file

@ -1,3 +0,0 @@
<html><body>
<a href="index.php?page=dec24results">December 2024 Results</a>
</body></html>

View file

@ -1,14 +0,0 @@
<html><body>
<h2>USACO 2024 December Contest, Gold</h2>
<b>Farmer John's Favorite Problem</b><br/>
<a href="index.php?page=viewproblem2&cpid=1469">View Problem</a>
<b>Binary Indexed Tree</b><br/>
<a href="index.php?page=viewproblem2&cpid=1470">View Problem</a>
<b>Counting Subsequences</b><br/>
<a href="index.php?page=viewproblem2&cpid=1471">View Problem</a>
</body></html>

View file

@ -1,10 +0,0 @@
<html><body>
<p>Time limit: 4s. Memory limit: 256 MB.</p>
<p>Given N cows, find the answer.</p>
<pre class="in">3
1 2 3</pre>
<pre class="out">6</pre>
<pre class="in">1
5</pre>
<pre class="out">5</pre>
</body></html>

View file

@ -1,7 +0,0 @@
<html><body>
<p>Time limit: 2s. Memory limit: 512 MB.</p>
<p>Build a binary indexed tree.</p>
<pre class="in">4
1 3 2 4</pre>
<pre class="out">10</pre>
</body></html>

View file

@ -1,7 +0,0 @@
<html><body>
<p>Time limit: 4s. Memory limit: 256 MB.</p>
<p>Output the answer with absolute error at most 10^{-6}.</p>
<pre class="in">2
1 2</pre>
<pre class="out">1.500000</pre>
</body></html>

View file

@ -1,12 +1,16 @@
import pytest import pytest
from scrapers.language_ids import LANGUAGE_IDS
from scrapers.models import ( from scrapers.models import (
ContestListResult, ContestListResult,
MetadataResult, MetadataResult,
TestsResult, TestsResult,
) )
MODEL_FOR_MODE = {
"metadata": MetadataResult,
"contests": ContestListResult,
}
MATRIX = { MATRIX = {
"cses": { "cses": {
"metadata": ("introductory_problems",), "metadata": ("introductory_problems",),
@ -28,16 +32,6 @@ MATRIX = {
"tests": ("START209D",), "tests": ("START209D",),
"contests": tuple(), "contests": tuple(),
}, },
"kattis": {
"metadata": ("hello",),
"tests": ("hello",),
"contests": tuple(),
},
"usaco": {
"metadata": ("dec24_gold",),
"tests": ("dec24_gold",),
"contests": tuple(),
},
} }
@ -49,16 +43,17 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode):
assert rc in (0, 1), f"Bad exit code {rc}" assert rc in (0, 1), f"Bad exit code {rc}"
assert objs, f"No JSON output for {scraper}:{mode}" assert objs, f"No JSON output for {scraper}:{mode}"
if mode == "metadata": if mode in ("metadata", "contests"):
model = MetadataResult.model_validate(objs[-1]) Model = MODEL_FOR_MODE[mode]
model = Model.model_validate(objs[-1])
assert model is not None
assert model.success is True assert model.success is True
assert model.url if mode == "metadata":
assert len(model.problems) >= 1 assert model.url
assert all(isinstance(p.id, str) and p.id for p in model.problems) assert len(model.problems) >= 1
elif mode == "contests": assert all(isinstance(p.id, str) and p.id for p in model.problems)
model = ContestListResult.model_validate(objs[-1]) else:
assert model.success is True assert len(model.contests) >= 1
assert len(model.contests) >= 1
else: else:
assert len(objs) >= 1, "No test objects returned" assert len(objs) >= 1, "No test objects returned"
validated_any = False validated_any = False
@ -96,61 +91,5 @@ def test_scraper_offline_fixture_matrix(run_scraper_offline, scraper, mode):
) )
assert "multi_test" in obj, "Missing multi_test field in raw JSON" assert "multi_test" in obj, "Missing multi_test field in raw JSON"
assert isinstance(obj["multi_test"], bool), "multi_test not boolean" assert isinstance(obj["multi_test"], bool), "multi_test not boolean"
assert "precision" in obj, "Missing precision field in raw JSON"
assert obj["precision"] is None or isinstance(
obj["precision"], float
), "precision must be None or float"
validated_any = True validated_any = True
assert validated_any, "No valid tests payloads validated" assert validated_any, "No valid tests payloads validated"
def test_kattis_contest_metadata(run_scraper_offline):
rc, objs = run_scraper_offline("kattis", "metadata", "open2024")
assert rc == 0
assert objs
model = MetadataResult.model_validate(objs[-1])
assert model.success is True
assert len(model.problems) == 2
assert model.contest_url != ""
assert model.standings_url != ""
def test_usaco_precision_extracted(run_scraper_offline):
rc, objs = run_scraper_offline("usaco", "tests", "dec24_gold")
assert rc == 0
precisions = [obj["precision"] for obj in objs if "problem_id" in obj]
assert any(p is not None for p in precisions), (
"Expected at least one problem with precision"
)
@pytest.mark.parametrize(
"scraper,contest_id",
[
("cses", "nonexistent_category_xyz"),
("usaco", "badformat"),
("kattis", "nonexistent_problem_xyz"),
],
)
def test_scraper_metadata_error(run_scraper_offline, scraper, contest_id):
rc, objs = run_scraper_offline(scraper, "metadata", contest_id)
assert rc == 1
assert objs
assert objs[-1].get("success") is False
assert objs[-1].get("error")
def test_language_ids_coverage():
expected_platforms = {
"atcoder",
"codeforces",
"cses",
"usaco",
"kattis",
"codechef",
}
assert set(LANGUAGE_IDS.keys()) == expected_platforms
for platform, langs in LANGUAGE_IDS.items():
assert {"cpp", "python"} <= set(langs.keys()), f"{platform} missing cpp/python"
for lang, lid in langs.items():
assert isinstance(lid, str) and lid, f"{platform}/{lang} empty ID"

1174
uv.lock generated

File diff suppressed because it is too large Load diff

30
vim.toml Normal file
View file

@ -0,0 +1,30 @@
[selene]
base = "lua51"
name = "vim"
[vim]
any = true
[jit]
any = true
[assert]
any = true
[describe]
any = true
[it]
any = true
[before_each]
any = true
[after_each]
any = true
[spy]
any = true
[stub]
any = true

View file

@ -1,26 +0,0 @@
---
base: lua51
name: vim
lua_versions:
- luajit
globals:
vim:
any: true
jit:
any: true
assert:
any: true
describe:
any: true
it:
any: true
before_each:
any: true
after_each:
any: true
spy:
any: true
stub:
any: true
bit:
any: true