From e7d56e3bbe3427d26f3e4772afa2816874d39675 Mon Sep 17 00:00:00 2001 From: Barrett Ruth <62671086+barrettruth@users.noreply.github.com> Date: Thu, 5 Mar 2026 11:14:31 -0500 Subject: [PATCH] feat(highlight): wire highlights.context into treesitter pipeline (#151) ## Problem `highlights.context.enabled` and `highlights.context.lines` were defined, validated, and range-checked but never read during highlighting. Hunks inside incomplete constructs (e.g., a table literal or function body whose opening is beyond the hunk's own context lines) parsed incorrectly because treesitter had no surrounding code. ## Solution `compute_hunk_context` in `init.lua` reads the working tree file using the hunk's `@@ +start,count @@` line numbers to collect up to `lines` (default 25) surrounding code lines in each direction. Files are read once via `io.open` and cached across hunks in the same file. `highlight_treesitter` in `highlight.lua` accepts an optional context parameter that prepends/appends context lines to the parse string and offsets capture rows by the prefix count, so extmarks only land on actual hunk lines. Wired through `highlight_hunk` for the two code-language treesitter calls (not headers, not `highlight_text`, not vim syntax). Closes #148. --- doc/diffs.nvim.txt | 32 ++-- lua/diffs/highlight.lua | 68 ++++++- lua/diffs/init.lua | 67 +++++++ lua/diffs/parser.lua | 2 + spec/context_spec.lua | 382 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 534 insertions(+), 17 deletions(-) create mode 100644 spec/context_spec.lua diff --git a/doc/diffs.nvim.txt b/doc/diffs.nvim.txt index 6798173..cfcd1b6 100644 --- a/doc/diffs.nvim.txt +++ b/doc/diffs.nvim.txt @@ -225,16 +225,20 @@ Configuration is done via `vim.g.diffs`. Set this before the plugin loads: *diffs.ContextConfig* Context config fields: ~ {enabled} (boolean, default: true) - Read lines from disk before and after each hunk - to provide surrounding syntax context. Improves - accuracy at hunk boundaries where incomplete - constructs (e.g., a function definition with no - body) would otherwise confuse the parser. + Read surrounding code from the working tree + file and feed it into the treesitter string + parser. Uses the hunk's `@@ +start,count @@` + line numbers to read lines before and after + the hunk from disk. Improves syntax accuracy + when the hunk is inside an incomplete construct + (e.g., a table literal or function body whose + opening is not visible in the hunk's own + context lines). {lines} (integer, default: 25) - Number of context lines to read in each - direction. Lines are read with early exit — - cost scales with this value, not file size. + Max context lines to read in each direction. + Files are read once per parse and cached across + hunks in the same file. *diffs.PrioritiesConfig* Priorities config fields: ~ @@ -695,10 +699,14 @@ KNOWN LIMITATIONS *diffs-limitations* Incomplete Syntax Context ~ *diffs-syntax-context* -Treesitter parses each diff hunk in isolation. Context lines within the hunk -(lines with a ` ` prefix) provide syntactic context for the parser. In rare -cases, hunks that start or end mid-expression may produce imperfect highlights -due to treesitter error recovery. +Treesitter parses each diff hunk in isolation. When `highlights.context` is +enabled (the default), surrounding code is read from the working tree file +and fed into the parser to improve accuracy at hunk boundaries. This helps +when a hunk is inside a table, function body, or loop whose opening is +beyond the hunk's own context lines. Requires `repo_root` and +`file_new_start` to be available on the hunk (true for standard unified +diffs). In rare cases, hunks that start or end mid-expression may still +produce imperfect highlights due to treesitter error recovery. Syntax Highlighting Flash ~ *diffs-flash* diff --git a/lua/diffs/highlight.lua b/lua/diffs/highlight.lua index 3190641..493bcfc 100644 --- a/lua/diffs/highlight.lua +++ b/lua/diffs/highlight.lua @@ -67,6 +67,10 @@ end ---@field defer_vim_syntax? boolean ---@field syntax_only? boolean +---@class diffs.TSContext +---@field before string[]? +---@field after string[]? + ---@param bufnr integer ---@param ns integer ---@param code_lines string[] @@ -76,6 +80,7 @@ end ---@param covered_lines? table ---@param priorities diffs.PrioritiesConfig ---@param force_high_priority? boolean +---@param context? diffs.TSContext ---@return integer local function highlight_treesitter( bufnr, @@ -86,9 +91,34 @@ local function highlight_treesitter( col_offset, covered_lines, priorities, - force_high_priority + force_high_priority, + context ) - local code = table.concat(code_lines, '\n') + local prefix_count = 0 + local parse_lines = code_lines + if context then + local before = context.before + local after = context.after + if (before and #before > 0) or (after and #after > 0) then + parse_lines = {} + if before then + prefix_count = #before + for _, l in ipairs(before) do + parse_lines[#parse_lines + 1] = l + end + end + for _, l in ipairs(code_lines) do + parse_lines[#parse_lines + 1] = l + end + if after then + for _, l in ipairs(after) do + parse_lines[#parse_lines + 1] = l + end + end + end + end + + local code = table.concat(parse_lines, '\n') if code == '' then return 0 end @@ -118,6 +148,8 @@ local function highlight_treesitter( if capture ~= 'spell' and capture ~= 'nospell' then local capture_name = '@' .. capture .. '.' .. tree_lang local sr, sc, er, ec = node:range() + sr = sr - prefix_count + er = er - prefix_count local buf_sr = line_map[sr] if buf_sr then @@ -329,10 +361,36 @@ function M.highlight_hunk(bufnr, ns, hunk, opts) end end - extmark_count = - highlight_treesitter(bufnr, ns, new_code, hunk.lang, new_map, pw + qw, covered_lines, p) + local ts_context = nil + if opts.highlights.context.enabled and (hunk.context_before or hunk.context_after) then + ts_context = { before = hunk.context_before, after = hunk.context_after } + end + + extmark_count = highlight_treesitter( + bufnr, + ns, + new_code, + hunk.lang, + new_map, + pw + qw, + covered_lines, + p, + nil, + ts_context + ) extmark_count = extmark_count - + highlight_treesitter(bufnr, ns, old_code, hunk.lang, old_map, pw + qw, covered_lines, p) + + highlight_treesitter( + bufnr, + ns, + old_code, + hunk.lang, + old_map, + pw + qw, + covered_lines, + p, + nil, + ts_context + ) if hunk.header_context and hunk.header_context_col then local header_extmarks = highlight_text( diff --git a/lua/diffs/init.lua b/lua/diffs/init.lua index 0bfc373..77f521a 100644 --- a/lua/diffs/init.lua +++ b/lua/diffs/init.lua @@ -297,6 +297,69 @@ local function carry_forward_highlighted(old_entry, new_hunks) return highlighted end +---@param path string +---@return string[]? +local function read_file_lines(path) + local f = io.open(path, 'r') + if not f then + return nil + end + local lines = {} + for line in f:lines() do + lines[#lines + 1] = line + end + f:close() + return lines +end + +---@param hunks diffs.Hunk[] +---@param max_lines integer +local function compute_hunk_context(hunks, max_lines) + ---@type table + local file_cache = {} + + for _, hunk in ipairs(hunks) do + if not hunk.repo_root or not hunk.filename or not hunk.file_new_start then + goto continue + end + + local path = vim.fs.joinpath(hunk.repo_root, hunk.filename) + local file_lines = file_cache[path] + if file_lines == nil then + file_lines = read_file_lines(path) or false + file_cache[path] = file_lines + end + if not file_lines then + goto continue + end + + local new_start = hunk.file_new_start + local new_count = hunk.file_new_count or 0 + local total = #file_lines + + local before_start = math.max(1, new_start - max_lines) + if before_start < new_start then + local before = {} + for i = before_start, new_start - 1 do + before[#before + 1] = file_lines[i] + end + hunk.context_before = before + end + + local after_start = new_start + new_count + local after_end = math.min(total, after_start + max_lines - 1) + if after_start <= total then + local after = {} + for i = after_start, after_end do + after[#after + 1] = file_lines[i] + end + hunk.context_after = after + end + + ::continue:: + end +end + ---@param bufnr integer local function ensure_cache(bufnr) if not vim.api.nvim_buf_is_valid(bufnr) then @@ -321,6 +384,9 @@ local function ensure_cache(bufnr) local lc = vim.api.nvim_buf_line_count(bufnr) local bc = vim.api.nvim_buf_get_offset(bufnr, lc) dbg('parsed %d hunks in buffer %d (tick %d)', #hunks, bufnr, tick) + if config.highlights.context.enabled then + compute_hunk_context(hunks, config.highlights.context.lines) + end local carried = entry and not entry.pending_clear and carry_forward_highlighted(entry, hunks) hunk_cache[bufnr] = { hunks = hunks, @@ -941,6 +1007,7 @@ M._test = { hunks_eq = hunks_eq, process_pending_clear = process_pending_clear, ft_retry_pending = ft_retry_pending, + compute_hunk_context = compute_hunk_context, } return M diff --git a/lua/diffs/parser.lua b/lua/diffs/parser.lua index 3df3d4c..ccbd46a 100644 --- a/lua/diffs/parser.lua +++ b/lua/diffs/parser.lua @@ -15,6 +15,8 @@ ---@field prefix_width integer ---@field quote_width integer ---@field repo_root string? +---@field context_before string[]? +---@field context_after string[]? local M = {} diff --git a/spec/context_spec.lua b/spec/context_spec.lua new file mode 100644 index 0000000..237b7f3 --- /dev/null +++ b/spec/context_spec.lua @@ -0,0 +1,382 @@ +require('spec.helpers') +local diffs = require('diffs') +local highlight = require('diffs.highlight') +local compute_hunk_context = diffs._test.compute_hunk_context + +describe('context', function() + describe('compute_hunk_context', function() + local tmpdir + + before_each(function() + tmpdir = vim.fn.tempname() + vim.fn.mkdir(tmpdir, 'p') + end) + + after_each(function() + vim.fn.delete(tmpdir, 'rf') + end) + + local function write_file(filename, lines) + local path = vim.fs.joinpath(tmpdir, filename) + local dir = vim.fn.fnamemodify(path, ':h') + if vim.fn.isdirectory(dir) == 0 then + vim.fn.mkdir(dir, 'p') + end + local f = io.open(path, 'w') + f:write(table.concat(lines, '\n') .. '\n') + f:close() + end + + local function make_hunk(filename, opts) + return { + filename = filename, + ft = 'lua', + lang = 'lua', + start_line = opts.start_line or 1, + lines = opts.lines, + prefix_width = opts.prefix_width or 1, + quote_width = 0, + repo_root = tmpdir, + file_new_start = opts.file_new_start, + file_new_count = opts.file_new_count, + } + end + + it('reads context_before from file lines preceding the hunk', function() + write_file('a.lua', { + 'local M = {}', + 'function M.foo()', + ' local x = 1', + ' local y = 2', + 'end', + 'return M', + }) + + local hunks = { + make_hunk('a.lua', { + file_new_start = 3, + file_new_count = 3, + lines = { ' local x = 1', '+local new = true', ' local y = 2' }, + }), + } + compute_hunk_context(hunks, 25) + + assert.same({ 'local M = {}', 'function M.foo()' }, hunks[1].context_before) + end) + + it('reads context_after from file lines following the hunk', function() + write_file('a.lua', { + 'local M = {}', + 'function M.foo()', + ' local x = 1', + 'end', + 'return M', + }) + + local hunks = { + make_hunk('a.lua', { + file_new_start = 2, + file_new_count = 2, + lines = { ' function M.foo()', '+ local x = 1' }, + }), + } + compute_hunk_context(hunks, 25) + + assert.same({ 'end', 'return M' }, hunks[1].context_after) + end) + + it('caps context_before to max_lines', function() + write_file('a.lua', { + 'line1', + 'line2', + 'line3', + 'line4', + 'line5', + 'target', + }) + + local hunks = { + make_hunk('a.lua', { + file_new_start = 6, + file_new_count = 1, + lines = { '+target' }, + }), + } + compute_hunk_context(hunks, 2) + + assert.same({ 'line4', 'line5' }, hunks[1].context_before) + end) + + it('caps context_after to max_lines', function() + write_file('a.lua', { + 'target', + 'after1', + 'after2', + 'after3', + 'after4', + }) + + local hunks = { + make_hunk('a.lua', { + file_new_start = 1, + file_new_count = 1, + lines = { '+target' }, + }), + } + compute_hunk_context(hunks, 2) + + assert.same({ 'after1', 'after2' }, hunks[1].context_after) + end) + + it('skips hunks without file_new_start', function() + write_file('a.lua', { 'line1', 'line2' }) + + local hunks = { + make_hunk('a.lua', { + file_new_start = nil, + file_new_count = nil, + lines = { '+something' }, + }), + } + compute_hunk_context(hunks, 25) + + assert.is_nil(hunks[1].context_before) + assert.is_nil(hunks[1].context_after) + end) + + it('skips hunks without repo_root', function() + local hunks = { + { + filename = 'a.lua', + ft = 'lua', + lang = 'lua', + start_line = 1, + lines = { '+x' }, + prefix_width = 1, + quote_width = 0, + repo_root = nil, + file_new_start = 1, + file_new_count = 1, + }, + } + compute_hunk_context(hunks, 25) + + assert.is_nil(hunks[1].context_before) + assert.is_nil(hunks[1].context_after) + end) + + it('skips when file does not exist on disk', function() + local hunks = { + make_hunk('nonexistent.lua', { + file_new_start = 1, + file_new_count = 1, + lines = { '+x' }, + }), + } + compute_hunk_context(hunks, 25) + + assert.is_nil(hunks[1].context_before) + assert.is_nil(hunks[1].context_after) + end) + + it('returns nil context_before for hunk at line 1', function() + write_file('a.lua', { 'first', 'second' }) + + local hunks = { + make_hunk('a.lua', { + file_new_start = 1, + file_new_count = 1, + lines = { '+first' }, + }), + } + compute_hunk_context(hunks, 25) + + assert.is_nil(hunks[1].context_before) + end) + + it('returns nil context_after for hunk at end of file', function() + write_file('a.lua', { 'first', 'last' }) + + local hunks = { + make_hunk('a.lua', { + file_new_start = 1, + file_new_count = 2, + lines = { ' first', '+last' }, + }), + } + compute_hunk_context(hunks, 25) + + assert.is_nil(hunks[1].context_after) + end) + + it('reads file once for multiple hunks in same file', function() + write_file('a.lua', { + 'local M = {}', + 'function M.foo()', + ' return 1', + 'end', + 'function M.bar()', + ' return 2', + 'end', + 'return M', + }) + + local hunks = { + make_hunk('a.lua', { + file_new_start = 2, + file_new_count = 3, + lines = { ' function M.foo()', '+ return 1', ' end' }, + }), + make_hunk('a.lua', { + file_new_start = 5, + file_new_count = 3, + lines = { ' function M.bar()', '+ return 2', ' end' }, + }), + } + compute_hunk_context(hunks, 25) + + assert.same({ 'local M = {}' }, hunks[1].context_before) + assert.same({ 'function M.bar()', ' return 2', 'end', 'return M' }, hunks[1].context_after) + assert.same({ + 'local M = {}', + 'function M.foo()', + ' return 1', + 'end', + }, hunks[2].context_before) + assert.same({ 'return M' }, hunks[2].context_after) + end) + end) + + describe('highlight_treesitter with context', function() + local ns + + before_each(function() + ns = vim.api.nvim_create_namespace('diffs_context_test') + local normal = vim.api.nvim_get_hl(0, { name = 'Normal' }) + vim.api.nvim_set_hl(0, 'DiffsClear', { fg = normal.fg or 0xc0c0c0 }) + end) + + local function create_buffer(lines) + local bufnr = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) + return bufnr + end + + local function delete_buffer(bufnr) + if vim.api.nvim_buf_is_valid(bufnr) then + vim.api.nvim_buf_delete(bufnr, { force = true }) + end + end + + local function get_extmarks(bufnr) + return vim.api.nvim_buf_get_extmarks(bufnr, ns, 0, -1, { details = true }) + end + + local function default_opts(overrides) + local opts = { + hide_prefix = false, + highlights = { + background = false, + gutter = false, + context = { enabled = true, lines = 25 }, + treesitter = { enabled = true, max_lines = 500 }, + vim = { enabled = false, max_lines = 200 }, + intra = { enabled = false, algorithm = 'default', max_lines = 500 }, + priorities = { clear = 198, syntax = 199, line_bg = 200, char_bg = 201 }, + }, + } + if overrides then + if overrides.highlights then + opts.highlights = vim.tbl_deep_extend('force', opts.highlights, overrides.highlights) + end + end + return opts + end + + it('applies extmarks only to hunk lines, not context lines', function() + local bufnr = create_buffer({ + '@@ -1,2 +1,3 @@', + ' local x = 1', + ' local y = 2', + '+local z = 3', + }) + + local hunk = { + filename = 'test.lua', + lang = 'lua', + start_line = 2, + lines = { ' local x = 1', ' local y = 2', '+local z = 3' }, + prefix_width = 1, + quote_width = 0, + context_before = { 'local function foo()' }, + context_after = { 'end' }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + for _, mark in ipairs(extmarks) do + local row = mark[2] + assert.is_true(row >= 1 and row <= 3, 'extmark row ' .. row .. ' outside hunk range') + end + assert.is_true(#extmarks > 0) + delete_buffer(bufnr) + end) + + it('does not pass context when context.enabled = false', function() + local bufnr = create_buffer({ + '@@ -1,1 +1,2 @@', + ' local x = 1', + '+local y = 2', + }) + + local hunk = { + filename = 'test.lua', + lang = 'lua', + start_line = 2, + lines = { ' local x = 1', '+local y = 2' }, + prefix_width = 1, + quote_width = 0, + context_before = { 'local function foo()' }, + context_after = { 'end' }, + } + + local opts_enabled = default_opts({ highlights = { context = { enabled = true } } }) + highlight.highlight_hunk(bufnr, ns, hunk, opts_enabled) + local extmarks_with = get_extmarks(bufnr) + + vim.api.nvim_buf_clear_namespace(bufnr, ns, 0, -1) + + local opts_disabled = default_opts({ highlights = { context = { enabled = false } } }) + highlight.highlight_hunk(bufnr, ns, hunk, opts_disabled) + local extmarks_without = get_extmarks(bufnr) + + assert.is_true(#extmarks_with > 0) + assert.is_true(#extmarks_without > 0) + delete_buffer(bufnr) + end) + + it('skips context fields that are nil', function() + local bufnr = create_buffer({ + '@@ -1,1 +1,2 @@', + ' local x = 1', + '+local y = 2', + }) + + local hunk = { + filename = 'test.lua', + lang = 'lua', + start_line = 2, + lines = { ' local x = 1', '+local y = 2' }, + prefix_width = 1, + quote_width = 0, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + assert.is_true(#extmarks > 0) + delete_buffer(bufnr) + end) + end) +end)