feat: local state over vim.g
This commit is contained in:
parent
e81ea9ef4d
commit
67d2a8054c
4 changed files with 224 additions and 42 deletions
240
lua/cp/init.lua
240
lua/cp/init.lua
|
|
@ -14,7 +14,6 @@ if not vim.fn.has("nvim-0.10.0") then
|
|||
return {}
|
||||
end
|
||||
|
||||
vim.g.cp = vim.g.cp or {}
|
||||
local user_config = {}
|
||||
local config = config_module.setup(user_config)
|
||||
logger.set_config(config)
|
||||
|
|
@ -28,10 +27,41 @@ local state = {
|
|||
saved_layout = nil,
|
||||
saved_session = nil,
|
||||
temp_output = nil,
|
||||
test_cases = nil,
|
||||
test_states = {},
|
||||
}
|
||||
|
||||
local platforms = { "atcoder", "codeforces", "cses" }
|
||||
local actions = { "run", "debug", "diff", "next", "prev" }
|
||||
local actions = { "run", "debug", "test", "next", "prev" }
|
||||
|
||||
local function get_current_problem_key()
|
||||
if not state.platform or not state.contest_id then
|
||||
return nil
|
||||
end
|
||||
if state.platform == "cses" then
|
||||
return state.contest_id
|
||||
else
|
||||
return state.contest_id .. "_" .. (state.problem_id or "")
|
||||
end
|
||||
end
|
||||
|
||||
local function get_test_states()
|
||||
local problem_key = get_current_problem_key()
|
||||
if not problem_key then
|
||||
return {}
|
||||
end
|
||||
|
||||
if not state.test_states[problem_key] then
|
||||
state.test_states[problem_key] = {}
|
||||
if state.test_cases then
|
||||
for i = 1, #state.test_cases do
|
||||
state.test_states[problem_key][i] = true
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return state.test_states[problem_key]
|
||||
end
|
||||
|
||||
local function set_platform(platform)
|
||||
if not vim.tbl_contains(platforms, platform) then
|
||||
|
|
@ -79,6 +109,11 @@ local function setup_problem(contest_id, problem_id)
|
|||
state.contest_id = contest_id
|
||||
state.problem_id = problem_id
|
||||
|
||||
local cached_test_cases = cache.get_test_cases(state.platform, contest_id, problem_id)
|
||||
if cached_test_cases then
|
||||
state.test_cases = cached_test_cases
|
||||
end
|
||||
|
||||
local ctx = problem.create_context(state.platform, contest_id, problem_id, config)
|
||||
|
||||
local scrape_result = scrape.scrape_problem(ctx)
|
||||
|
|
@ -86,9 +121,15 @@ local function setup_problem(contest_id, problem_id)
|
|||
if not scrape_result.success then
|
||||
logger.log("scraping failed: " .. (scrape_result.error or "unknown error"), vim.log.levels.WARN)
|
||||
logger.log("you can manually add test cases to io/ directory", vim.log.levels.INFO)
|
||||
state.test_cases = nil
|
||||
else
|
||||
local test_count = scrape_result.test_count or 0
|
||||
logger.log(("scraped %d test case(s) for %s"):format(test_count, scrape_result.problem_id))
|
||||
state.test_cases = scrape_result.test_cases
|
||||
|
||||
if scrape_result.test_cases then
|
||||
cache.set_test_cases(state.platform, contest_id, problem_id, scrape_result.test_cases)
|
||||
end
|
||||
end
|
||||
|
||||
vim.cmd.e(ctx.source_file)
|
||||
|
|
@ -186,36 +227,86 @@ local function debug_problem()
|
|||
end)
|
||||
end
|
||||
|
||||
local function diff_problem()
|
||||
if state.diff_mode then
|
||||
local tile_fn = config.tile or window.default_tile
|
||||
window.restore_layout(state.saved_layout, tile_fn)
|
||||
state.diff_mode = false
|
||||
state.saved_layout = nil
|
||||
logger.log("exited diff mode")
|
||||
else
|
||||
local problem_id = get_current_problem()
|
||||
if not problem_id then
|
||||
return
|
||||
end
|
||||
|
||||
local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config)
|
||||
|
||||
if vim.fn.filereadable(ctx.expected_file) == 0 then
|
||||
logger.log(("No expected output file found: %s"):format(ctx.expected_file), vim.log.levels.ERROR)
|
||||
return
|
||||
end
|
||||
|
||||
state.saved_layout = window.save_layout()
|
||||
|
||||
local result = vim.system({ "awk", "/^\\[[^]]*\\]:/ {exit} {print}", ctx.output_file }, { text = true }):wait()
|
||||
local actual_output = result.stdout
|
||||
|
||||
window.setup_diff_layout(actual_output, ctx.expected_file, ctx.input_file)
|
||||
|
||||
state.diff_mode = true
|
||||
logger.log("entered diff mode")
|
||||
local function test_problem()
|
||||
local problem_id = get_current_problem()
|
||||
if not problem_id then
|
||||
return
|
||||
end
|
||||
|
||||
if not state.test_cases then
|
||||
logger.log("No test case data available. Try scraping the problem first.", vim.log.levels.ERROR)
|
||||
return
|
||||
end
|
||||
|
||||
local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config)
|
||||
local contest_config = config.contests[state.platform]
|
||||
|
||||
local test_results = execute.run_individual_tests(ctx, state.test_cases, contest_config, false)
|
||||
|
||||
if test_results.compile_error then
|
||||
logger.log("Compilation failed: " .. test_results.compile_error, vim.log.levels.ERROR)
|
||||
return
|
||||
end
|
||||
|
||||
local buf_name = ("cp-test://%s"):format(problem_id)
|
||||
local existing_buf = vim.fn.bufnr(buf_name)
|
||||
local buf
|
||||
|
||||
if existing_buf ~= -1 then
|
||||
buf = existing_buf
|
||||
else
|
||||
buf = vim.api.nvim_create_buf(false, true)
|
||||
vim.api.nvim_buf_set_name(buf, buf_name)
|
||||
end
|
||||
|
||||
local lines = {}
|
||||
local passed = 0
|
||||
local total = #test_results.results
|
||||
|
||||
local test_states = get_test_states()
|
||||
|
||||
for _, result in ipairs(test_results.results) do
|
||||
local status_icon = result.status == "PASS" and "✓" or "✗"
|
||||
local enabled_icon = test_states[result.id] and "[x]" or "[ ]"
|
||||
local time_str = ("%.1fms"):format(result.time_ms)
|
||||
|
||||
table.insert(lines, ("%s Test %d %s %s (%s) {{{"):format(enabled_icon, result.id, status_icon, result.status, time_str))
|
||||
table.insert(lines, " Input:")
|
||||
for _, line in ipairs(vim.split(result.input, "\n")) do
|
||||
table.insert(lines, " " .. line)
|
||||
end
|
||||
|
||||
if result.status == "PASS" then
|
||||
table.insert(lines, " Output:")
|
||||
for _, line in ipairs(vim.split(result.actual, "\n")) do
|
||||
table.insert(lines, " " .. line)
|
||||
end
|
||||
passed = passed + 1
|
||||
else
|
||||
table.insert(lines, " Expected:")
|
||||
for _, line in ipairs(vim.split(result.expected, "\n")) do
|
||||
table.insert(lines, " " .. line)
|
||||
end
|
||||
table.insert(lines, " Got:")
|
||||
for _, line in ipairs(vim.split(result.actual, "\n")) do
|
||||
table.insert(lines, " " .. line)
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(lines, "}}}")
|
||||
table.insert(lines, "")
|
||||
end
|
||||
|
||||
table.insert(lines, ("Summary: %d/%d passed"):format(passed, total))
|
||||
|
||||
vim.api.nvim_buf_set_lines(buf, 0, -1, false, lines)
|
||||
vim.bo[buf].filetype = "cp-test"
|
||||
vim.bo[buf].modifiable = false
|
||||
|
||||
vim.cmd.split()
|
||||
vim.api.nvim_set_current_buf(buf)
|
||||
|
||||
logger.log(("Test results: %d/%d passed"):format(passed, total))
|
||||
end
|
||||
|
||||
---@param delta number 1 for next, -1 for prev
|
||||
|
|
@ -323,8 +414,8 @@ function M.handle_command(opts)
|
|||
run_problem()
|
||||
elseif cmd.action == "debug" then
|
||||
debug_problem()
|
||||
elseif cmd.action == "diff" then
|
||||
diff_problem()
|
||||
elseif cmd.action == "test" then
|
||||
test_problem()
|
||||
elseif cmd.action == "next" then
|
||||
navigate_problem(1)
|
||||
elseif cmd.action == "prev" then
|
||||
|
|
@ -400,6 +491,81 @@ function M.handle_command(opts)
|
|||
end
|
||||
end
|
||||
|
||||
function M.toggle_test(test_id)
|
||||
local test_states = get_test_states()
|
||||
test_states[test_id] = not test_states[test_id]
|
||||
|
||||
local problem_key = get_current_problem_key()
|
||||
if problem_key then
|
||||
state.test_states[problem_key] = test_states
|
||||
end
|
||||
|
||||
test_problem()
|
||||
end
|
||||
|
||||
function M.run_single_test(test_id)
|
||||
if not state.test_cases or not state.test_cases[test_id] then
|
||||
logger.log("Test case not found", vim.log.levels.ERROR)
|
||||
return
|
||||
end
|
||||
|
||||
local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config)
|
||||
local contest_config = config.contests[state.platform]
|
||||
|
||||
local single_test = { state.test_cases[test_id] }
|
||||
local test_results = execute.run_individual_tests(ctx, single_test, contest_config, false)
|
||||
|
||||
if test_results.compile_error then
|
||||
logger.log("Compilation failed: " .. test_results.compile_error, vim.log.levels.ERROR)
|
||||
return
|
||||
end
|
||||
|
||||
local result = test_results.results[1]
|
||||
if result then
|
||||
logger.log(("Test %d: %s (%.1fms)"):format(test_id, result.status, result.time_ms))
|
||||
end
|
||||
end
|
||||
|
||||
function M.run_all_enabled_tests()
|
||||
if not state.test_cases then
|
||||
logger.log("No test cases available", vim.log.levels.ERROR)
|
||||
return
|
||||
end
|
||||
|
||||
local test_states = get_test_states()
|
||||
local enabled_tests = {}
|
||||
|
||||
for i, test_case in ipairs(state.test_cases) do
|
||||
if test_states[i] then
|
||||
table.insert(enabled_tests, test_case)
|
||||
end
|
||||
end
|
||||
|
||||
if #enabled_tests == 0 then
|
||||
logger.log("No tests enabled", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
local ctx = problem.create_context(state.platform, state.contest_id, state.problem_id, config)
|
||||
local contest_config = config.contests[state.platform]
|
||||
|
||||
local test_results = execute.run_individual_tests(ctx, enabled_tests, contest_config, false)
|
||||
|
||||
if test_results.compile_error then
|
||||
logger.log("Compilation failed: " .. test_results.compile_error, vim.log.levels.ERROR)
|
||||
return
|
||||
end
|
||||
|
||||
local passed = 0
|
||||
for _, result in ipairs(test_results.results) do
|
||||
if result.status == "PASS" then
|
||||
passed = passed + 1
|
||||
end
|
||||
end
|
||||
|
||||
logger.log(("Enabled tests: %d/%d passed"):format(passed, #enabled_tests))
|
||||
end
|
||||
|
||||
function M.setup(opts)
|
||||
opts = opts or {}
|
||||
user_config = opts
|
||||
|
|
@ -411,6 +577,14 @@ function M.setup(opts)
|
|||
end
|
||||
end
|
||||
|
||||
function M.get_current_context()
|
||||
return {
|
||||
platform = state.platform,
|
||||
contest_id = state.contest_id,
|
||||
problem_id = state.problem_id,
|
||||
}
|
||||
end
|
||||
|
||||
function M.is_initialized()
|
||||
return true
|
||||
end
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue