diff --git a/lua/cp/scraper.lua b/lua/cp/scraper.lua index 7ddca0a..fc8ba69 100644 --- a/lua/cp/scraper.lua +++ b/lua/cp/scraper.lua @@ -298,9 +298,12 @@ function M.submit( stdin = source_code, env_extra = { CP_CREDENTIALS = vim.json.encode(credentials) }, on_event = function(ev) + if ev.credentials ~= nil then + require('cp.cache').set_credentials(platform, ev.credentials) + end if ev.status ~= nil then if type(on_status) == 'function' then - on_status(ev.status) + on_status(ev) end elseif ev.success ~= nil then done = true diff --git a/lua/cp/submit.lua b/lua/cp/submit.lua index 4efe25e..7dc9a71 100644 --- a/lua/cp/submit.lua +++ b/lua/cp/submit.lua @@ -25,6 +25,7 @@ local function prompt_credentials(platform, callback) vim.fn.inputsave() local password = vim.fn.inputsecret(platform .. ' password: ') vim.fn.inputrestore() + vim.cmd.redraw() if not password or password == '' then logger.log('Submit cancelled', vim.log.levels.WARN) return @@ -64,9 +65,9 @@ function M.submit(opts) language, source_code, creds, - function(status) + function(ev) vim.schedule(function() - vim.notify('[cp.nvim] ' .. (STATUS_MSGS[status] or status), vim.log.levels.INFO) + vim.notify('[cp.nvim] ' .. (STATUS_MSGS[ev.status] or ev.status), vim.log.levels.INFO) end) end, function(result) diff --git a/scrapers/cses.py b/scrapers/cses.py index 473558f..b2e845a 100644 --- a/scrapers/cses.py +++ b/scrapers/cses.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import asyncio +import base64 import json import re from typing import Any @@ -18,6 +19,8 @@ from .models import ( ) BASE_URL = "https://cses.fi" +API_URL = "https://cses.fi/api" +SUBMIT_SCOPE = "courses/problemset" INDEX_PATH = "/problemset" TASK_PATH = "/problemset/task/{id}" HEADERS = { @@ -26,6 +29,16 @@ HEADERS = { TIMEOUT_S = 15.0 CONNECTIONS = 8 +CSES_LANGUAGES: dict[str, dict[str, str]] = { + "C++17": {"name": "C++", "option": "C++17"}, + "Python3": {"name": "Python", "option": "CPython3"}, +} + +EXTENSIONS: dict[str, str] = { + "C++17": "cpp", + "Python3": "py", +} + def normalize_category_name(category_name: str) -> str: return category_name.lower().replace(" ", "_").replace("&", "and") @@ -270,6 +283,65 @@ class CSESScraper(BaseScraper): payload = await coro print(json.dumps(payload), flush=True) + async def _web_login( + self, + client: httpx.AsyncClient, + username: str, + password: str, + ) -> str | None: + login_page = await client.get( + f"{BASE_URL}/login", headers=HEADERS, timeout=TIMEOUT_S + ) + csrf_match = re.search(r'name="csrf_token" value="([^"]+)"', login_page.text) + if not csrf_match: + return None + + login_resp = await client.post( + f"{BASE_URL}/login", + data={ + "csrf_token": csrf_match.group(1), + "nick": username, + "pass": password, + }, + headers=HEADERS, + timeout=TIMEOUT_S, + ) + + if "Invalid username or password" in login_resp.text: + return None + + api_resp = await client.post( + f"{API_URL}/login", headers=HEADERS, timeout=TIMEOUT_S + ) + api_data = api_resp.json() + token: str = api_data["X-Auth-Token"] + auth_url: str = api_data["authentication_url"] + + auth_page = await client.get(auth_url, headers=HEADERS, timeout=TIMEOUT_S) + auth_csrf = re.search(r'name="csrf_token" value="([^"]+)"', auth_page.text) + form_token = re.search(r'name="token" value="([^"]+)"', auth_page.text) + if not auth_csrf or not form_token: + return None + + await client.post( + auth_url, + data={ + "csrf_token": auth_csrf.group(1), + "token": form_token.group(1), + }, + headers=HEADERS, + timeout=TIMEOUT_S, + ) + + check = await client.get( + f"{API_URL}/login", + headers={"X-Auth-Token": token, **HEADERS}, + timeout=TIMEOUT_S, + ) + if check.status_code != 200: + return None + return token + async def submit( self, contest_id: str, @@ -278,12 +350,83 @@ class CSESScraper(BaseScraper): language_id: str, credentials: dict[str, str], ) -> SubmitResult: - return SubmitResult( - success=False, - error="CSES submit not yet implemented", - submission_id="", - verdict="", - ) + username = credentials.get("username", "") + password = credentials.get("password", "") + if not username or not password: + return self._submit_error("Missing credentials. Use :CP login cses") + + async with httpx.AsyncClient(follow_redirects=True) as client: + print(json.dumps({"status": "logging_in"}), flush=True) + + token = await self._web_login(client, username, password) + if not token: + return self._submit_error("Login failed (bad credentials?)") + + print(json.dumps({"status": "submitting"}), flush=True) + + ext = EXTENSIONS.get(language_id, "cpp") + lang = CSES_LANGUAGES.get(language_id, {}) + content_b64 = base64.b64encode(source_code.encode()).decode() + + payload: dict[str, Any] = { + "language": lang, + "filename": f"{problem_id}.{ext}", + "content": content_b64, + } + + r = await client.post( + f"{API_URL}/{SUBMIT_SCOPE}/submissions", + json=payload, + params={"task": problem_id}, + headers={ + "X-Auth-Token": token, + "Content-Type": "application/json", + **HEADERS, + }, + timeout=TIMEOUT_S, + ) + + if r.status_code not in range(200, 300): + try: + err = r.json().get("message", r.text) + except Exception: + err = r.text + return self._submit_error(f"Submit request failed: {err}") + + info = r.json() + submission_id = str(info.get("id", "")) + + for _ in range(60): + await asyncio.sleep(2) + try: + r = await client.get( + f"{API_URL}/{SUBMIT_SCOPE}/submissions/{submission_id}", + params={"poll": "true"}, + headers={ + "X-Auth-Token": token, + **HEADERS, + }, + timeout=30.0, + ) + if r.status_code == 200: + info = r.json() + if not info.get("pending", True): + verdict = info.get("result", "unknown") + return SubmitResult( + success=True, + error="", + submission_id=submission_id, + verdict=verdict, + ) + except Exception: + pass + + return SubmitResult( + success=True, + error="", + submission_id=submission_id, + verdict="submitted (poll timed out)", + ) if __name__ == "__main__": diff --git a/t/1068.cc b/t/1068.cc new file mode 100644 index 0000000..5d3fe37 --- /dev/null +++ b/t/1068.cc @@ -0,0 +1,54 @@ +#include // {{{ + +#include +#ifdef __cpp_lib_ranges_enumerate +#include +namespace rv = std::views; +namespace rs = std::ranges; +#endif + +#pragma GCC optimize("O2,unroll-loops") +#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt") + +using namespace std; + +using i32 = int32_t; +using u32 = uint32_t; +using i64 = int64_t; +using u64 = uint64_t; +using f64 = double; +using f128 = long double; + +#if __cplusplus >= 202002L +template +constexpr T MIN = std::numeric_limits::min(); + +template +constexpr T MAX = std::numeric_limits::max(); +#endif + +#ifdef LOCAL +#define db(...) std::print(__VA_ARGS__) +#define dbln(...) std::println(__VA_ARGS__) +#else +#define db(...) +#define dbln(...) +#endif +// }}} + +void solve() { + cout << "hi\n"; +} + +int main() { // {{{ + std::cin.exceptions(std::cin.failbit); +#ifdef LOCAL + std::cerr.rdbuf(std::cout.rdbuf()); + std::cout.setf(std::ios::unitbuf); + std::cerr.setf(std::ios::unitbuf); +#else + std::cin.tie(nullptr)->sync_with_stdio(false); +#endif + solve(); + return 0; +} // }}}