diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml deleted file mode 100644 index 5742799..0000000 --- a/.github/ISSUE_TEMPLATE/bug_report.yaml +++ /dev/null @@ -1,78 +0,0 @@ -name: Bug Report -description: Report a bug -title: 'bug: ' -labels: [bug] -body: - - type: checkboxes - attributes: - label: Prerequisites - options: - - label: - I have searched [existing - issues](https://github.com/barrettruth/cp.nvim/issues) - required: true - - label: I have updated to the latest version - required: true - - - type: textarea - attributes: - label: 'Neovim version' - description: 'Output of `nvim --version`' - render: text - validations: - required: true - - - type: input - attributes: - label: 'Operating system' - placeholder: 'e.g. Arch Linux, macOS 15, Ubuntu 24.04' - validations: - required: true - - - type: textarea - attributes: - label: Description - description: What happened? What did you expect? - validations: - required: true - - - type: textarea - attributes: - label: Steps to reproduce - description: Minimal steps to trigger the bug - value: | - 1. - 2. - 3. - validations: - required: true - - - type: textarea - attributes: - label: 'Health check' - description: 'Output of `:checkhealth cp`' - render: text - - - type: textarea - attributes: - label: Minimal reproduction - description: | - Save the script below as `repro.lua`, edit if needed, and run: - ``` - nvim -u repro.lua - ``` - Confirm the bug reproduces with this config before submitting. - render: lua - value: | - vim.env.LAZY_STDPATH = '.repro' - load(vim.fn.system('curl -s https://raw.githubusercontent.com/folke/lazy.nvim/main/bootstrap.lua'))() - require('lazy.nvim').setup({ - spec = { - { - 'barrett-ruth/cp.nvim', - opts = {}, - }, - }, - }) - validations: - required: true diff --git a/.github/ISSUE_TEMPLATE/config.yaml b/.github/ISSUE_TEMPLATE/config.yaml deleted file mode 100644 index 12ef1b0..0000000 --- a/.github/ISSUE_TEMPLATE/config.yaml +++ /dev/null @@ -1,5 +0,0 @@ -blank_issues_enabled: false -contact_links: - - name: Questions - url: https://github.com/barrettruth/cp.nvim/discussions - about: Ask questions and discuss ideas diff --git a/.github/ISSUE_TEMPLATE/feature_request.yaml b/.github/ISSUE_TEMPLATE/feature_request.yaml deleted file mode 100644 index 39c6692..0000000 --- a/.github/ISSUE_TEMPLATE/feature_request.yaml +++ /dev/null @@ -1,30 +0,0 @@ -name: Feature Request -description: Suggest a feature -title: 'feat: ' -labels: [enhancement] -body: - - type: checkboxes - attributes: - label: Prerequisites - options: - - label: - I have searched [existing - issues](https://github.com/barrettruth/cp.nvim/issues) - required: true - - - type: textarea - attributes: - label: Problem - description: What problem does this solve? - validations: - required: true - - - type: textarea - attributes: - label: Proposed solution - validations: - required: true - - - type: textarea - attributes: - label: Alternatives considered diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml deleted file mode 100644 index bf6bcea..0000000 --- a/.github/workflows/ci.yaml +++ /dev/null @@ -1,112 +0,0 @@ -name: ci -on: - workflow_call: - pull_request: - branches: [main] - push: - branches: [main] - -jobs: - changes: - runs-on: ubuntu-latest - outputs: - lua: ${{ steps.changes.outputs.lua }} - python: ${{ steps.changes.outputs.python }} - steps: - - uses: actions/checkout@v4 - - uses: dorny/paths-filter@v3 - id: changes - with: - filters: | - lua: - - 'lua/**' - - 'spec/**' - - 'plugin/**' - - 'after/**' - - 'ftdetect/**' - - '*.lua' - - '.luarc.json' - - 'stylua.toml' - - 'selene.toml' - python: - - 'scripts/**' - - 'scrapers/**' - - 'tests/**' - - 'pyproject.toml' - - 'uv.lock' - - lua-format: - runs-on: ubuntu-latest - needs: changes - if: ${{ needs.changes.outputs.lua == 'true' }} - steps: - - uses: actions/checkout@v4 - - uses: JohnnyMorganz/stylua-action@v4 - with: - token: ${{ secrets.GITHUB_TOKEN }} - version: 2.1.0 - args: --check . - - lua-lint: - runs-on: ubuntu-latest - needs: changes - if: ${{ needs.changes.outputs.lua == 'true' }} - steps: - - uses: actions/checkout@v4 - - uses: NTBBloodbath/selene-action@v1.0.0 - with: - token: ${{ secrets.GITHUB_TOKEN }} - args: --display-style quiet . - - lua-typecheck: - runs-on: ubuntu-latest - needs: changes - if: ${{ needs.changes.outputs.lua == 'true' }} - steps: - - uses: actions/checkout@v4 - - uses: mrcjkb/lua-typecheck-action@v0 - with: - checklevel: Warning - directories: lua - configpath: .luarc.json - - python-format: - runs-on: ubuntu-latest - needs: changes - if: ${{ needs.changes.outputs.python == 'true' }} - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/setup-uv@v4 - - run: uv tool install ruff - - run: ruff format --check . - - python-lint: - runs-on: ubuntu-latest - needs: changes - if: ${{ needs.changes.outputs.python == 'true' }} - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/setup-uv@v4 - - run: uv tool install ruff - - run: ruff check . - - python-typecheck: - runs-on: ubuntu-latest - needs: changes - if: ${{ needs.changes.outputs.python == 'true' }} - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/setup-uv@v4 - - run: uv sync --dev - - run: uvx ty check . - - python-test: - runs-on: ubuntu-latest - needs: changes - if: ${{ needs.changes.outputs.python == 'true' }} - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/setup-uv@v4 - - run: uv sync --dev - - run: uv run camoufox fetch - - run: uv run pytest tests/ -v diff --git a/.github/workflows/luarocks.yaml b/.github/workflows/luarocks.yaml index 5be8b55..c64568f 100644 --- a/.github/workflows/luarocks.yaml +++ b/.github/workflows/luarocks.yaml @@ -1,21 +1,18 @@ -name: luarocks +name: Release on: push: tags: - - 'v*' + - '*' + workflow_dispatch: jobs: - ci: - uses: ./.github/workflows/ci.yaml - - publish: - needs: ci + publish-luarocks: + name: Publish to LuaRocks runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 - - - uses: nvim-neorocks/luarocks-tag-release@v7 + - name: Publish to LuaRocks + uses: nvim-neorocks/luarocks-tag-release@v7 env: LUAROCKS_API_KEY: ${{ secrets.LUAROCKS_API_KEY }} diff --git a/.github/workflows/quality.yaml b/.github/workflows/quality.yaml index 731e74b..f6f27bc 100644 --- a/.github/workflows/quality.yaml +++ b/.github/workflows/quality.yaml @@ -1,4 +1,4 @@ -name: quality +name: Code Quality on: pull_request: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index ed7be0c..4c1cc1f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,4 +1,4 @@ -name: tests +name: Tests on: pull_request: diff --git a/.gitignore b/.gitignore index 45bc345..f383808 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,11 @@ -.venv -venv +.venv/ doc/tags *.log build io debug -create - - -.*cache* +venv/ CLAUDE.md __pycache__ .claude/ - node_modules/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 971ffbe..49fe046 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,6 @@ repos: hooks: - id: prettier name: prettier - files: \.(md|toml|ya?ml|sh)$ - repo: local hooks: diff --git a/README.md b/README.md index cf82417..7d439f9 100644 --- a/README.md +++ b/README.md @@ -19,15 +19,6 @@ https://github.com/user-attachments/assets/e81d8dfb-578f-4a79-9989-210164fc0148 - **Language agnostic**: Works with any language - **Diff viewer**: Compare expected vs actual output with 3 diff modes -## Installation - -Install using your package manager of choice or via -[luarocks](https://luarocks.org/modules/barrettruth/cp.nvim): - -``` -luarocks install cp.nvim -``` - ## Optional Dependencies - [uv](https://docs.astral.sh/uv/) for problem scraping diff --git a/cp.nvim-scm-1.rockspec b/cp.nvim-scm-1.rockspec index 8152b7b..e38d924 100644 --- a/cp.nvim-scm-1.rockspec +++ b/cp.nvim-scm-1.rockspec @@ -2,7 +2,7 @@ rockspec_format = '3.0' package = 'cp.nvim' version = 'scm-1' -source = { url = 'git://github.com/barrettruth/cp.nvim' } +source = { url = 'git://github.com/barrett-ruth/cp.nvim' } build = { type = 'builtin' } test_dependencies = { diff --git a/doc/cp.nvim.txt b/doc/cp.nvim.txt index d6d1d73..f99b07b 100644 --- a/doc/cp.nvim.txt +++ b/doc/cp.nvim.txt @@ -205,66 +205,71 @@ Debug Builds ~ ============================================================================== CONFIGURATION *cp-config* -Configuration is done via `vim.g.cp_config`. Set this before using the plugin: +Here's an example configuration with lazy.nvim: >lua - vim.g.cp_config = { - languages = { - cpp = { - extension = 'cc', - commands = { - build = { 'g++', '-std=c++17', '{source}', '-o', '{binary}', - '-fdiagnostics-color=always' }, - run = { '{binary}' }, - debug = { 'g++', '-std=c++17', '-fsanitize=address,undefined', - '{source}', '-o', '{binary}' }, + { + 'barrett-ruth/cp.nvim', + cmd = 'CP', + build = 'uv sync', + opts = { + languages = { + cpp = { + extension = 'cc', + commands = { + build = { 'g++', '-std=c++17', '{source}', '-o', '{binary}', + '-fdiagnostics-color=always' }, + run = { '{binary}' }, + debug = { 'g++', '-std=c++17', '-fsanitize=address,undefined', + '{source}', '-o', '{binary}' }, + }, + }, + python = { + extension = 'py', + commands = { + run = { 'python', '{source}' }, + debug = { 'python', '{source}' }, + }, }, }, - python = { - extension = 'py', - commands = { - run = { 'python', '{source}' }, - debug = { 'python', '{source}' }, + platforms = { + cses = { + enabled_languages = { 'cpp', 'python' }, + default_language = 'cpp', + overrides = { + cpp = { extension = 'cpp', commands = { build = { ... } } } + }, + }, + atcoder = { + enabled_languages = { 'cpp', 'python' }, + default_language = 'cpp', + }, + codeforces = { + enabled_languages = { 'cpp', 'python' }, + default_language = 'cpp', }, }, - }, - platforms = { - cses = { - enabled_languages = { 'cpp', 'python' }, - default_language = 'cpp', - overrides = { - cpp = { extension = 'cpp', commands = { build = { ... } } } + open_url = true, + debug = false, + ui = { + ansi = true, + run = { + width = 0.3, + next_test_key = '', -- or nil to disable + prev_test_key = '', -- or nil to disable }, - }, - atcoder = { - enabled_languages = { 'cpp', 'python' }, - default_language = 'cpp', - }, - codeforces = { - enabled_languages = { 'cpp', 'python' }, - default_language = 'cpp', - }, - }, - open_url = true, - debug = false, - ui = { - ansi = true, - run = { - width = 0.3, - next_test_key = '', -- or nil to disable - prev_test_key = '', -- or nil to disable - }, - panel = { - diff_modes = { 'side-by-side', 'git', 'vim' }, - max_output_lines = 50, - }, - diff = { - git = { - args = { 'diff', '--no-index', '--word-diff=plain', - '--word-diff-regex=.', '--no-prefix' }, + panel = { + diff_mode = 'vim', + max_output_lines = 50, }, + diff = { + git = { + args = { 'diff', '--no-index', '--word-diff=plain', + '--word-diff-regex=.', '--no-prefix' }, + }, + }, + picker = 'telescope', }, - picker = 'telescope', - }, + } } < @@ -274,7 +279,7 @@ the default; per-platform overrides can tweak 'extension' or 'commands'. For example, to run CodeForces contests with Python by default: >lua - vim.g.cp_config = { + { platforms = { codeforces = { default_language = 'python', @@ -285,7 +290,7 @@ For example, to run CodeForces contests with Python by default: Any language is supported provided the proper configuration. For example, to run CSES problems with Rust using the single schema: >lua - vim.g.cp_config = { + { languages = { rust = { extension = 'rs', @@ -373,10 +378,8 @@ run CSES problems with Rust using the single schema: *cp.PanelConfig* Fields: ~ - {diff_modes} (string[], default: {'side-by-side', 'git', 'vim'}) - List of diff modes to cycle through with 't' key. - First element is the default mode. - Valid modes: 'side-by-side', 'git', 'vim'. + {diff_mode} (string, default: "none") Diff backend: "none", + "vim", or "git". {max_output_lines} (number, default: 50) Maximum lines of test output. *cp.DiffConfig* @@ -781,15 +784,12 @@ HIGHLIGHT GROUPS *cp-highlights* Test Status Groups ~ -All test status groups link to builtin highlight groups, automatically adapting -to your colorscheme: - - CpTestAC Links to DiagnosticOk (AC status) - CpTestWA Links to DiagnosticError (WA status) - CpTestTLE Links to DiagnosticWarn (TLE status) - CpTestMLE Links to DiagnosticWarn (MLE status) - CpTestRTE Links to DiagnosticHint (RTE status) - CpTestNA Links to Comment (pending/unknown status) + CpTestAC Green foreground for AC status + CpTestWA Red foreground for WA status + CpTestTLE Orange foreground for TLE status + CpTestMLE Orange foreground for MLE status + CpTestRTE Purple foreground for RTE status + CpTestNA Gray foreground for remaining state ANSI Color Groups ~ @@ -848,20 +848,17 @@ PANEL KEYMAPS *cp-panel-keys* Navigate to next test case Navigate to previous test case -t Cycle through configured diff modes (see |cp.PanelConfig|) +t Cycle through diff modes: none → git → vim q Exit panel and restore layout Exit interactive terminal and restore layout Diff Modes ~ -Three diff modes are available: +Three diff backends are available: - side-by-side Expected and actual output shown side-by-side (default) - vim Built-in vim diff (always available) - git Character-level git word-diff (requires git, more precise) - -Configure which modes to cycle through via |cp.PanelConfig|.diff_modes. -The first element is used as the default mode. + none Nothing + vim Built-in vim diff (default, always available) + git Character-level git word-diff (requires git, more precise) The git backend shows character-level changes with [-removed-] and {+added+} markers. diff --git a/lua/cp/config.lua b/lua/cp/config.lua index dec8878..78f321f 100644 --- a/lua/cp/config.lua +++ b/lua/cp/config.lua @@ -18,7 +18,7 @@ ---@field overrides? table ---@class PanelConfig ----@field diff_modes string[] +---@field diff_mode "none"|"vim"|"git" ---@field max_output_lines integer ---@class DiffGitConfig @@ -173,7 +173,7 @@ M.defaults = { add_test_key = 'ga', save_and_exit_key = 'q', }, - panel = { diff_modes = { 'side-by-side', 'git', 'vim' }, max_output_lines = 50 }, + panel = { diff_mode = 'none', max_output_lines = 50 }, diff = { git = { args = { 'diff', '--no-index', '--word-diff=plain', '--word-diff-regex=.', '--no-prefix' }, @@ -305,24 +305,7 @@ function M.setup(user_config) vim.validate({ hooks = { cfg.hooks, { 'table' } }, ui = { cfg.ui, { 'table' } }, - debug = { cfg.debug, { 'boolean', 'nil' }, true }, open_url = { cfg.open_url, { 'boolean', 'nil' }, true }, - filename = { cfg.filename, { 'function', 'nil' }, true }, - scrapers = { - cfg.scrapers, - function(v) - if type(v) ~= 'table' then - return false - end - for _, s in ipairs(v) do - if not vim.tbl_contains(constants.PLATFORMS, s) then - return false - end - end - return true - end, - ('one of {%s}'):format(table.concat(constants.PLATFORMS, ',')), - }, before_run = { cfg.hooks.before_run, { 'function', 'nil' }, true }, before_debug = { cfg.hooks.before_debug, { 'function', 'nil' }, true }, setup_code = { cfg.hooks.setup_code, { 'function', 'nil' }, true }, @@ -330,23 +313,14 @@ function M.setup(user_config) setup_io_output = { cfg.hooks.setup_io_output, { 'function', 'nil' }, true }, }) - local layouts = require('cp.ui.layouts') vim.validate({ ansi = { cfg.ui.ansi, 'boolean' }, - diff_modes = { - cfg.ui.panel.diff_modes, + diff_mode = { + cfg.ui.panel.diff_mode, function(v) - if type(v) ~= 'table' then - return false - end - for _, mode in ipairs(v) do - if not layouts.DIFF_MODES[mode] then - return false - end - end - return true + return vim.tbl_contains({ 'none', 'vim', 'git' }, v) end, - ('one of {%s}'):format(table.concat(vim.tbl_keys(layouts.DIFF_MODES), ',')), + "diff_mode must be 'none', 'vim', or 'git'", }, max_output_lines = { cfg.ui.panel.max_output_lines, @@ -356,14 +330,6 @@ function M.setup(user_config) 'positive integer', }, git = { cfg.ui.diff.git, { 'table' } }, - git_args = { cfg.ui.diff.git.args, is_string_list, 'string[]' }, - width = { - cfg.ui.run.width, - function(v) - return type(v) == 'number' and v > 0 and v <= 1 - end, - 'decimal between 0 and 1', - }, next_test_key = { cfg.ui.run.next_test_key, function(v) @@ -417,13 +383,6 @@ function M.setup(user_config) end, 'nil or non-empty string', }, - picker = { - cfg.ui.picker, - function(v) - return v == nil or v == 'telescope' or v == 'fzf-lua' - end, - "nil, 'telescope', or 'fzf-lua'", - }, }) for id, lang in pairs(cfg.languages) do @@ -484,18 +443,7 @@ function M.get_language_for_platform(platform_id, language_id) } end - local platform_effective = cfg.runtime.effective[platform_id] - if not platform_effective then - return { - valid = false, - error = string.format( - 'No runtime config for platform %s (plugin not initialized)', - platform_id - ), - } - end - - local effective = platform_effective[language_id] + local effective = cfg.runtime.effective[platform_id][language_id] if not effective then return { valid = false, diff --git a/lua/cp/init.lua b/lua/cp/init.lua index fac3044..64a997d 100644 --- a/lua/cp/init.lua +++ b/lua/cp/init.lua @@ -11,25 +11,25 @@ if vim.fn.has('nvim-0.10.0') == 0 then return {} end +local user_config = {} +local config = nil local initialized = false -local function ensure_initialized() - if initialized then - return - end - local user_config = vim.g.cp_config or {} - local config = config_module.setup(user_config) - config_module.set_current_config(config) - initialized = true -end - ---@return nil function M.handle_command(opts) - ensure_initialized() local commands = require('cp.commands') commands.handle_command(opts) end +function M.setup(opts) + opts = opts or {} + user_config = opts + config = config_module.setup(user_config) + config_module.set_current_config(config) + + initialized = true +end + function M.is_initialized() return initialized end diff --git a/lua/cp/runner/execute.lua b/lua/cp/runner/execute.lua index 76d055a..c1c141e 100644 --- a/lua/cp/runner/execute.lua +++ b/lua/cp/runner/execute.lua @@ -39,27 +39,24 @@ end ---@param compile_cmd string[] ---@param substitutions SubstitutableCommand ----@param on_complete fun(r: {code: integer, stdout: string}) -function M.compile(compile_cmd, substitutions, on_complete) +function M.compile(compile_cmd, substitutions) local cmd = substitute_template(compile_cmd, substitutions) local sh = table.concat(cmd, ' ') .. ' 2>&1' local t0 = vim.uv.hrtime() - vim.system({ 'sh', '-c', sh }, { text = false }, function(r) - local dt = (vim.uv.hrtime() - t0) / 1e6 - local ansi = require('cp.ui.ansi') - r.stdout = ansi.bytes_to_string(r.stdout or '') + local r = vim.system({ 'sh', '-c', sh }, { text = false }):wait() + local dt = (vim.uv.hrtime() - t0) / 1e6 - if r.code == 0 then - logger.log(('Compilation successful in %.1fms.'):format(dt), vim.log.levels.INFO) - else - logger.log(('Compilation failed in %.1fms.'):format(dt)) - end + local ansi = require('cp.ui.ansi') + r.stdout = ansi.bytes_to_string(r.stdout or '') - vim.schedule(function() - on_complete(r) - end) - end) + if r.code == 0 then + logger.log(('Compilation successful in %.1fms.'):format(dt), vim.log.levels.INFO) + else + logger.log(('Compilation failed in %.1fms.'):format(dt)) + end + + return r end local function parse_and_strip_time_v(output) @@ -106,8 +103,7 @@ local function parse_and_strip_time_v(output) return head, peak_mb end ----@param on_complete fun(result: ExecuteResult) -function M.run(cmd, stdin, timeout_ms, memory_mb, on_complete) +function M.run(cmd, stdin, timeout_ms, memory_mb) local time_bin = utils.time_path() local timeout_bin = utils.timeout_path() @@ -121,91 +117,76 @@ function M.run(cmd, stdin, timeout_ms, memory_mb, on_complete) local sh = prefix .. timeout_prefix .. ('%s -v sh -c %q 2>&1'):format(time_bin, prog) local t0 = vim.uv.hrtime() - vim.system({ 'sh', '-c', sh }, { stdin = stdin, text = true }, function(r) - local dt = (vim.uv.hrtime() - t0) / 1e6 + local r = vim + .system({ 'sh', '-c', sh }, { + stdin = stdin, + text = true, + }) + :wait() + local dt = (vim.uv.hrtime() - t0) / 1e6 - local code = r.code or 0 - local raw = r.stdout or '' - local cleaned, peak_mb = parse_and_strip_time_v(raw) - local tled = code == 124 + local code = r.code or 0 + local raw = r.stdout or '' + local cleaned, peak_mb = parse_and_strip_time_v(raw) + local tled = code == 124 - local signal = nil - if code >= 128 then - signal = constants.signal_codes[code] - end + local signal = nil + if code >= 128 then + signal = constants.signal_codes[code] + end - local lower = (cleaned or ''):lower() - local oom_hint = lower:find('std::bad_alloc', 1, true) - or lower:find('cannot allocate memory', 1, true) - or lower:find('out of memory', 1, true) - or lower:find('oom', 1, true) - or lower:find('enomem', 1, true) - local near_cap = peak_mb >= (0.90 * memory_mb) + local lower = (cleaned or ''):lower() + local oom_hint = lower:find('std::bad_alloc', 1, true) + or lower:find('cannot allocate memory', 1, true) + or lower:find('out of memory', 1, true) + or lower:find('oom', 1, true) + or lower:find('enomem', 1, true) + local near_cap = peak_mb >= (0.90 * memory_mb) - local mled = (peak_mb >= memory_mb) or near_cap or (oom_hint ~= nil and not tled) + local mled = (peak_mb >= memory_mb) or near_cap or (oom_hint and not tled) - if tled then - logger.log(('Execution timed out in %.1fms.'):format(dt)) - elseif mled then - logger.log(('Execution memory limit exceeded in %.1fms.'):format(dt)) - elseif code ~= 0 then - logger.log(('Execution failed in %.1fms (exit code %d).'):format(dt, code)) - else - logger.log(('Execution successful in %.1fms.'):format(dt)) - end + if tled then + logger.log(('Execution timed out in %.1fms.'):format(dt)) + elseif mled then + logger.log(('Execution memory limit exceeded in %.1fms.'):format(dt)) + elseif code ~= 0 then + logger.log(('Execution failed in %.1fms (exit code %d).'):format(dt, code)) + else + logger.log(('Execution successful in %.1fms.'):format(dt)) + end - vim.schedule(function() - on_complete({ - stdout = cleaned, - code = code, - time_ms = dt, - tled = tled, - mled = mled, - peak_mb = peak_mb, - signal = signal, - }) - end) - end) + return { + stdout = cleaned, + code = code, + time_ms = dt, + tled = tled, + mled = mled, + peak_mb = peak_mb, + signal = signal, + } end ----@param debug boolean? ----@param on_complete fun(result: {success: boolean, output: string?}) -function M.compile_problem(debug, on_complete) +function M.compile_problem(debug) local state = require('cp.state') local config = require('cp.config').get_config() local platform = state.get_platform() local language = state.get_language() or config.platforms[platform].default_language local eff = config.runtime.effective[platform][language] - local source_file = state.get_source_file() - if source_file then - local buf = vim.fn.bufnr(source_file) - if buf ~= -1 and vim.api.nvim_buf_is_loaded(buf) and vim.bo[buf].modified then - vim.api.nvim_buf_call(buf, function() - vim.cmd.write({ mods = { silent = true, noautocmd = true } }) - end) - end - end - local compile_config = (debug and eff.commands.debug) or eff.commands.build if not compile_config then - on_complete({ success = true, output = nil }) - return + return { success = true, output = nil } end - require('cp.utils').ensure_dirs() - local binary = debug and state.get_debug_file() or state.get_binary_file() local substitutions = { source = state.get_source_file(), binary = binary } + local r = M.compile(compile_config, substitutions) - M.compile(compile_config, substitutions, function(r) - if r.code ~= 0 then - on_complete({ success = false, output = r.stdout or 'unknown error' }) - else - on_complete({ success = true, output = nil }) - end - end) + if r.code ~= 0 then + return { success = false, output = r.stdout or 'unknown error' } + end + return { success = true, output = nil } end return M diff --git a/lua/cp/runner/run.lua b/lua/cp/runner/run.lua index 36a560c..80024a8 100644 --- a/lua/cp/runner/run.lua +++ b/lua/cp/runner/run.lua @@ -101,8 +101,8 @@ end ---@param test_case RanTestCase ---@param debug boolean? ----@param on_complete fun(result: { status: "pass"|"fail"|"tle"|"mle", actual: string, actual_highlights: Highlight[], error: string, stderr: string, time_ms: number, code: integer, ok: boolean, signal: string?, tled: boolean, mled: boolean, rss_mb: number }) -local function run_single_test_case(test_case, debug, on_complete) +---@return { status: "pass"|"fail"|"tle"|"mle", actual: string, actual_highlights: Highlight[], error: string, stderr: string, time_ms: number, code: integer, ok: boolean, signal: string, tled: boolean, mled: boolean, rss_mb: number } +local function run_single_test_case(test_case, debug) local source_file = state.get_source_file() local binary_file = debug and state.get_debug_file() or state.get_binary_file() @@ -117,65 +117,65 @@ local function run_single_test_case(test_case, debug, on_complete) local timeout_ms = (panel_state.constraints and panel_state.constraints.timeout_ms) or 0 local memory_mb = panel_state.constraints and panel_state.constraints.memory_mb or 0 - execute.run(cmd, stdin_content, timeout_ms, memory_mb, function(r) - local ansi = require('cp.ui.ansi') - local out = r.stdout or '' - local highlights = {} - if out ~= '' then - if config.ui.ansi then - local parsed = ansi.parse_ansi_text(out) - out = table.concat(parsed.lines, '\n') - highlights = parsed.highlights - else - out = out:gsub('\027%[[%d;]*[a-zA-Z]', '') - end - end + local r = execute.run(cmd, stdin_content, timeout_ms, memory_mb) - local max_lines = config.ui.panel.max_output_lines - local lines = vim.split(out, '\n') - if #lines > max_lines then - local trimmed = {} - for i = 1, max_lines do - table.insert(trimmed, lines[i]) - end - table.insert(trimmed, string.format('... (output trimmed after %d lines)', max_lines)) - out = table.concat(trimmed, '\n') - end - - local expected = test_case.expected or '' - local ok = normalize_lines(out) == normalize_lines(expected) - - local signal = r.signal - if not signal and r.code and r.code >= 128 then - signal = constants.signal_codes[r.code] - end - - local status - if r.tled then - status = 'tle' - elseif r.mled then - status = 'mle' - elseif ok then - status = 'pass' + local ansi = require('cp.ui.ansi') + local out = r.stdout or '' + local highlights = {} + if out ~= '' then + if config.ui.ansi then + local parsed = ansi.parse_ansi_text(out) + out = table.concat(parsed.lines, '\n') + highlights = parsed.highlights else - status = 'fail' + out = out:gsub('\027%[[%d;]*[a-zA-Z]', '') end + end - on_complete({ - status = status, - actual = out, - actual_highlights = highlights, - error = (r.code ~= 0 and not ok) and out or '', - stderr = '', - time_ms = r.time_ms, - code = r.code, - ok = ok, - signal = signal, - tled = r.tled or false, - mled = r.mled or false, - rss_mb = r.peak_mb or 0, - }) - end) + local max_lines = config.ui.panel.max_output_lines + local lines = vim.split(out, '\n') + if #lines > max_lines then + local trimmed = {} + for i = 1, max_lines do + table.insert(trimmed, lines[i]) + end + table.insert(trimmed, string.format('... (output trimmed after %d lines)', max_lines)) + out = table.concat(trimmed, '\n') + end + + local expected = test_case.expected or '' + local ok = normalize_lines(out) == normalize_lines(expected) + + local signal = r.signal + if not signal and r.code and r.code >= 128 then + signal = constants.signal_codes[r.code] + end + + local status + if r.tled then + status = 'tle' + elseif r.mled then + status = 'mle' + elseif ok then + status = 'pass' + else + status = 'fail' + end + + return { + status = status, + actual = out, + actual_highlights = highlights, + error = (r.code ~= 0 and not ok) and out or '', + stderr = '', + time_ms = r.time_ms, + code = r.code, + ok = ok, + signal = signal, + tled = r.tled or false, + mled = r.mled or false, + rss_mb = r.peak_mb or 0, + } end ---@return boolean @@ -199,8 +199,8 @@ function M.load_test_cases() end ---@param debug boolean? ----@param on_complete fun(result: RanTestCase?) -function M.run_combined_test(debug, on_complete) +---@return RanTestCase? +function M.run_combined_test(debug) local combined = cache.get_combined_test( state.get_platform() or '', state.get_contest_id() or '', @@ -209,8 +209,7 @@ function M.run_combined_test(debug, on_complete) if not combined then logger.log('No combined test found', vim.log.levels.ERROR) - on_complete(nil) - return + return nil end local ran_test = { @@ -229,45 +228,42 @@ function M.run_combined_test(debug, on_complete) selected = true, } - run_single_test_case(ran_test, debug, function(result) - on_complete(result) - end) + local result = run_single_test_case(ran_test, debug) + return result end ---@param index number ---@param debug boolean? ----@param on_complete fun(success: boolean) -function M.run_test_case(index, debug, on_complete) +---@return boolean +function M.run_test_case(index, debug) local tc = panel_state.test_cases[index] if not tc then - on_complete(false) - return + return false end tc.status = 'running' - run_single_test_case(tc, debug, function(r) - tc.status = r.status - tc.actual = r.actual - tc.actual_highlights = r.actual_highlights - tc.error = r.error - tc.stderr = r.stderr - tc.time_ms = r.time_ms - tc.code = r.code - tc.ok = r.ok - tc.signal = r.signal - tc.tled = r.tled - tc.mled = r.mled - tc.rss_mb = r.rss_mb + local r = run_single_test_case(tc, debug) - on_complete(true) - end) + tc.status = r.status + tc.actual = r.actual + tc.actual_highlights = r.actual_highlights + tc.error = r.error + tc.stderr = r.stderr + tc.time_ms = r.time_ms + tc.code = r.code + tc.ok = r.ok + tc.signal = r.signal + tc.tled = r.tled + tc.mled = r.mled + tc.rss_mb = r.rss_mb + + return true end ---@param indices? integer[] ---@param debug boolean? ----@param on_each? fun(index: integer, total: integer) ----@param on_done fun(results: RanTestCase[]) -function M.run_all_test_cases(indices, debug, on_each, on_done) +---@return RanTestCase[] +function M.run_all_test_cases(indices, debug) local to_run = indices if not to_run then to_run = {} @@ -276,26 +272,20 @@ function M.run_all_test_cases(indices, debug, on_each, on_done) end end - local function run_next(pos) - if pos > #to_run then - logger.log( - ('Finished %s %d test cases.'):format(debug and 'debugging' or 'running', #to_run), - vim.log.levels.INFO, - true - ) - on_done(panel_state.test_cases) - return - end - - M.run_test_case(to_run[pos], debug, function() - if on_each then - on_each(pos, #to_run) - end - run_next(pos + 1) - end) + for _, i in ipairs(to_run) do + M.run_test_case(i, debug) end - run_next(1) + logger.log( + ('Finished %s %s test cases.'):format( + debug and 'debugging' or 'running', + #panel_state.test_cases + ), + vim.log.levels.INFO, + true + ) + + return panel_state.test_cases end ---@return PanelState diff --git a/lua/cp/runner/run_render.lua b/lua/cp/runner/run_render.lua index 2dfb45b..714ecd3 100644 --- a/lua/cp/runner/run_render.lua +++ b/lua/cp/runner/run_render.lua @@ -4,10 +4,6 @@ local M = {} -local function strwidth(s) - return vim.api.nvim_strwidth(s) -end - local exit_code_names = { [128] = 'SIGHUP', [129] = 'SIGINT', @@ -30,12 +26,6 @@ local exit_code_names = { ---@param ran_test_case RanTestCase ---@return StatusInfo function M.get_status_info(ran_test_case) - if ran_test_case.status == 'pending' then - return { text = '...', highlight_group = 'CpTestNA' } - elseif ran_test_case.status == 'running' then - return { text = 'RUN', highlight_group = 'CpTestNA' } - end - if ran_test_case.ok then return { text = 'AC', highlight_group = 'CpTestAC' } end @@ -44,7 +34,7 @@ function M.get_status_info(ran_test_case) return { text = 'TLE', highlight_group = 'CpTestTLE' } elseif ran_test_case.mled then return { text = 'MLE', highlight_group = 'CpTestMLE' } - elseif ran_test_case.code and ran_test_case.code >= 128 then + elseif ran_test_case.code > 0 and ran_test_case.code >= 128 then return { text = 'RTE', highlight_group = 'CpTestRTE' } elseif ran_test_case.code == 0 and not ran_test_case.ok then return { text = 'WA', highlight_group = 'CpTestWA' } @@ -73,24 +63,24 @@ local function compute_cols(test_state) for i, tc in ipairs(test_state.test_cases) do local prefix = (i == test_state.current_index) and '>' or ' ' - w.num = math.max(w.num, strwidth(' ' .. prefix .. i .. ' ')) - w.status = math.max(w.status, strwidth(' ' .. M.get_status_info(tc).text .. ' ')) + w.num = math.max(w.num, #(' ' .. prefix .. i .. ' ')) + w.status = math.max(w.status, #(' ' .. M.get_status_info(tc).text .. ' ')) local time_str = tc.time_ms and string.format('%.2f', tc.time_ms) or '—' - w.time = math.max(w.time, strwidth(' ' .. time_str .. ' ')) - w.timeout = math.max(w.timeout, strwidth(' ' .. timeout_str .. ' ')) + w.time = math.max(w.time, #(' ' .. time_str .. ' ')) + w.timeout = math.max(w.timeout, #(' ' .. timeout_str .. ' ')) local rss_str = (tc.rss_mb and string.format('%.0f', tc.rss_mb)) or '—' - w.rss = math.max(w.rss, strwidth(' ' .. rss_str .. ' ')) - w.memory = math.max(w.memory, strwidth(' ' .. memory_str .. ' ')) - w.exit = math.max(w.exit, strwidth(' ' .. format_exit_code(tc.code) .. ' ')) + w.rss = math.max(w.rss, #(' ' .. rss_str .. ' ')) + w.memory = math.max(w.memory, #(' ' .. memory_str .. ' ')) + w.exit = math.max(w.exit, #(' ' .. format_exit_code(tc.code) .. ' ')) end - w.num = math.max(w.num, strwidth(' # ')) - w.status = math.max(w.status, strwidth(' Status ')) - w.time = math.max(w.time, strwidth(' Runtime (ms) ')) - w.timeout = math.max(w.timeout, strwidth(' Time (ms) ')) - w.rss = math.max(w.rss, strwidth(' RSS (MB) ')) - w.memory = math.max(w.memory, strwidth(' Mem (MB) ')) - w.exit = math.max(w.exit, strwidth(' Exit Code ')) + w.num = math.max(w.num, #' # ') + w.status = math.max(w.status, #' Status ') + w.time = math.max(w.time, #' Runtime (ms) ') + w.timeout = math.max(w.timeout, #' Time (ms) ') + w.rss = math.max(w.rss, #' RSS (MB) ') + w.memory = math.max(w.memory, #' Mem (MB) ') + w.exit = math.max(w.exit, #' Exit Code ') local sum = w.num + w.status + w.time + w.timeout + w.rss + w.memory + w.exit local inner = sum + 6 @@ -99,7 +89,7 @@ local function compute_cols(test_state) end local function center(text, width) - local pad = width - strwidth(text) + local pad = width - #text if pad <= 0 then return text end @@ -111,7 +101,7 @@ local function format_num_column(prefix, idx, width) local num_str = tostring(idx) local content = (#num_str == 1) and (' ' .. prefix .. ' ' .. num_str .. ' ') or (' ' .. prefix .. num_str .. ' ') - local total_pad = width - strwidth(content) + local total_pad = width - #content if total_pad <= 0 then return content end @@ -324,10 +314,10 @@ function M.render_test_list(test_state) for _, input_line in ipairs(vim.split(tc.input, '\n', { plain = true, trimempty = false })) do local s = input_line or '' - if strwidth(s) > c.inner then + if #s > c.inner then s = string.sub(s, 1, c.inner) end - local pad = c.inner - strwidth(s) + local pad = c.inner - #s table.insert(lines, '│' .. s .. string.rep(' ', pad) .. '│') end @@ -367,12 +357,14 @@ end ---@return table function M.get_highlight_groups() return { - CpTestAC = { link = 'DiagnosticOk' }, - CpTestWA = { link = 'DiagnosticError' }, - CpTestTLE = { link = 'DiagnosticWarn' }, - CpTestMLE = { link = 'DiagnosticWarn' }, - CpTestRTE = { link = 'DiagnosticHint' }, - CpTestNA = { link = 'Comment' }, + CpTestAC = { fg = '#10b981' }, + CpTestWA = { fg = '#ef4444' }, + CpTestTLE = { fg = '#f59e0b' }, + CpTestMLE = { fg = '#f59e0b' }, + CpTestRTE = { fg = '#8b5cf6' }, + CpTestNA = { fg = '#6b7280' }, + CpDiffRemoved = { fg = '#ef4444', bg = '#1f1f1f' }, + CpDiffAdded = { fg = '#10b981', bg = '#1f1f1f' }, } end diff --git a/lua/cp/scraper.lua b/lua/cp/scraper.lua index c42d8be..3c0af30 100644 --- a/lua/cp/scraper.lua +++ b/lua/cp/scraper.lua @@ -186,7 +186,7 @@ function M.scrape_all_tests(platform, contest_id, callback) return end vim.schedule(function() - require('cp.utils').ensure_dirs() + vim.system({ 'mkdir', '-p', 'build', 'io' }):wait() local config = require('cp.config') local base_name = config.default_filename(contest_id, ev.problem_id) for i, t in ipairs(ev.tests) do diff --git a/lua/cp/setup.lua b/lua/cp/setup.lua index fce1a0c..d05b417 100644 --- a/lua/cp/setup.lua +++ b/lua/cp/setup.lua @@ -82,7 +82,7 @@ local function start_tests(platform, contest_id, problems) return not vim.tbl_isempty(cache.get_test_cases(platform, contest_id, p.id)) end, problems) if cached_len ~= #problems then - logger.log(('Fetching %s/%s problem tests...'):format(cached_len, #problems)) + logger.log(('Fetching problem test data... (%d/%d)'):format(cached_len, #problems)) scraper.scrape_all_tests(platform, contest_id, function(ev) local cached_tests = {} if not ev.interactive and vim.tbl_isempty(ev.tests) then @@ -348,8 +348,6 @@ function M.navigate_problem(direction, language) return end - logger.log(('navigate_problem: %s -> %s'):format(current_problem_id, problems[new_index].id)) - local active_panel = state.get_active_panel() if active_panel == 'run' then require('cp.ui.views').disable() diff --git a/lua/cp/state.lua b/lua/cp/state.lua index 6d99cbf..40eed86 100644 --- a/lua/cp/state.lua +++ b/lua/cp/state.lua @@ -10,7 +10,6 @@ ---@field output_buf integer ---@field input_buf integer ---@field current_test_index integer? ----@field source_buf integer? ---@class cp.State ---@field get_platform fun(): string? diff --git a/lua/cp/ui/highlight.lua b/lua/cp/ui/highlight.lua index a0dd17d..02bf1ae 100644 --- a/lua/cp/ui/highlight.lua +++ b/lua/cp/ui/highlight.lua @@ -26,7 +26,7 @@ local function parse_diff_line(text) line = 0, col_start = highlight_start, col_end = #result_text, - highlight_group = 'DiffDelete', + highlight_group = 'CpDiffRemoved', }) pos = removed_end + 1 else @@ -38,7 +38,7 @@ local function parse_diff_line(text) line = 0, col_start = highlight_start, col_end = #result_text, - highlight_group = 'DiffAdd', + highlight_group = 'CpDiffAdded', }) pos = added_end + 1 else diff --git a/lua/cp/ui/layouts.lua b/lua/cp/ui/layouts.lua index 9b40f49..4e737d3 100644 --- a/lua/cp/ui/layouts.lua +++ b/lua/cp/ui/layouts.lua @@ -3,13 +3,7 @@ local M = {} local helpers = require('cp.helpers') local utils = require('cp.utils') -M.DIFF_MODES = { - ['side-by-side'] = 'side-by-side', - vim = 'vim', - git = 'git', -} - -local function create_side_by_side_layout(parent_win, expected_content, actual_content) +local function create_none_diff_layout(parent_win, expected_content, actual_content) local expected_buf = utils.create_buffer_with_options() local actual_buf = utils.create_buffer_with_options() helpers.clearcol(expected_buf) @@ -27,13 +21,8 @@ local function create_side_by_side_layout(parent_win, expected_content, actual_c vim.api.nvim_set_option_value('filetype', 'cp', { buf = expected_buf }) vim.api.nvim_set_option_value('filetype', 'cp', { buf = actual_buf }) - local label = M.DIFF_MODES['side-by-side'] - vim.api.nvim_set_option_value( - 'winbar', - ('expected (diff: %s)'):format(label), - { win = expected_win } - ) - vim.api.nvim_set_option_value('winbar', ('actual (diff: %s)'):format(label), { win = actual_win }) + vim.api.nvim_set_option_value('winbar', 'Expected', { win = expected_win }) + vim.api.nvim_set_option_value('winbar', 'Actual', { win = actual_win }) local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) local actual_lines = vim.split(actual_content, '\n', { plain = true }) @@ -44,7 +33,6 @@ local function create_side_by_side_layout(parent_win, expected_content, actual_c return { buffers = { expected_buf, actual_buf }, windows = { expected_win, actual_win }, - mode = 'side-by-side', cleanup = function() pcall(vim.api.nvim_win_close, expected_win, true) pcall(vim.api.nvim_win_close, actual_win, true) @@ -72,13 +60,8 @@ local function create_vim_diff_layout(parent_win, expected_content, actual_conte vim.api.nvim_set_option_value('filetype', 'cp', { buf = expected_buf }) vim.api.nvim_set_option_value('filetype', 'cp', { buf = actual_buf }) - local label = M.DIFF_MODES.vim - vim.api.nvim_set_option_value( - 'winbar', - ('expected (diff: %s)'):format(label), - { win = expected_win } - ) - vim.api.nvim_set_option_value('winbar', ('actual (diff: %s)'):format(label), { win = actual_win }) + vim.api.nvim_set_option_value('winbar', 'Expected', { win = expected_win }) + vim.api.nvim_set_option_value('winbar', 'Actual', { win = actual_win }) local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) local actual_lines = vim.split(actual_content, '\n', { plain = true }) @@ -100,7 +83,6 @@ local function create_vim_diff_layout(parent_win, expected_content, actual_conte return { buffers = { expected_buf, actual_buf }, windows = { expected_win, actual_win }, - mode = 'vim', cleanup = function() pcall(vim.api.nvim_win_close, expected_win, true) pcall(vim.api.nvim_win_close, actual_win, true) @@ -121,8 +103,7 @@ local function create_git_diff_layout(parent_win, expected_content, actual_conte vim.api.nvim_win_set_buf(diff_win, diff_buf) vim.api.nvim_set_option_value('filetype', 'cp', { buf = diff_buf }) - local label = M.DIFF_MODES.git - vim.api.nvim_set_option_value('winbar', ('diff: %s'):format(label), { win = diff_win }) + vim.api.nvim_set_option_value('winbar', 'Expected vs Actual', { win = diff_win }) local diff_backend = require('cp.ui.diff') local backend = diff_backend.get_best_backend('git') @@ -140,7 +121,6 @@ local function create_git_diff_layout(parent_win, expected_content, actual_conte return { buffers = { diff_buf }, windows = { diff_win }, - mode = 'git', cleanup = function() pcall(vim.api.nvim_win_close, diff_win, true) pcall(vim.api.nvim_buf_delete, diff_buf, { force = true }) @@ -163,7 +143,6 @@ local function create_single_layout(parent_win, content) return { buffers = { buf }, windows = { win }, - mode = 'single', cleanup = function() pcall(vim.api.nvim_win_close, win, true) pcall(vim.api.nvim_buf_delete, buf, { force = true }) @@ -174,14 +153,12 @@ end function M.create_diff_layout(mode, parent_win, expected_content, actual_content) if mode == 'single' then return create_single_layout(parent_win, actual_content) - elseif mode == 'side-by-side' then - return create_side_by_side_layout(parent_win, expected_content, actual_content) + elseif mode == 'none' then + return create_none_diff_layout(parent_win, expected_content, actual_content) elseif mode == 'git' then return create_git_diff_layout(parent_win, expected_content, actual_content) - elseif mode == 'vim' then - return create_vim_diff_layout(parent_win, expected_content, actual_content) else - return create_side_by_side_layout(parent_win, expected_content, actual_content) + return create_vim_diff_layout(parent_win, expected_content, actual_content) end end @@ -214,13 +191,12 @@ function M.update_diff_panes( actual_content = actual_content end - local default_mode = config.ui.panel.diff_modes[1] - local desired_mode = is_compilation_failure and 'single' or (current_mode or default_mode) + local desired_mode = is_compilation_failure and 'single' or config.ui.panel.diff_mode local highlight = require('cp.ui.highlight') local diff_namespace = highlight.create_namespace() local ansi_namespace = vim.api.nvim_create_namespace('cp_ansi_highlights') - if current_diff_layout and current_diff_layout.mode ~= desired_mode then + if current_diff_layout and current_mode ~= desired_mode then local saved_pos = vim.api.nvim_win_get_cursor(0) current_diff_layout.cleanup() current_diff_layout = nil @@ -275,7 +251,7 @@ function M.update_diff_panes( ansi_namespace ) end - elseif desired_mode == 'side-by-side' then + elseif desired_mode == 'none' then local expected_lines = vim.split(expected_content, '\n', { plain = true, trimempty = true }) local actual_lines = vim.split(actual_content, '\n', { plain = true }) utils.update_buffer_content(current_diff_layout.buffers[1], expected_lines, {}) diff --git a/lua/cp/ui/views.lua b/lua/cp/ui/views.lua index c1d25fc..4609da8 100644 --- a/lua/cp/ui/views.lua +++ b/lua/cp/ui/views.lua @@ -13,7 +13,6 @@ local utils = require('cp.utils') local current_diff_layout = nil local current_mode = nil -local io_view_running = false function M.disable() local active_panel = state.get_active_panel() @@ -82,145 +81,127 @@ function M.toggle_interactive(interactor_cmd) local execute = require('cp.runner.execute') local run = require('cp.runner.run') + local compile_result = execute.compile_problem() + if not compile_result.success then + run.handle_compilation_failure(compile_result.output) + return + end - local function restore_session() + local binary = state.get_binary_file() + if not binary or binary == '' then + logger.log('No binary produced.', vim.log.levels.ERROR) + return + end + + local cmdline + if interactor_cmd and interactor_cmd ~= '' then + local interactor = interactor_cmd + if not interactor:find('/') then + interactor = './' .. interactor + end + if vim.fn.executable(interactor) ~= 1 then + logger.log( + ("Interactor '%s' is not executable."):format(interactor_cmd), + vim.log.levels.ERROR + ) + if state.saved_interactive_session then + vim.cmd.source(state.saved_interactive_session) + vim.fn.delete(state.saved_interactive_session) + state.saved_interactive_session = nil + end + return + end + local orchestrator = vim.fn.fnamemodify(utils.get_plugin_path() .. '/scripts/interact.py', ':p') + cmdline = table.concat({ + 'uv', + 'run', + vim.fn.shellescape(orchestrator), + vim.fn.shellescape(interactor), + vim.fn.shellescape(binary), + }, ' ') + else + cmdline = vim.fn.shellescape(binary) + end + + vim.cmd.terminal(cmdline) + local term_buf = vim.api.nvim_get_current_buf() + local term_win = vim.api.nvim_get_current_win() + + local cleaned = false + local function cleanup() + if cleaned then + return + end + cleaned = true + if term_buf and vim.api.nvim_buf_is_valid(term_buf) then + local job = vim.b[term_buf] and vim.b[term_buf].terminal_job_id or nil + if job then + pcall(vim.fn.jobstop, job) + end + end if state.saved_interactive_session then vim.cmd.source(state.saved_interactive_session) vim.fn.delete(state.saved_interactive_session) state.saved_interactive_session = nil end + state.interactive_buf = nil + state.interactive_win = nil + state.set_active_panel(nil) end - execute.compile_problem(false, function(compile_result) - if not compile_result.success then - run.handle_compilation_failure(compile_result.output) - restore_session() - return - end + vim.api.nvim_create_autocmd({ 'BufWipeout', 'BufUnload' }, { + buffer = term_buf, + callback = cleanup, + }) - local binary = state.get_binary_file() - if not binary or binary == '' then - logger.log('No binary produced.', vim.log.levels.ERROR) - restore_session() - return - end - - local cmdline - if interactor_cmd and interactor_cmd ~= '' then - local interactor = interactor_cmd - if not interactor:find('/') then - interactor = './' .. interactor - end - if vim.fn.executable(interactor) ~= 1 then - logger.log( - ("Interactor '%s' is not executable."):format(interactor_cmd), - vim.log.levels.ERROR - ) - restore_session() - return - end - local orchestrator = - vim.fn.fnamemodify(utils.get_plugin_path() .. '/scripts/interact.py', ':p') - cmdline = table.concat({ - 'uv', - 'run', - vim.fn.shellescape(orchestrator), - vim.fn.shellescape(interactor), - vim.fn.shellescape(binary), - }, ' ') - else - cmdline = vim.fn.shellescape(binary) - end - - vim.cmd.terminal(cmdline) - local term_buf = vim.api.nvim_get_current_buf() - local term_win = vim.api.nvim_get_current_win() - - local cleaned = false - local function cleanup() + vim.api.nvim_create_autocmd('WinClosed', { + callback = function() if cleaned then return end - cleaned = true - if term_buf and vim.api.nvim_buf_is_valid(term_buf) then - local job = vim.b[term_buf] and vim.b[term_buf].terminal_job_id or nil - if job then - pcall(vim.fn.jobstop, job) + local any = false + for _, win in ipairs(vim.api.nvim_list_wins()) do + if vim.api.nvim_win_is_valid(win) and vim.api.nvim_win_get_buf(win) == term_buf then + any = true + break end end - restore_session() - state.interactive_buf = nil - state.interactive_win = nil - state.set_active_panel(nil) - end + if not any then + cleanup() + end + end, + }) - vim.api.nvim_create_autocmd({ 'BufWipeout', 'BufUnload' }, { - buffer = term_buf, - callback = cleanup, - }) + vim.api.nvim_create_autocmd('TermClose', { + buffer = term_buf, + callback = function() + vim.b[term_buf].cp_interactive_exited = true + end, + }) - vim.api.nvim_create_autocmd('WinClosed', { - callback = function() - if cleaned then - return - end - local any = false - for _, win in ipairs(vim.api.nvim_list_wins()) do - if vim.api.nvim_win_is_valid(win) and vim.api.nvim_win_get_buf(win) == term_buf then - any = true - break - end - end - if not any then - cleanup() - end - end, - }) + vim.keymap.set('t', '', function() + cleanup() + end, { buffer = term_buf, silent = true }) + vim.keymap.set('n', '', function() + cleanup() + end, { buffer = term_buf, silent = true }) - vim.api.nvim_create_autocmd('TermClose', { - buffer = term_buf, - callback = function() - vim.b[term_buf].cp_interactive_exited = true - end, - }) - - vim.keymap.set('t', '', function() - cleanup() - end, { buffer = term_buf, silent = true }) - vim.keymap.set('n', '', function() - cleanup() - end, { buffer = term_buf, silent = true }) - - state.interactive_buf = term_buf - state.interactive_win = term_win - state.set_active_panel('interactive') - end) + state.interactive_buf = term_buf + state.interactive_win = term_win + state.set_active_panel('interactive') end ---@return integer, integer local function get_or_create_io_buffers() local io_state = state.get_io_view_state() - local solution_win = state.get_solution_win() - local current_source_buf = vim.api.nvim_win_get_buf(solution_win) if io_state then local output_valid = io_state.output_buf and vim.api.nvim_buf_is_valid(io_state.output_buf) local input_valid = io_state.input_buf and vim.api.nvim_buf_is_valid(io_state.input_buf) - local same_source = io_state.source_buf == current_source_buf - if output_valid and input_valid and same_source then + if output_valid and input_valid then return io_state.output_buf, io_state.input_buf end - - if io_state.source_buf then - pcall(vim.api.nvim_del_augroup_by_name, 'cp_io_cleanup_buf' .. io_state.source_buf) - end - if output_valid then - pcall(vim.api.nvim_buf_delete, io_state.output_buf, { force = true }) - end - if input_valid then - pcall(vim.api.nvim_buf_delete, io_state.input_buf, { force = true }) - end end local output_buf = utils.create_buffer_with_options('cpout') @@ -230,10 +211,10 @@ local function get_or_create_io_buffers() output_buf = output_buf, input_buf = input_buf, current_test_index = 1, - source_buf = current_source_buf, }) - local source_buf = current_source_buf + local solution_win = state.get_solution_win() + local source_buf = vim.api.nvim_win_get_buf(solution_win) local group_name = 'cp_io_cleanup_buf' .. source_buf vim.api.nvim_create_augroup(group_name, { clear = true }) @@ -268,10 +249,6 @@ local function get_or_create_io_buffers() return end - if io.source_buf ~= source_buf then - return - end - local wins = vim.api.nvim_list_wins() for _, win in ipairs(wins) do if vim.api.nvim_win_get_buf(win) == source_buf then @@ -391,8 +368,6 @@ function M.ensure_io_view() return end - require('cp.utils').ensure_dirs() - local source_file = state.get_source_file() if source_file then local source_file_abs = vim.fn.fnamemodify(source_file, ':p') @@ -472,44 +447,166 @@ function M.ensure_io_view() end end -local function render_io_view_results(io_state, test_indices, mode, combined_result, combined_input) - local run = require('cp.runner.run') - local run_render = require('cp.runner.run_render') - local cfg = config_module.get_config() +function M.run_io_view(test_indices_arg, debug, mode) + logger.log(('%s tests...'):format(debug and 'Debugging' or 'Running'), vim.log.levels.INFO, true) + mode = mode or 'combined' + + local platform, contest_id, problem_id = + state.get_platform(), state.get_contest_id(), state.get_problem_id() + if not platform or not contest_id or not problem_id then + logger.log( + 'No platform/contest/problem configured. Use :CP [...] first.', + vim.log.levels.ERROR + ) + return + end + + cache.load() + local contest_data = cache.get_contest_data(platform, contest_id) + if not contest_data or not contest_data.index_map then + logger.log('No test cases available.', vim.log.levels.ERROR) + return + end + + if mode == 'combined' then + local problem_data = contest_data.problems[contest_data.index_map[problem_id]] + if not problem_data.multi_test then + mode = 'individual' + end + end + + local run = require('cp.runner.run') + + if mode == 'combined' then + local combined = cache.get_combined_test(platform, contest_id, problem_id) + if not combined then + logger.log('No combined test available', vim.log.levels.ERROR) + return + end + else + if not run.load_test_cases() then + logger.log('No test cases available', vim.log.levels.ERROR) + return + end + end + + local test_indices = {} + + if mode == 'individual' then + local test_state = run.get_panel_state() + + if test_indices_arg then + for _, idx in ipairs(test_indices_arg) do + if idx < 1 or idx > #test_state.test_cases then + logger.log( + string.format( + 'Test %d does not exist (only %d tests available)', + idx, + #test_state.test_cases + ), + vim.log.levels.WARN + ) + return + end + end + test_indices = test_indices_arg + else + for i = 1, #test_state.test_cases do + test_indices[i] = i + end + end + end + + if not test_indices_arg then + M.ensure_io_view() + end + + local io_state = state.get_io_view_state() + if not io_state then + return + end + + local config = config_module.get_config() + + if config.ui.ansi then + require('cp.ui.ansi').setup_highlight_groups() + end + + local execute = require('cp.runner.execute') + local compile_result = execute.compile_problem(debug) + if not compile_result.success then + local ansi = require('cp.ui.ansi') + local output = compile_result.output or '' + local lines, highlights + + if config.ui.ansi then + local parsed = ansi.parse_ansi_text(output) + lines = parsed.lines + highlights = parsed.highlights + else + lines = vim.split(output:gsub('\027%[[%d;]*[a-zA-Z]', ''), '\n') + highlights = {} + end + + local ns = vim.api.nvim_create_namespace('cp_io_view_compile_error') + utils.update_buffer_content(io_state.output_buf, lines, highlights, ns) + return + end + + local run_render = require('cp.runner.run_render') run_render.setup_highlights() local input_lines = {} local output_lines = {} local verdict_lines = {} local verdict_highlights = {} - local formatter = cfg.ui.run.format_verdict - local test_state = run.get_panel_state() - if mode == 'combined' and combined_result then - input_lines = vim.split(combined_input, '\n') + local formatter = config.ui.run.format_verdict - if combined_result.actual and combined_result.actual ~= '' then - output_lines = vim.split(combined_result.actual, '\n') + if mode == 'combined' then + local combined = cache.get_combined_test(platform, contest_id, problem_id) + + if not combined then + logger.log('No combined test found', vim.log.levels.ERROR) + return end - local status = run_render.get_status_info(combined_result) + run.load_test_cases() + + local result = run.run_combined_test(debug) + + if not result then + logger.log('Failed to run combined test', vim.log.levels.ERROR) + return + end + + input_lines = vim.split(combined.input, '\n') + + if result.actual and result.actual ~= '' then + output_lines = vim.split(result.actual, '\n') + end + + local status = run_render.get_status_info(result) + local test_state = run.get_panel_state() + + ---@type VerdictFormatData local format_data = { index = 1, status = status, - time_ms = combined_result.time_ms or 0, + time_ms = result.time_ms or 0, time_limit_ms = test_state.constraints and test_state.constraints.timeout_ms or 0, - memory_mb = combined_result.rss_mb or 0, + memory_mb = result.rss_mb or 0, memory_limit_mb = test_state.constraints and test_state.constraints.memory_mb or 0, - exit_code = combined_result.code or 0, - signal = (combined_result.code and combined_result.code >= 128) - and require('cp.constants').signal_codes[combined_result.code] + exit_code = result.code or 0, + signal = (result.code and result.code >= 128) + and require('cp.constants').signal_codes[result.code] or nil, - time_actual_width = #string.format('%.2f', combined_result.time_ms or 0), + time_actual_width = #string.format('%.2f', result.time_ms or 0), time_limit_width = #tostring( test_state.constraints and test_state.constraints.timeout_ms or 0 ), - mem_actual_width = #string.format('%.0f', combined_result.rss_mb or 0), + mem_actual_width = #string.format('%.0f', result.rss_mb or 0), mem_limit_width = #string.format( '%.0f', test_state.constraints and test_state.constraints.memory_mb or 0 @@ -530,7 +627,13 @@ local function render_io_view_results(io_state, test_indices, mode, combined_res end end else - local max_time_actual, max_time_limit, max_mem_actual, max_mem_limit = 0, 0, 0, 0 + run.run_all_test_cases(test_indices, debug) + local test_state = run.get_panel_state() + + local max_time_actual = 0 + local max_time_limit = 0 + local max_mem_actual = 0 + local max_mem_limit = 0 for _, idx in ipairs(test_indices) do local tc = test_state.test_cases[idx] @@ -549,9 +652,11 @@ local function render_io_view_results(io_state, test_indices, mode, combined_res local all_outputs = {} for _, idx in ipairs(test_indices) do local tc = test_state.test_cases[idx] + for _, line in ipairs(vim.split(tc.input, '\n')) do table.insert(input_lines, line) end + if tc.actual then table.insert(all_outputs, tc.actual) end @@ -568,6 +673,7 @@ local function render_io_view_results(io_state, test_indices, mode, combined_res local tc = test_state.test_cases[idx] local status = run_render.get_status_info(tc) + ---@type VerdictFormatData local format_data = { index = idx, status = status, @@ -620,169 +726,11 @@ local function render_io_view_results(io_state, test_indices, mode, combined_res end utils.update_buffer_content(io_state.input_buf, input_lines, nil, nil) + local output_ns = vim.api.nvim_create_namespace('cp_io_view_output') utils.update_buffer_content(io_state.output_buf, output_lines, final_highlights, output_ns) end -function M.run_io_view(test_indices_arg, debug, mode) - if io_view_running then - logger.log('Tests already running', vim.log.levels.WARN) - return - end - io_view_running = true - - logger.log(('%s tests...'):format(debug and 'Debugging' or 'Running'), vim.log.levels.INFO, true) - - mode = mode or 'combined' - - local platform, contest_id, problem_id = - state.get_platform(), state.get_contest_id(), state.get_problem_id() - if not platform or not contest_id or not problem_id then - logger.log( - 'No platform/contest/problem configured. Use :CP [...] first.', - vim.log.levels.ERROR - ) - io_view_running = false - return - end - - cache.load() - local contest_data = cache.get_contest_data(platform, contest_id) - if not contest_data or not contest_data.index_map then - logger.log('No test cases available.', vim.log.levels.ERROR) - io_view_running = false - return - end - - if mode == 'combined' then - local problem_data = contest_data.problems[contest_data.index_map[problem_id]] - if not problem_data.multi_test then - mode = 'individual' - end - end - - local run = require('cp.runner.run') - - if mode == 'combined' then - local combined = cache.get_combined_test(platform, contest_id, problem_id) - if not combined then - logger.log('No combined test available', vim.log.levels.ERROR) - io_view_running = false - return - end - else - if not run.load_test_cases() then - logger.log('No test cases available', vim.log.levels.ERROR) - io_view_running = false - return - end - end - - local test_indices = {} - - if mode == 'individual' then - local test_state = run.get_panel_state() - - if test_indices_arg then - for _, idx in ipairs(test_indices_arg) do - if idx < 1 or idx > #test_state.test_cases then - logger.log( - string.format( - 'Test %d does not exist (only %d tests available)', - idx, - #test_state.test_cases - ), - vim.log.levels.WARN - ) - io_view_running = false - return - end - end - test_indices = test_indices_arg - else - for i = 1, #test_state.test_cases do - test_indices[i] = i - end - end - end - - if not test_indices_arg then - M.ensure_io_view() - end - - local io_state = state.get_io_view_state() - if not io_state then - io_view_running = false - return - end - - local cfg = config_module.get_config() - - if cfg.ui.ansi then - require('cp.ui.ansi').setup_highlight_groups() - end - - local execute = require('cp.runner.execute') - - execute.compile_problem(debug, function(compile_result) - if not vim.api.nvim_buf_is_valid(io_state.output_buf) then - io_view_running = false - return - end - - if not compile_result.success then - local ansi = require('cp.ui.ansi') - local output = compile_result.output or '' - local lines, highlights - - if cfg.ui.ansi then - local parsed = ansi.parse_ansi_text(output) - lines = parsed.lines - highlights = parsed.highlights - else - lines = vim.split(output:gsub('\027%[[%d;]*[a-zA-Z]', ''), '\n') - highlights = {} - end - - local ns = vim.api.nvim_create_namespace('cp_io_view_compile_error') - utils.update_buffer_content(io_state.output_buf, lines, highlights, ns) - io_view_running = false - return - end - - if mode == 'combined' then - local combined = cache.get_combined_test(platform, contest_id, problem_id) - if not combined then - logger.log('No combined test found', vim.log.levels.ERROR) - io_view_running = false - return - end - - run.load_test_cases() - - run.run_combined_test(debug, function(result) - if not result then - logger.log('Failed to run combined test', vim.log.levels.ERROR) - io_view_running = false - return - end - - if vim.api.nvim_buf_is_valid(io_state.output_buf) then - render_io_view_results(io_state, test_indices, mode, result, combined.input) - end - io_view_running = false - end) - else - run.run_all_test_cases(test_indices, debug, nil, function() - if vim.api.nvim_buf_is_valid(io_state.output_buf) then - render_io_view_results(io_state, test_indices, mode, nil, nil) - end - io_view_running = false - end) - end - end) -end - ---@param panel_opts? PanelOpts function M.toggle_panel(panel_opts) if state.get_active_panel() == 'run' then @@ -880,9 +828,6 @@ function M.toggle_panel(panel_opts) end local function refresh_panel() - if state.get_active_panel() ~= 'run' then - return - end if not test_buffers.tab_buf or not vim.api.nvim_buf_is_valid(test_buffers.tab_buf) then return end @@ -908,10 +853,6 @@ function M.toggle_panel(panel_opts) vim.cmd.normal({ 'zz', bang = true }) end) end - - if test_windows.tab_win and vim.api.nvim_win_is_valid(test_windows.tab_win) then - vim.api.nvim_set_current_win(test_windows.tab_win) - end end local function navigate_test_case(delta) @@ -928,15 +869,15 @@ function M.toggle_panel(panel_opts) M.toggle_panel() end, { buffer = buf, silent = true }) vim.keymap.set('n', 't', function() - local modes = config.ui.panel.diff_modes + local modes = { 'none', 'git', 'vim' } local current_idx = 1 for i, mode in ipairs(modes) do - if current_mode == mode then + if config.ui.panel.diff_mode == mode then current_idx = i break end end - current_mode = modes[(current_idx % #modes) + 1] + config.ui.panel.diff_mode = modes[(current_idx % #modes) + 1] refresh_panel() end, { buffer = buf, silent = true }) vim.keymap.set('n', '', function() @@ -960,47 +901,30 @@ function M.toggle_panel(panel_opts) end) end + local execute = require('cp.runner.execute') + local compile_result = execute.compile_problem(panel_opts and panel_opts.debug) + if compile_result.success then + run.run_all_test_cases(nil, panel_opts and panel_opts.debug) + else + run.handle_compilation_failure(compile_result.output) + end + + refresh_panel() + + vim.schedule(function() + if config.ui.ansi then + require('cp.ui.ansi').setup_highlight_groups() + end + if current_diff_layout then + update_diff_panes() + end + end) + vim.api.nvim_set_current_win(test_windows.tab_win) state.test_buffers = test_buffers state.test_windows = test_windows state.set_active_panel('run') logger.log('test panel opened') - - refresh_panel() - - local function finalize_panel() - vim.schedule(function() - if state.get_active_panel() ~= 'run' then - return - end - if config.ui.ansi then - require('cp.ui.ansi').setup_highlight_groups() - end - if current_diff_layout then - update_diff_panes() - end - end) - end - - local execute = require('cp.runner.execute') - execute.compile_problem(panel_opts and panel_opts.debug, function(compile_result) - if not test_buffers.tab_buf or not vim.api.nvim_buf_is_valid(test_buffers.tab_buf) then - return - end - - if compile_result.success then - run.run_all_test_cases(nil, panel_opts and panel_opts.debug, function() - refresh_panel() - end, function() - refresh_panel() - finalize_panel() - end) - else - run.handle_compilation_failure(compile_result.output) - refresh_panel() - finalize_panel() - end - end) end return M diff --git a/lua/cp/utils.lua b/lua/cp/utils.lua index 654c2ef..c2f353a 100644 --- a/lua/cp/utils.lua +++ b/lua/cp/utils.lua @@ -262,8 +262,4 @@ function M.cwd_executables() return out end -function M.ensure_dirs() - vim.system({ 'mkdir', '-p', 'build', 'io' }):wait() -end - return M diff --git a/new b/new deleted file mode 100644 index e69de29..0000000 diff --git a/scrapers/atcoder.py b/scrapers/atcoder.py index 1b946dd..66b95aa 100644 --- a/scrapers/atcoder.py +++ b/scrapers/atcoder.py @@ -266,31 +266,43 @@ class AtcoderScraper(BaseScraper): return "atcoder" async def scrape_contest_metadata(self, contest_id: str) -> MetadataResult: - try: - rows = await asyncio.to_thread(_scrape_tasks_sync, contest_id) + async def impl(cid: str) -> MetadataResult: + try: + rows = await asyncio.to_thread(_scrape_tasks_sync, cid) + except requests.HTTPError as e: + if e.response is not None and e.response.status_code == 404: + return self._create_metadata_error( + f"No problems found for contest {cid}", cid + ) + raise + problems = _to_problem_summaries(rows) if not problems: - return self._metadata_error( - f"No problems found for contest {contest_id}" + return self._create_metadata_error( + f"No problems found for contest {cid}", cid ) + return MetadataResult( success=True, error="", - contest_id=contest_id, + contest_id=cid, problems=problems, url=f"https://atcoder.jp/contests/{contest_id}/tasks/{contest_id}_%s", ) - except Exception as e: - return self._metadata_error(str(e)) + + return await self._safe_execute("metadata", impl, contest_id) async def scrape_contest_list(self) -> ContestListResult: - try: - contests = await _fetch_all_contests_async() + async def impl() -> ContestListResult: + try: + contests = await _fetch_all_contests_async() + except Exception as e: + return self._create_contests_error(str(e)) if not contests: - return self._contests_error("No contests found") + return self._create_contests_error("No contests found") return ContestListResult(success=True, error="", contests=contests) - except Exception as e: - return self._contests_error(str(e)) + + return await self._safe_execute("contests", impl) async def stream_tests_for_category_async(self, category_id: str) -> None: rows = await asyncio.to_thread(_scrape_tasks_sync, category_id) diff --git a/scrapers/base.py b/scrapers/base.py index 4b685d0..6409c9a 100644 --- a/scrapers/base.py +++ b/scrapers/base.py @@ -1,8 +1,9 @@ -import asyncio -import sys from abc import ABC, abstractmethod +from typing import Any, Awaitable, Callable, ParamSpec, cast -from .models import CombinedTest, ContestListResult, MetadataResult, TestsResult +from .models import ContestListResult, MetadataResult, TestsResult + +P = ParamSpec("P") class BaseScraper(ABC): @@ -19,65 +20,57 @@ class BaseScraper(ABC): @abstractmethod async def stream_tests_for_category_async(self, category_id: str) -> None: ... - def _usage(self) -> str: - name = self.platform_name - return f"Usage: {name}.py metadata | tests | contests" + def _create_metadata_error( + self, error_msg: str, contest_id: str = "" + ) -> MetadataResult: + return MetadataResult( + success=False, + error=f"{self.platform_name}: {error_msg}", + contest_id=contest_id, + problems=[], + url="", + ) - def _metadata_error(self, msg: str) -> MetadataResult: - return MetadataResult(success=False, error=msg, url="") + def _create_tests_error( + self, error_msg: str, problem_id: str = "", url: str = "" + ) -> TestsResult: + from .models import CombinedTest - def _tests_error(self, msg: str) -> TestsResult: return TestsResult( success=False, - error=msg, - problem_id="", + error=f"{self.platform_name}: {error_msg}", + problem_id=problem_id, combined=CombinedTest(input="", expected=""), tests=[], timeout_ms=0, memory_mb=0, + interactive=False, ) - def _contests_error(self, msg: str) -> ContestListResult: - return ContestListResult(success=False, error=msg) + def _create_contests_error(self, error_msg: str) -> ContestListResult: + return ContestListResult( + success=False, + error=f"{self.platform_name}: {error_msg}", + contests=[], + ) - async def _run_cli_async(self, args: list[str]) -> int: - if len(args) < 2: - print(self._metadata_error(self._usage()).model_dump_json()) - return 1 - - mode = args[1] - - match mode: - case "metadata": - if len(args) != 3: - print(self._metadata_error(self._usage()).model_dump_json()) - return 1 - result = await self.scrape_contest_metadata(args[2]) - print(result.model_dump_json()) - return 0 if result.success else 1 - - case "tests": - if len(args) != 3: - print(self._tests_error(self._usage()).model_dump_json()) - return 1 - await self.stream_tests_for_category_async(args[2]) - return 0 - - case "contests": - if len(args) != 2: - print(self._contests_error(self._usage()).model_dump_json()) - return 1 - result = await self.scrape_contest_list() - print(result.model_dump_json()) - return 0 if result.success else 1 - - case _: - print( - self._metadata_error( - f"Unknown mode: {mode}. {self._usage()}" - ).model_dump_json() - ) - return 1 - - def run_cli(self) -> None: - sys.exit(asyncio.run(self._run_cli_async(sys.argv))) + async def _safe_execute( + self, + operation: str, + func: Callable[P, Awaitable[Any]], + *args: P.args, + **kwargs: P.kwargs, + ): + try: + return await func(*args, **kwargs) + except Exception as e: + if operation == "metadata": + contest_id = cast(str, args[0]) if args else "" + return self._create_metadata_error(str(e), contest_id) + elif operation == "tests": + problem_id = cast(str, args[1]) if len(args) > 1 else "" + return self._create_tests_error(str(e), problem_id) + elif operation == "contests": + return self._create_contests_error(str(e)) + else: + raise diff --git a/scrapers/codechef.py b/scrapers/codechef.py index 0687c1e..1680e83 100644 --- a/scrapers/codechef.py +++ b/scrapers/codechef.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 - import asyncio import json import re +import sys from typing import Any import httpx @@ -10,11 +10,13 @@ from scrapling.fetchers import Fetcher from .base import BaseScraper from .models import ( + CombinedTest, ContestListResult, ContestSummary, MetadataResult, ProblemSummary, TestCase, + TestsResult, ) BASE_URL = "https://www.codechef.com" @@ -60,40 +62,42 @@ class CodeChefScraper(BaseScraper): return "codechef" async def scrape_contest_metadata(self, contest_id: str) -> MetadataResult: - try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient() as client: + try: data = await fetch_json( client, API_CONTEST.format(contest_id=contest_id) ) - if not data.get("problems"): - return self._metadata_error( - f"No problems found for contest {contest_id}" + except httpx.HTTPStatusError as e: + return self._create_metadata_error( + f"Failed to fetch contest {contest_id}: {e}", contest_id ) - problems = [] - for problem_code, problem_data in data["problems"].items(): - if problem_data.get("category_name") == "main": - problems.append( - ProblemSummary( - id=problem_code, - name=problem_data.get("name", problem_code), - ) - ) - return MetadataResult( - success=True, - error="", - contest_id=contest_id, - problems=problems, - url=f"{BASE_URL}/{contest_id}", + if not data.get("problems"): + return self._create_metadata_error( + f"No problems found for contest {contest_id}", contest_id ) - except Exception as e: - return self._metadata_error(f"Failed to fetch contest {contest_id}: {e}") + problems = [] + for problem_code, problem_data in data["problems"].items(): + if problem_data.get("category_name") == "main": + problems.append( + ProblemSummary( + id=problem_code, + name=problem_data.get("name", problem_code), + ) + ) + return MetadataResult( + success=True, + error="", + contest_id=contest_id, + problems=problems, + url=f"{BASE_URL}/{contest_id}", + ) async def scrape_contest_list(self) -> ContestListResult: async with httpx.AsyncClient() as client: try: data = await fetch_json(client, API_CONTESTS_ALL) except httpx.HTTPStatusError as e: - return self._contests_error(f"Failed to fetch contests: {e}") + return self._create_contests_error(f"Failed to fetch contests: {e}") all_contests = data.get("future_contests", []) + data.get( "past_contests", [] ) @@ -106,7 +110,7 @@ class CodeChefScraper(BaseScraper): num = int(match.group(1)) max_num = max(max_num, num) if max_num == 0: - return self._contests_error("No Starters contests found") + return self._create_contests_error("No Starters contests found") contests = [] sem = asyncio.Semaphore(CONNECTIONS) @@ -248,5 +252,68 @@ class CodeChefScraper(BaseScraper): print(json.dumps(payload), flush=True) +async def main_async() -> int: + if len(sys.argv) < 2: + result = MetadataResult( + success=False, + error="Usage: codechef.py metadata OR codechef.py tests OR codechef.py contests", + url="", + ) + print(result.model_dump_json()) + return 1 + mode: str = sys.argv[1] + scraper = CodeChefScraper() + if mode == "metadata": + if len(sys.argv) != 3: + result = MetadataResult( + success=False, + error="Usage: codechef.py metadata ", + url="", + ) + print(result.model_dump_json()) + return 1 + contest_id = sys.argv[2] + result = await scraper.scrape_contest_metadata(contest_id) + print(result.model_dump_json()) + return 0 if result.success else 1 + if mode == "tests": + if len(sys.argv) != 3: + tests_result = TestsResult( + success=False, + error="Usage: codechef.py tests ", + problem_id="", + combined=CombinedTest(input="", expected=""), + tests=[], + timeout_ms=0, + memory_mb=0, + ) + print(tests_result.model_dump_json()) + return 1 + contest_id = sys.argv[2] + await scraper.stream_tests_for_category_async(contest_id) + return 0 + if mode == "contests": + if len(sys.argv) != 2: + contest_result = ContestListResult( + success=False, error="Usage: codechef.py contests" + ) + print(contest_result.model_dump_json()) + return 1 + contest_result = await scraper.scrape_contest_list() + print(contest_result.model_dump_json()) + return 0 if contest_result.success else 1 + result = MetadataResult( + success=False, + error=f"Unknown mode: {mode}. Use 'metadata ', 'tests ', or 'contests'", + url="", + ) + print(result.model_dump_json()) + return 1 + + +def main() -> None: + sys.exit(asyncio.run(main_async())) + + if __name__ == "__main__": - CodeChefScraper().run_cli() + main() diff --git a/scrapers/codeforces.py b/scrapers/codeforces.py index cf172b8..24f55f6 100644 --- a/scrapers/codeforces.py +++ b/scrapers/codeforces.py @@ -4,6 +4,7 @@ import asyncio import json import logging import re +import sys from typing import Any import requests @@ -12,11 +13,13 @@ from scrapling.fetchers import Fetcher from .base import BaseScraper from .models import ( + CombinedTest, ContestListResult, ContestSummary, MetadataResult, ProblemSummary, TestCase, + TestsResult, ) # suppress scrapling logging - https://github.com/D4Vinci/Scrapling/issues/31) @@ -86,14 +89,14 @@ def _extract_samples(block: Tag) -> tuple[list[TestCase], bool]: if not st: return [], False - input_pres: list[Tag] = [ - inp.find("pre") - for inp in st.find_all("div", class_="input") + input_pres: list[Tag] = [ # type: ignore[misc] + inp.find("pre") # type: ignore[misc] + for inp in st.find_all("div", class_="input") # type: ignore[union-attr] if isinstance(inp, Tag) and inp.find("pre") ] output_pres: list[Tag] = [ - out.find("pre") - for out in st.find_all("div", class_="output") + out.find("pre") # type: ignore[misc] + for out in st.find_all("div", class_="output") # type: ignore[union-attr] if isinstance(out, Tag) and out.find("pre") ] input_pres = [p for p in input_pres if isinstance(p, Tag)] @@ -206,46 +209,49 @@ class CodeforcesScraper(BaseScraper): return "codeforces" async def scrape_contest_metadata(self, contest_id: str) -> MetadataResult: - try: - problems = await asyncio.to_thread( - _scrape_contest_problems_sync, contest_id - ) + async def impl(cid: str) -> MetadataResult: + problems = await asyncio.to_thread(_scrape_contest_problems_sync, cid) if not problems: - return self._metadata_error( - f"No problems found for contest {contest_id}" + return self._create_metadata_error( + f"No problems found for contest {cid}", cid ) return MetadataResult( success=True, error="", - contest_id=contest_id, + contest_id=cid, problems=problems, url=f"https://codeforces.com/contest/{contest_id}/problem/%s", ) - except Exception as e: - return self._metadata_error(str(e)) + + return await self._safe_execute("metadata", impl, contest_id) async def scrape_contest_list(self) -> ContestListResult: - try: - r = requests.get(API_CONTEST_LIST_URL, timeout=TIMEOUT_SECONDS) - r.raise_for_status() - data = r.json() - if data.get("status") != "OK": - return self._contests_error("Invalid API response") + async def impl() -> ContestListResult: + try: + r = requests.get(API_CONTEST_LIST_URL, timeout=TIMEOUT_SECONDS) + r.raise_for_status() + data = r.json() + if data.get("status") != "OK": + return self._create_contests_error("Invalid API response") - contests: list[ContestSummary] = [] - for c in data["result"]: - if c.get("phase") != "FINISHED": - continue - cid = str(c["id"]) - name = c["name"] - contests.append(ContestSummary(id=cid, name=name, display_name=name)) + contests: list[ContestSummary] = [] + for c in data["result"]: + if c.get("phase") != "FINISHED": + continue + cid = str(c["id"]) + name = c["name"] + contests.append( + ContestSummary(id=cid, name=name, display_name=name) + ) - if not contests: - return self._contests_error("No contests found") + if not contests: + return self._create_contests_error("No contests found") - return ContestListResult(success=True, error="", contests=contests) - except Exception as e: - return self._contests_error(str(e)) + return ContestListResult(success=True, error="", contests=contests) + except Exception as e: + return self._create_contests_error(str(e)) + + return await self._safe_execute("contests", impl) async def stream_tests_for_category_async(self, category_id: str) -> None: html = await asyncio.to_thread(_fetch_problems_html, category_id) @@ -275,5 +281,73 @@ class CodeforcesScraper(BaseScraper): ) +async def main_async() -> int: + if len(sys.argv) < 2: + result = MetadataResult( + success=False, + error="Usage: codeforces.py metadata OR codeforces.py tests OR codeforces.py contests", + url="", + ) + print(result.model_dump_json()) + return 1 + + mode: str = sys.argv[1] + scraper = CodeforcesScraper() + + if mode == "metadata": + if len(sys.argv) != 3: + result = MetadataResult( + success=False, + error="Usage: codeforces.py metadata ", + url="", + ) + print(result.model_dump_json()) + return 1 + contest_id = sys.argv[2] + result = await scraper.scrape_contest_metadata(contest_id) + print(result.model_dump_json()) + return 0 if result.success else 1 + + if mode == "tests": + if len(sys.argv) != 3: + tests_result = TestsResult( + success=False, + error="Usage: codeforces.py tests ", + problem_id="", + combined=CombinedTest(input="", expected=""), + tests=[], + timeout_ms=0, + memory_mb=0, + ) + print(tests_result.model_dump_json()) + return 1 + contest_id = sys.argv[2] + await scraper.stream_tests_for_category_async(contest_id) + return 0 + + if mode == "contests": + if len(sys.argv) != 2: + contest_result = ContestListResult( + success=False, error="Usage: codeforces.py contests" + ) + print(contest_result.model_dump_json()) + return 1 + contest_result = await scraper.scrape_contest_list() + print(contest_result.model_dump_json()) + return 0 if contest_result.success else 1 + + result = MetadataResult( + success=False, + error="Unknown mode. Use 'metadata ', 'tests ', or 'contests'", + url="", + ) + print(result.model_dump_json()) + return 1 + + +def main() -> None: + sys.exit(asyncio.run(main_async())) + + if __name__ == "__main__": - CodeforcesScraper().run_cli() + main() diff --git a/scrapers/cses.py b/scrapers/cses.py index 5440b34..620cb7f 100644 --- a/scrapers/cses.py +++ b/scrapers/cses.py @@ -3,17 +3,20 @@ import asyncio import json import re +import sys from typing import Any import httpx from .base import BaseScraper from .models import ( + CombinedTest, ContestListResult, ContestSummary, MetadataResult, ProblemSummary, TestCase, + TestsResult, ) BASE_URL = "https://cses.fi" @@ -258,5 +261,73 @@ class CSESScraper(BaseScraper): print(json.dumps(payload), flush=True) +async def main_async() -> int: + if len(sys.argv) < 2: + result = MetadataResult( + success=False, + error="Usage: cses.py metadata OR cses.py tests OR cses.py contests", + url="", + ) + print(result.model_dump_json()) + return 1 + + mode: str = sys.argv[1] + scraper = CSESScraper() + + if mode == "metadata": + if len(sys.argv) != 3: + result = MetadataResult( + success=False, + error="Usage: cses.py metadata ", + url="", + ) + print(result.model_dump_json()) + return 1 + category_id = sys.argv[2] + result = await scraper.scrape_contest_metadata(category_id) + print(result.model_dump_json()) + return 0 if result.success else 1 + + if mode == "tests": + if len(sys.argv) != 3: + tests_result = TestsResult( + success=False, + error="Usage: cses.py tests ", + problem_id="", + combined=CombinedTest(input="", expected=""), + tests=[], + timeout_ms=0, + memory_mb=0, + ) + print(tests_result.model_dump_json()) + return 1 + category = sys.argv[2] + await scraper.stream_tests_for_category_async(category) + return 0 + + if mode == "contests": + if len(sys.argv) != 2: + contest_result = ContestListResult( + success=False, error="Usage: cses.py contests" + ) + print(contest_result.model_dump_json()) + return 1 + contest_result = await scraper.scrape_contest_list() + print(contest_result.model_dump_json()) + return 0 if contest_result.success else 1 + + result = MetadataResult( + success=False, + error=f"Unknown mode: {mode}. Use 'metadata ', 'tests ', or 'contests'", + url="", + ) + print(result.model_dump_json()) + return 1 + + +def main() -> None: + sys.exit(asyncio.run(main_async())) + + if __name__ == "__main__": - CSESScraper().run_cli() + main() diff --git a/tests/conftest.py b/tests/conftest.py index aaefec8..63e6108 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -232,35 +232,33 @@ def run_scraper_offline(fixture_text): case _: raise AssertionError(f"Unknown scraper: {scraper_name}") - scraper_classes = { - "cses": "CSESScraper", - "atcoder": "AtcoderScraper", - "codeforces": "CodeforcesScraper", - "codechef": "CodeChefScraper", - } - def _run(scraper_name: str, mode: str, *args: str): mod_path = ROOT / "scrapers" / f"{scraper_name}.py" ns = _load_scraper_module(mod_path, scraper_name) offline_fetches = _make_offline_fetches(scraper_name) if scraper_name == "codeforces": - fetchers.Fetcher.get = offline_fetches["Fetcher.get"] + fetchers.Fetcher.get = offline_fetches["Fetcher.get"] # type: ignore[assignment] requests.get = offline_fetches["requests.get"] elif scraper_name == "atcoder": ns._fetch = offline_fetches["_fetch"] ns._get_async = offline_fetches["_get_async"] elif scraper_name == "cses": - httpx.AsyncClient.get = offline_fetches["__offline_fetch_text"] + httpx.AsyncClient.get = offline_fetches["__offline_fetch_text"] # type: ignore[assignment] elif scraper_name == "codechef": - httpx.AsyncClient.get = offline_fetches["__offline_get_async"] - fetchers.Fetcher.get = offline_fetches["Fetcher.get"] + httpx.AsyncClient.get = offline_fetches["__offline_get_async"] # type: ignore[assignment] + fetchers.Fetcher.get = offline_fetches["Fetcher.get"] # type: ignore[assignment] - scraper_class = getattr(ns, scraper_classes[scraper_name]) - scraper = scraper_class() + main_async = getattr(ns, "main_async") + assert callable(main_async), f"main_async not found in {scraper_name}" argv = [str(mod_path), mode, *args] - rc, out = _capture_stdout(scraper._run_cli_async(argv)) + old_argv = sys.argv + sys.argv = argv + try: + rc, out = _capture_stdout(main_async()) + finally: + sys.argv = old_argv json_lines: list[Any] = [] for line in (_line for _line in out.splitlines() if _line.strip()):