feat(highlight): wire highlights.context into treesitter pipeline (#151)

## Problem

`highlights.context.enabled` and `highlights.context.lines` were
defined, validated, and range-checked but never read during
highlighting. Hunks inside incomplete constructs (e.g., a table literal
or function body whose opening is beyond the hunk's own context lines)
parsed incorrectly because treesitter had no surrounding code.

## Solution

`compute_hunk_context` in `init.lua` reads the working tree file using
the hunk's `@@ +start,count @@` line numbers to collect up to `lines`
(default 25) surrounding code lines in each direction. Files are read
once via `io.open` and cached across hunks in the same file.
`highlight_treesitter` in `highlight.lua` accepts an optional context
parameter that prepends/appends context lines to the parse string and
offsets capture rows by the prefix count, so extmarks only land on
actual hunk lines. Wired through `highlight_hunk` for the two
code-language treesitter calls (not headers, not `highlight_text`, not
vim syntax).

Closes #148.
This commit is contained in:
Barrett Ruth 2026-03-05 11:14:31 -05:00 committed by GitHub
parent 29e624d9f0
commit 6e766c83b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 534 additions and 17 deletions

View file

@ -225,16 +225,20 @@ Configuration is done via `vim.g.diffs`. Set this before the plugin loads:
*diffs.ContextConfig* *diffs.ContextConfig*
Context config fields: ~ Context config fields: ~
{enabled} (boolean, default: true) {enabled} (boolean, default: true)
Read lines from disk before and after each hunk Read surrounding code from the working tree
to provide surrounding syntax context. Improves file and feed it into the treesitter string
accuracy at hunk boundaries where incomplete parser. Uses the hunk's `@@ +start,count @@`
constructs (e.g., a function definition with no line numbers to read lines before and after
body) would otherwise confuse the parser. the hunk from disk. Improves syntax accuracy
when the hunk is inside an incomplete construct
(e.g., a table literal or function body whose
opening is not visible in the hunk's own
context lines).
{lines} (integer, default: 25) {lines} (integer, default: 25)
Number of context lines to read in each Max context lines to read in each direction.
direction. Lines are read with early exit — Files are read once per parse and cached across
cost scales with this value, not file size. hunks in the same file.
*diffs.PrioritiesConfig* *diffs.PrioritiesConfig*
Priorities config fields: ~ Priorities config fields: ~
@ -695,10 +699,14 @@ KNOWN LIMITATIONS *diffs-limitations*
Incomplete Syntax Context ~ Incomplete Syntax Context ~
*diffs-syntax-context* *diffs-syntax-context*
Treesitter parses each diff hunk in isolation. Context lines within the hunk Treesitter parses each diff hunk in isolation. When `highlights.context` is
(lines with a ` ` prefix) provide syntactic context for the parser. In rare enabled (the default), surrounding code is read from the working tree file
cases, hunks that start or end mid-expression may produce imperfect highlights and fed into the parser to improve accuracy at hunk boundaries. This helps
due to treesitter error recovery. when a hunk is inside a table, function body, or loop whose opening is
beyond the hunk's own context lines. Requires `repo_root` and
`file_new_start` to be available on the hunk (true for standard unified
diffs). In rare cases, hunks that start or end mid-expression may still
produce imperfect highlights due to treesitter error recovery.
Syntax Highlighting Flash ~ Syntax Highlighting Flash ~
*diffs-flash* *diffs-flash*

View file

@ -67,6 +67,10 @@ end
---@field defer_vim_syntax? boolean ---@field defer_vim_syntax? boolean
---@field syntax_only? boolean ---@field syntax_only? boolean
---@class diffs.TSContext
---@field before string[]?
---@field after string[]?
---@param bufnr integer ---@param bufnr integer
---@param ns integer ---@param ns integer
---@param code_lines string[] ---@param code_lines string[]
@ -76,6 +80,7 @@ end
---@param covered_lines? table<integer, true> ---@param covered_lines? table<integer, true>
---@param priorities diffs.PrioritiesConfig ---@param priorities diffs.PrioritiesConfig
---@param force_high_priority? boolean ---@param force_high_priority? boolean
---@param context? diffs.TSContext
---@return integer ---@return integer
local function highlight_treesitter( local function highlight_treesitter(
bufnr, bufnr,
@ -86,9 +91,34 @@ local function highlight_treesitter(
col_offset, col_offset,
covered_lines, covered_lines,
priorities, priorities,
force_high_priority force_high_priority,
context
) )
local code = table.concat(code_lines, '\n') local prefix_count = 0
local parse_lines = code_lines
if context then
local before = context.before
local after = context.after
if (before and #before > 0) or (after and #after > 0) then
parse_lines = {}
if before then
prefix_count = #before
for _, l in ipairs(before) do
parse_lines[#parse_lines + 1] = l
end
end
for _, l in ipairs(code_lines) do
parse_lines[#parse_lines + 1] = l
end
if after then
for _, l in ipairs(after) do
parse_lines[#parse_lines + 1] = l
end
end
end
end
local code = table.concat(parse_lines, '\n')
if code == '' then if code == '' then
return 0 return 0
end end
@ -118,6 +148,8 @@ local function highlight_treesitter(
if capture ~= 'spell' and capture ~= 'nospell' then if capture ~= 'spell' and capture ~= 'nospell' then
local capture_name = '@' .. capture .. '.' .. tree_lang local capture_name = '@' .. capture .. '.' .. tree_lang
local sr, sc, er, ec = node:range() local sr, sc, er, ec = node:range()
sr = sr - prefix_count
er = er - prefix_count
local buf_sr = line_map[sr] local buf_sr = line_map[sr]
if buf_sr then if buf_sr then
@ -329,10 +361,36 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
end end
end end
extmark_count = local ts_context = nil
highlight_treesitter(bufnr, ns, new_code, hunk.lang, new_map, pw + qw, covered_lines, p) if opts.highlights.context.enabled and (hunk.context_before or hunk.context_after) then
ts_context = { before = hunk.context_before, after = hunk.context_after }
end
extmark_count = highlight_treesitter(
bufnr,
ns,
new_code,
hunk.lang,
new_map,
pw + qw,
covered_lines,
p,
nil,
ts_context
)
extmark_count = extmark_count extmark_count = extmark_count
+ highlight_treesitter(bufnr, ns, old_code, hunk.lang, old_map, pw + qw, covered_lines, p) + highlight_treesitter(
bufnr,
ns,
old_code,
hunk.lang,
old_map,
pw + qw,
covered_lines,
p,
nil,
ts_context
)
if hunk.header_context and hunk.header_context_col then if hunk.header_context and hunk.header_context_col then
local header_extmarks = highlight_text( local header_extmarks = highlight_text(

View file

@ -297,6 +297,69 @@ local function carry_forward_highlighted(old_entry, new_hunks)
return highlighted return highlighted
end end
---@param path string
---@return string[]?
local function read_file_lines(path)
local f = io.open(path, 'r')
if not f then
return nil
end
local lines = {}
for line in f:lines() do
lines[#lines + 1] = line
end
f:close()
return lines
end
---@param hunks diffs.Hunk[]
---@param max_lines integer
local function compute_hunk_context(hunks, max_lines)
---@type table<string, string[]|false>
local file_cache = {}
for _, hunk in ipairs(hunks) do
if not hunk.repo_root or not hunk.filename or not hunk.file_new_start then
goto continue
end
local path = vim.fs.joinpath(hunk.repo_root, hunk.filename)
local file_lines = file_cache[path]
if file_lines == nil then
file_lines = read_file_lines(path) or false
file_cache[path] = file_lines
end
if not file_lines then
goto continue
end
local new_start = hunk.file_new_start
local new_count = hunk.file_new_count or 0
local total = #file_lines
local before_start = math.max(1, new_start - max_lines)
if before_start < new_start then
local before = {}
for i = before_start, new_start - 1 do
before[#before + 1] = file_lines[i]
end
hunk.context_before = before
end
local after_start = new_start + new_count
local after_end = math.min(total, after_start + max_lines - 1)
if after_start <= total then
local after = {}
for i = after_start, after_end do
after[#after + 1] = file_lines[i]
end
hunk.context_after = after
end
::continue::
end
end
---@param bufnr integer ---@param bufnr integer
local function ensure_cache(bufnr) local function ensure_cache(bufnr)
if not vim.api.nvim_buf_is_valid(bufnr) then if not vim.api.nvim_buf_is_valid(bufnr) then
@ -321,6 +384,9 @@ local function ensure_cache(bufnr)
local lc = vim.api.nvim_buf_line_count(bufnr) local lc = vim.api.nvim_buf_line_count(bufnr)
local bc = vim.api.nvim_buf_get_offset(bufnr, lc) local bc = vim.api.nvim_buf_get_offset(bufnr, lc)
dbg('parsed %d hunks in buffer %d (tick %d)', #hunks, bufnr, tick) dbg('parsed %d hunks in buffer %d (tick %d)', #hunks, bufnr, tick)
if config.highlights.context.enabled then
compute_hunk_context(hunks, config.highlights.context.lines)
end
local carried = entry and not entry.pending_clear and carry_forward_highlighted(entry, hunks) local carried = entry and not entry.pending_clear and carry_forward_highlighted(entry, hunks)
hunk_cache[bufnr] = { hunk_cache[bufnr] = {
hunks = hunks, hunks = hunks,
@ -941,6 +1007,7 @@ M._test = {
hunks_eq = hunks_eq, hunks_eq = hunks_eq,
process_pending_clear = process_pending_clear, process_pending_clear = process_pending_clear,
ft_retry_pending = ft_retry_pending, ft_retry_pending = ft_retry_pending,
compute_hunk_context = compute_hunk_context,
} }
return M return M

View file

@ -15,6 +15,8 @@
---@field prefix_width integer ---@field prefix_width integer
---@field quote_width integer ---@field quote_width integer
---@field repo_root string? ---@field repo_root string?
---@field context_before string[]?
---@field context_after string[]?
local M = {} local M = {}

382
spec/context_spec.lua Normal file
View file

@ -0,0 +1,382 @@
require('spec.helpers')
local diffs = require('diffs')
local highlight = require('diffs.highlight')
local compute_hunk_context = diffs._test.compute_hunk_context
describe('context', function()
describe('compute_hunk_context', function()
local tmpdir
before_each(function()
tmpdir = vim.fn.tempname()
vim.fn.mkdir(tmpdir, 'p')
end)
after_each(function()
vim.fn.delete(tmpdir, 'rf')
end)
local function write_file(filename, lines)
local path = vim.fs.joinpath(tmpdir, filename)
local dir = vim.fn.fnamemodify(path, ':h')
if vim.fn.isdirectory(dir) == 0 then
vim.fn.mkdir(dir, 'p')
end
local f = io.open(path, 'w')
f:write(table.concat(lines, '\n') .. '\n')
f:close()
end
local function make_hunk(filename, opts)
return {
filename = filename,
ft = 'lua',
lang = 'lua',
start_line = opts.start_line or 1,
lines = opts.lines,
prefix_width = opts.prefix_width or 1,
quote_width = 0,
repo_root = tmpdir,
file_new_start = opts.file_new_start,
file_new_count = opts.file_new_count,
}
end
it('reads context_before from file lines preceding the hunk', function()
write_file('a.lua', {
'local M = {}',
'function M.foo()',
' local x = 1',
' local y = 2',
'end',
'return M',
})
local hunks = {
make_hunk('a.lua', {
file_new_start = 3,
file_new_count = 3,
lines = { ' local x = 1', '+local new = true', ' local y = 2' },
}),
}
compute_hunk_context(hunks, 25)
assert.same({ 'local M = {}', 'function M.foo()' }, hunks[1].context_before)
end)
it('reads context_after from file lines following the hunk', function()
write_file('a.lua', {
'local M = {}',
'function M.foo()',
' local x = 1',
'end',
'return M',
})
local hunks = {
make_hunk('a.lua', {
file_new_start = 2,
file_new_count = 2,
lines = { ' function M.foo()', '+ local x = 1' },
}),
}
compute_hunk_context(hunks, 25)
assert.same({ 'end', 'return M' }, hunks[1].context_after)
end)
it('caps context_before to max_lines', function()
write_file('a.lua', {
'line1',
'line2',
'line3',
'line4',
'line5',
'target',
})
local hunks = {
make_hunk('a.lua', {
file_new_start = 6,
file_new_count = 1,
lines = { '+target' },
}),
}
compute_hunk_context(hunks, 2)
assert.same({ 'line4', 'line5' }, hunks[1].context_before)
end)
it('caps context_after to max_lines', function()
write_file('a.lua', {
'target',
'after1',
'after2',
'after3',
'after4',
})
local hunks = {
make_hunk('a.lua', {
file_new_start = 1,
file_new_count = 1,
lines = { '+target' },
}),
}
compute_hunk_context(hunks, 2)
assert.same({ 'after1', 'after2' }, hunks[1].context_after)
end)
it('skips hunks without file_new_start', function()
write_file('a.lua', { 'line1', 'line2' })
local hunks = {
make_hunk('a.lua', {
file_new_start = nil,
file_new_count = nil,
lines = { '+something' },
}),
}
compute_hunk_context(hunks, 25)
assert.is_nil(hunks[1].context_before)
assert.is_nil(hunks[1].context_after)
end)
it('skips hunks without repo_root', function()
local hunks = {
{
filename = 'a.lua',
ft = 'lua',
lang = 'lua',
start_line = 1,
lines = { '+x' },
prefix_width = 1,
quote_width = 0,
repo_root = nil,
file_new_start = 1,
file_new_count = 1,
},
}
compute_hunk_context(hunks, 25)
assert.is_nil(hunks[1].context_before)
assert.is_nil(hunks[1].context_after)
end)
it('skips when file does not exist on disk', function()
local hunks = {
make_hunk('nonexistent.lua', {
file_new_start = 1,
file_new_count = 1,
lines = { '+x' },
}),
}
compute_hunk_context(hunks, 25)
assert.is_nil(hunks[1].context_before)
assert.is_nil(hunks[1].context_after)
end)
it('returns nil context_before for hunk at line 1', function()
write_file('a.lua', { 'first', 'second' })
local hunks = {
make_hunk('a.lua', {
file_new_start = 1,
file_new_count = 1,
lines = { '+first' },
}),
}
compute_hunk_context(hunks, 25)
assert.is_nil(hunks[1].context_before)
end)
it('returns nil context_after for hunk at end of file', function()
write_file('a.lua', { 'first', 'last' })
local hunks = {
make_hunk('a.lua', {
file_new_start = 1,
file_new_count = 2,
lines = { ' first', '+last' },
}),
}
compute_hunk_context(hunks, 25)
assert.is_nil(hunks[1].context_after)
end)
it('reads file once for multiple hunks in same file', function()
write_file('a.lua', {
'local M = {}',
'function M.foo()',
' return 1',
'end',
'function M.bar()',
' return 2',
'end',
'return M',
})
local hunks = {
make_hunk('a.lua', {
file_new_start = 2,
file_new_count = 3,
lines = { ' function M.foo()', '+ return 1', ' end' },
}),
make_hunk('a.lua', {
file_new_start = 5,
file_new_count = 3,
lines = { ' function M.bar()', '+ return 2', ' end' },
}),
}
compute_hunk_context(hunks, 25)
assert.same({ 'local M = {}' }, hunks[1].context_before)
assert.same({ 'function M.bar()', ' return 2', 'end', 'return M' }, hunks[1].context_after)
assert.same({
'local M = {}',
'function M.foo()',
' return 1',
'end',
}, hunks[2].context_before)
assert.same({ 'return M' }, hunks[2].context_after)
end)
end)
describe('highlight_treesitter with context', function()
local ns
before_each(function()
ns = vim.api.nvim_create_namespace('diffs_context_test')
local normal = vim.api.nvim_get_hl(0, { name = 'Normal' })
vim.api.nvim_set_hl(0, 'DiffsClear', { fg = normal.fg or 0xc0c0c0 })
end)
local function create_buffer(lines)
local bufnr = vim.api.nvim_create_buf(false, true)
vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines)
return bufnr
end
local function delete_buffer(bufnr)
if vim.api.nvim_buf_is_valid(bufnr) then
vim.api.nvim_buf_delete(bufnr, { force = true })
end
end
local function get_extmarks(bufnr)
return vim.api.nvim_buf_get_extmarks(bufnr, ns, 0, -1, { details = true })
end
local function default_opts(overrides)
local opts = {
hide_prefix = false,
highlights = {
background = false,
gutter = false,
context = { enabled = true, lines = 25 },
treesitter = { enabled = true, max_lines = 500 },
vim = { enabled = false, max_lines = 200 },
intra = { enabled = false, algorithm = 'default', max_lines = 500 },
priorities = { clear = 198, syntax = 199, line_bg = 200, char_bg = 201 },
},
}
if overrides then
if overrides.highlights then
opts.highlights = vim.tbl_deep_extend('force', opts.highlights, overrides.highlights)
end
end
return opts
end
it('applies extmarks only to hunk lines, not context lines', function()
local bufnr = create_buffer({
'@@ -1,2 +1,3 @@',
' local x = 1',
' local y = 2',
'+local z = 3',
})
local hunk = {
filename = 'test.lua',
lang = 'lua',
start_line = 2,
lines = { ' local x = 1', ' local y = 2', '+local z = 3' },
prefix_width = 1,
quote_width = 0,
context_before = { 'local function foo()' },
context_after = { 'end' },
}
highlight.highlight_hunk(bufnr, ns, hunk, default_opts())
local extmarks = get_extmarks(bufnr)
for _, mark in ipairs(extmarks) do
local row = mark[2]
assert.is_true(row >= 1 and row <= 3, 'extmark row ' .. row .. ' outside hunk range')
end
assert.is_true(#extmarks > 0)
delete_buffer(bufnr)
end)
it('does not pass context when context.enabled = false', function()
local bufnr = create_buffer({
'@@ -1,1 +1,2 @@',
' local x = 1',
'+local y = 2',
})
local hunk = {
filename = 'test.lua',
lang = 'lua',
start_line = 2,
lines = { ' local x = 1', '+local y = 2' },
prefix_width = 1,
quote_width = 0,
context_before = { 'local function foo()' },
context_after = { 'end' },
}
local opts_enabled = default_opts({ highlights = { context = { enabled = true } } })
highlight.highlight_hunk(bufnr, ns, hunk, opts_enabled)
local extmarks_with = get_extmarks(bufnr)
vim.api.nvim_buf_clear_namespace(bufnr, ns, 0, -1)
local opts_disabled = default_opts({ highlights = { context = { enabled = false } } })
highlight.highlight_hunk(bufnr, ns, hunk, opts_disabled)
local extmarks_without = get_extmarks(bufnr)
assert.is_true(#extmarks_with > 0)
assert.is_true(#extmarks_without > 0)
delete_buffer(bufnr)
end)
it('skips context fields that are nil', function()
local bufnr = create_buffer({
'@@ -1,1 +1,2 @@',
' local x = 1',
'+local y = 2',
})
local hunk = {
filename = 'test.lua',
lang = 'lua',
start_line = 2,
lines = { ' local x = 1', '+local y = 2' },
prefix_width = 1,
quote_width = 0,
}
highlight.highlight_hunk(bufnr, ns, hunk, default_opts())
local extmarks = get_extmarks(bufnr)
assert.is_true(#extmarks > 0)
delete_buffer(bufnr)
end)
end)
end)