diff --git a/lua/cp/stress.lua b/lua/cp/stress.lua new file mode 100644 index 0000000..3e51881 --- /dev/null +++ b/lua/cp/stress.lua @@ -0,0 +1,235 @@ +local M = {} + +local logger = require('cp.log') +local state = require('cp.state') +local utils = require('cp.utils') + +local GENERATOR_PATTERNS = { + 'gen.py', + 'gen.cc', + 'gen.cpp', + 'generator.py', + 'generator.cc', + 'generator.cpp', +} + +local BRUTE_PATTERNS = { + 'brute.py', + 'brute.cc', + 'brute.cpp', + 'slow.py', + 'slow.cc', + 'slow.cpp', +} + +local function find_file(patterns) + for _, pattern in ipairs(patterns) do + if vim.fn.filereadable(pattern) == 1 then + return pattern + end + end + return nil +end + +local function compile_cpp(source, output) + local result = vim.system({ 'sh', '-c', 'g++ -O2 -o ' .. output .. ' ' .. source }):wait() + if result.code ~= 0 then + logger.log( + ('Failed to compile %s: %s'):format(source, result.stderr or ''), + vim.log.levels.ERROR + ) + return false + end + return true +end + +local function build_run_cmd(file) + local ext = file:match('%.([^%.]+)$') + if ext == 'cc' or ext == 'cpp' then + local base = file:gsub('%.[^%.]+$', '') + local bin = base .. '_bin' + if not compile_cpp(file, bin) then + return nil + end + return './' .. bin + elseif ext == 'py' then + return 'python3 ' .. file + end + return './' .. file +end + +function M.toggle(generator_cmd, brute_cmd) + if state.get_active_panel() == 'stress' then + if state.stress_buf and vim.api.nvim_buf_is_valid(state.stress_buf) then + local job = vim.b[state.stress_buf].terminal_job_id + if job then + vim.fn.jobstop(job) + end + end + if state.saved_stress_session then + vim.cmd.source(state.saved_stress_session) + vim.fn.delete(state.saved_stress_session) + state.saved_stress_session = nil + end + state.set_active_panel(nil) + return + end + + if state.get_active_panel() then + logger.log('Another panel is already active.', vim.log.levels.WARN) + return + end + + local gen_file = generator_cmd + local brute_file = brute_cmd + + if not gen_file then + gen_file = find_file(GENERATOR_PATTERNS) + end + if not brute_file then + brute_file = find_file(BRUTE_PATTERNS) + end + + if not gen_file then + logger.log( + 'No generator found. Pass generator as first arg or add gen.{py,cc,cpp}.', + vim.log.levels.ERROR + ) + return + end + if not brute_file then + logger.log( + 'No brute solution found. Pass brute as second arg or add brute.{py,cc,cpp}.', + vim.log.levels.ERROR + ) + return + end + + local gen_cmd = build_run_cmd(gen_file) + if not gen_cmd then + return + end + + local brute_run_cmd = build_run_cmd(brute_file) + if not brute_run_cmd then + return + end + + state.saved_stress_session = vim.fn.tempname() + -- selene: allow(mixed_table) + vim.cmd.mksession({ state.saved_stress_session, bang = true }) + vim.cmd.only({ mods = { silent = true } }) + + local execute = require('cp.runner.execute') + + local function restore_session() + if state.saved_stress_session then + vim.cmd.source(state.saved_stress_session) + vim.fn.delete(state.saved_stress_session) + state.saved_stress_session = nil + end + end + + execute.compile_problem(false, function(compile_result) + if not compile_result.success then + local run = require('cp.runner.run') + run.handle_compilation_failure(compile_result.output) + restore_session() + return + end + + local binary = state.get_binary_file() + if not binary or binary == '' then + logger.log('No binary produced.', vim.log.levels.ERROR) + restore_session() + return + end + + local script = vim.fn.fnamemodify(utils.get_plugin_path() .. '/scripts/stress.py', ':p') + + local cmdline + if utils.is_nix_build() then + cmdline = table.concat({ + vim.fn.shellescape(utils.get_nix_python()), + vim.fn.shellescape(script), + vim.fn.shellescape(gen_cmd), + vim.fn.shellescape(brute_run_cmd), + vim.fn.shellescape(binary), + }, ' ') + else + cmdline = table.concat({ + 'uv', + 'run', + vim.fn.shellescape(script), + vim.fn.shellescape(gen_cmd), + vim.fn.shellescape(brute_run_cmd), + vim.fn.shellescape(binary), + }, ' ') + end + + vim.cmd.terminal(cmdline) + local term_buf = vim.api.nvim_get_current_buf() + local term_win = vim.api.nvim_get_current_win() + + local cleaned = false + local function cleanup() + if cleaned then + return + end + cleaned = true + if term_buf and vim.api.nvim_buf_is_valid(term_buf) then + local job = vim.b[term_buf] and vim.b[term_buf].terminal_job_id or nil + if job then + pcall(vim.fn.jobstop, job) + end + end + restore_session() + state.stress_buf = nil + state.stress_win = nil + state.set_active_panel(nil) + end + + vim.api.nvim_create_autocmd({ 'BufWipeout', 'BufUnload' }, { + buffer = term_buf, + callback = cleanup, + }) + + vim.api.nvim_create_autocmd('WinClosed', { + callback = function() + if cleaned then + return + end + local any = false + for _, win in ipairs(vim.api.nvim_list_wins()) do + if vim.api.nvim_win_is_valid(win) and vim.api.nvim_win_get_buf(win) == term_buf then + any = true + break + end + end + if not any then + cleanup() + end + end, + }) + + vim.api.nvim_create_autocmd('TermClose', { + buffer = term_buf, + callback = function() + vim.b[term_buf].cp_stress_exited = true + end, + }) + + vim.keymap.set('t', '', function() + cleanup() + end, { buffer = term_buf, silent = true }) + vim.keymap.set('n', '', function() + cleanup() + end, { buffer = term_buf, silent = true }) + + state.stress_buf = term_buf + state.stress_win = term_win + state.set_active_panel('stress') + end) +end + +return M diff --git a/scripts/stress.py b/scripts/stress.py new file mode 100755 index 0000000..429ca26 --- /dev/null +++ b/scripts/stress.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +import subprocess +import sys + + +def main() -> None: + argv = sys.argv[1:] + max_iterations = 1000 + timeout = 10 + + positional: list[str] = [] + i = 0 + while i < len(argv): + if argv[i] == "--max-iterations" and i + 1 < len(argv): + max_iterations = int(argv[i + 1]) + i += 2 + elif argv[i] == "--timeout" and i + 1 < len(argv): + timeout = int(argv[i + 1]) + i += 2 + else: + positional.append(argv[i]) + i += 1 + + if len(positional) != 3: + print( + "Usage: stress.py " + "[--max-iterations N] [--timeout S]", + file=sys.stderr, + ) + sys.exit(1) + + generator, brute, candidate = positional + + for iteration in range(1, max_iterations + 1): + try: + gen_result = subprocess.run( + generator, + capture_output=True, + text=True, + shell=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired: + print(f"[stress] generator timed out on iteration {iteration}", file=sys.stderr) + sys.exit(1) + + if gen_result.returncode != 0: + print( + f"[stress] generator failed on iteration {iteration} " + f"(exit code {gen_result.returncode})", + file=sys.stderr, + ) + if gen_result.stderr: + print(gen_result.stderr, file=sys.stderr, end="") + sys.exit(1) + + test_input = gen_result.stdout + + try: + brute_result = subprocess.run( + brute, + input=test_input, + capture_output=True, + text=True, + shell=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired: + print(f"[stress] brute timed out on iteration {iteration}", file=sys.stderr) + print(f"\n--- input ---\n{test_input}", end="") + sys.exit(1) + + try: + cand_result = subprocess.run( + candidate, + input=test_input, + capture_output=True, + text=True, + shell=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired: + print(f"[stress] candidate timed out on iteration {iteration}", file=sys.stderr) + print(f"\n--- input ---\n{test_input}", end="") + sys.exit(1) + + brute_out = brute_result.stdout.strip() + cand_out = cand_result.stdout.strip() + + if brute_out != cand_out: + print(f"[stress] mismatch on iteration {iteration}", file=sys.stderr) + print(f"\n--- input ---\n{test_input}", end="") + print(f"\n--- expected (brute) ---\n{brute_out}") + print(f"\n--- actual (candidate) ---\n{cand_out}") + sys.exit(1) + + print(f"[stress] iteration {iteration} OK", file=sys.stderr) + + print( + f"[stress] all {max_iterations} iterations passed", + file=sys.stderr, + ) + sys.exit(0) + + +if __name__ == "__main__": + main()