Merge pull request #78 from barrettruth/fix/treesitter-split-parsing

fix(highlight): split old/new treesitter parsing
This commit is contained in:
Barrett Ruth 2026-02-07 00:54:48 -05:00 committed by GitHub
commit 12eaac4727
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 186 additions and 101 deletions

View file

@ -18,14 +18,16 @@ function M.dump()
end_col = details.end_col,
hl_group = details.hl_group,
priority = details.priority,
hl_eol = details.hl_eol,
line_hl_group = details.line_hl_group,
number_hl_group = details.number_hl_group,
virt_text = details.virt_text,
}
if not by_line[row] then
by_line[row] = { text = lines[row + 1] or '', marks = {} }
local key = tostring(row)
if not by_line[key] then
by_line[key] = { text = lines[row + 1] or '', marks = {} }
end
table.insert(by_line[row].marks, entry)
table.insert(by_line[key].marks, entry)
end
local all_ns_marks = vim.api.nvim_buf_get_extmarks(bufnr, -1, 0, -1, { details = true })

View file

@ -137,7 +137,12 @@ local function char_diff_pair(old_line, new_line, del_idx, add_idx, diff_opts)
local old_text = table.concat(old_bytes, '\n') .. '\n'
local new_text = table.concat(new_bytes, '\n') .. '\n'
local char_hunks = byte_diff(old_text, new_text, diff_opts)
local char_opts = diff_opts
if diff_opts and diff_opts.linematch then
char_opts = { algorithm = diff_opts.algorithm }
end
local char_hunks = byte_diff(old_text, new_text, char_opts)
for _, ch in ipairs(char_hunks) do
if ch.old_count > 0 then

View file

@ -3,6 +3,11 @@ local M = {}
local dbg = require('diffs.log').dbg
local diff = require('diffs.diff')
local PRIORITY_CLEAR = 198
local PRIORITY_SYNTAX = 199
local PRIORITY_LINE_BG = 200
local PRIORITY_CHAR_BG = 201
---@param bufnr integer
---@param ns integer
---@param hunk diffs.Hunk
@ -38,7 +43,7 @@ local function highlight_text(bufnr, ns, hunk, col_offset, text, lang)
local buf_sc = col_offset + sc
local buf_ec = col_offset + ec
local priority = lang == 'diff' and (tonumber(metadata.priority) or 100) or 200
local priority = lang == 'diff' and (tonumber(metadata.priority) or 100) or PRIORITY_SYNTAX
pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_sr, buf_sc, {
end_row = buf_er,
@ -58,16 +63,21 @@ end
---@param bufnr integer
---@param ns integer
---@param hunk diffs.Hunk
---@param code_lines string[]
---@param col_offset integer?
---@param lang string
---@param line_map table<integer, integer>
---@param col_offset integer
---@param covered_lines? table<integer, true>
---@return integer
local function highlight_treesitter(bufnr, ns, hunk, code_lines, col_offset)
local lang = hunk.lang
if not lang then
return 0
end
local function highlight_treesitter(
bufnr,
ns,
code_lines,
lang,
line_map,
col_offset,
covered_lines
)
local code = table.concat(code_lines, '\n')
if code == '' then
return 0
@ -91,41 +101,31 @@ local function highlight_treesitter(bufnr, ns, hunk, code_lines, col_offset)
return 0
end
if hunk.header_context and hunk.header_context_col then
local header_line = hunk.start_line - 1
pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, header_line, hunk.header_context_col, {
end_col = hunk.header_context_col + #hunk.header_context,
hl_group = 'Normal',
priority = 199,
})
local header_extmarks =
highlight_text(bufnr, ns, hunk, hunk.header_context_col, hunk.header_context, lang)
if header_extmarks > 0 then
dbg('header %s:%d applied %d extmarks', hunk.filename, hunk.start_line, header_extmarks)
end
end
col_offset = col_offset or 1
local extmark_count = 0
for id, node, metadata in query:iter_captures(trees[1]:root(), code) do
local capture_name = '@' .. query.captures[id] .. '.' .. lang
local sr, sc, er, ec = node:range()
local buf_sr = hunk.start_line + sr
local buf_er = hunk.start_line + er
local buf_sc = sc + col_offset
local buf_ec = ec + col_offset
local buf_sr = line_map[sr]
if buf_sr then
local buf_er = line_map[er] or buf_sr
local priority = lang == 'diff' and (tonumber(metadata.priority) or 100) or 200
local buf_sc = sc + col_offset
local buf_ec = ec + col_offset
pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_sr, buf_sc, {
end_row = buf_er,
end_col = buf_ec,
hl_group = capture_name,
priority = priority,
})
extmark_count = extmark_count + 1
local priority = lang == 'diff' and (tonumber(metadata.priority) or 100) or PRIORITY_SYNTAX
pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_sr, buf_sc, {
end_row = buf_er,
end_col = buf_ec,
hl_group = capture_name,
priority = priority,
})
extmark_count = extmark_count + 1
if covered_lines then
covered_lines[buf_sr] = true
end
end
end
return extmark_count
@ -176,8 +176,9 @@ end
---@param ns integer
---@param hunk diffs.Hunk
---@param code_lines string[]
---@param covered_lines? table<integer, true>
---@return integer
local function highlight_vim_syntax(bufnr, ns, hunk, code_lines)
local function highlight_vim_syntax(bufnr, ns, hunk, code_lines, covered_lines)
local ft = hunk.ft
if not ft then
return 0
@ -219,9 +220,12 @@ local function highlight_vim_syntax(bufnr, ns, hunk, code_lines)
pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, span.col_start, {
end_col = span.col_end,
hl_group = span.hl_name,
priority = 200,
priority = PRIORITY_SYNTAX,
})
extmark_count = extmark_count + 1
if covered_lines then
covered_lines[buf_line] = true
end
end
return extmark_count
@ -248,21 +252,63 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
use_vim = false
end
local apply_syntax = use_ts or use_vim
---@type string[]
local code_lines = {}
if apply_syntax then
for _, line in ipairs(hunk.lines) do
table.insert(code_lines, line:sub(2))
end
end
---@type table<integer, true>
local covered_lines = {}
local extmark_count = 0
if use_ts then
extmark_count = highlight_treesitter(bufnr, ns, hunk, code_lines)
---@type string[]
local new_code = {}
---@type table<integer, integer>
local new_map = {}
---@type string[]
local old_code = {}
---@type table<integer, integer>
local old_map = {}
for i, line in ipairs(hunk.lines) do
local prefix = line:sub(1, 1)
local stripped = line:sub(2)
local buf_line = hunk.start_line + i - 1
if prefix == '+' then
new_map[#new_code] = buf_line
table.insert(new_code, stripped)
elseif prefix == '-' then
old_map[#old_code] = buf_line
table.insert(old_code, stripped)
else
new_map[#new_code] = buf_line
table.insert(new_code, stripped)
table.insert(old_code, stripped)
end
end
extmark_count = highlight_treesitter(bufnr, ns, new_code, hunk.lang, new_map, 1, covered_lines)
extmark_count = extmark_count
+ highlight_treesitter(bufnr, ns, old_code, hunk.lang, old_map, 1, covered_lines)
if hunk.header_context and hunk.header_context_col then
local header_line = hunk.start_line - 1
pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, header_line, hunk.header_context_col, {
end_col = hunk.header_context_col + #hunk.header_context,
hl_group = 'DiffsClear',
priority = PRIORITY_CLEAR,
})
local header_extmarks =
highlight_text(bufnr, ns, hunk, hunk.header_context_col, hunk.header_context, hunk.lang)
if header_extmarks > 0 then
dbg('header %s:%d applied %d extmarks', hunk.filename, hunk.start_line, header_extmarks)
end
extmark_count = extmark_count + header_extmarks
end
elseif use_vim then
extmark_count = highlight_vim_syntax(bufnr, ns, hunk, code_lines)
---@type string[]
local code_lines = {}
for _, line in ipairs(hunk.lines) do
table.insert(code_lines, line:sub(2))
end
extmark_count = highlight_vim_syntax(bufnr, ns, hunk, code_lines, covered_lines)
end
if
@ -271,18 +317,15 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
and #hunk.header_lines > 0
and opts.highlights.treesitter.enabled
then
---@type table<integer, integer>
local header_map = {}
for i = 0, #hunk.header_lines - 1 do
header_map[i] = hunk.header_start_line - 1 + i
end
extmark_count = extmark_count
+ highlight_treesitter(bufnr, ns, {
filename = hunk.filename,
start_line = hunk.header_start_line - 1,
lang = 'diff',
lines = hunk.header_lines,
header_lines = {},
}, hunk.header_lines, 0)
+ highlight_treesitter(bufnr, ns, hunk.header_lines, 'diff', header_map, 0)
end
local syntax_applied = extmark_count > 0
---@type diffs.IntraChanges?
local intra = nil
local intra_cfg = opts.highlights.intra
@ -334,11 +377,11 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
})
end
if line_len > 1 and syntax_applied then
if line_len > 1 and covered_lines[buf_line] then
pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, 1, {
end_col = line_len,
hl_group = 'Normal',
priority = 198,
hl_group = 'DiffsClear',
priority = PRIORITY_CLEAR,
})
end
@ -348,7 +391,7 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
hl_group = line_hl,
hl_eol = true,
number_hl_group = opts.highlights.gutter and number_hl or nil,
priority = 199,
priority = PRIORITY_LINE_BG,
})
end
@ -367,7 +410,7 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
local ok, err = pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, span.col_start, {
end_col = span.col_end,
hl_group = char_hl,
priority = 201,
priority = PRIORITY_CHAR_BG,
})
if not ok then
dbg('char extmark FAILED: %s', err)

View file

@ -186,6 +186,7 @@ local function compute_highlight_groups()
local blended_add_text = blend_color(add_fg, bg, 0.7)
local blended_del_text = blend_color(del_fg, bg, 0.7)
vim.api.nvim_set_hl(0, 'DiffsClear', { default = true, fg = normal.fg or 0xc0c0c0 })
vim.api.nvim_set_hl(0, 'DiffsAdd', { default = true, bg = blended_add })
vim.api.nvim_set_hl(0, 'DiffsDelete', { default = true, bg = blended_del })
vim.api.nvim_set_hl(0, 'DiffsAddNr', { default = true, fg = add_fg, bg = blended_add })

View file

@ -7,8 +7,10 @@ describe('highlight', function()
before_each(function()
ns = vim.api.nvim_create_namespace('diffs_test')
local normal = vim.api.nvim_get_hl(0, { name = 'Normal' })
local diff_add = vim.api.nvim_get_hl(0, { name = 'DiffAdd' })
local diff_delete = vim.api.nvim_get_hl(0, { name = 'DiffDelete' })
vim.api.nvim_set_hl(0, 'DiffsClear', { fg = normal.fg or 0xc0c0c0 })
vim.api.nvim_set_hl(0, 'DiffsAdd', { bg = diff_add.bg })
vim.api.nvim_set_hl(0, 'DiffsDelete', { bg = diff_delete.bg })
end)
@ -82,7 +84,7 @@ describe('highlight', function()
delete_buffer(bufnr)
end)
it('applies Normal extmarks to clear diff colors', function()
it('applies DiffsClear extmarks to clear diff colors', function()
local bufnr = create_buffer({
'@@ -1,1 +1,2 @@',
' local x = 1',
@ -99,14 +101,46 @@ describe('highlight', function()
highlight.highlight_hunk(bufnr, ns, hunk, default_opts())
local extmarks = get_extmarks(bufnr)
local has_normal = false
local has_clear = false
for _, mark in ipairs(extmarks) do
if mark[4] and mark[4].hl_group == 'Normal' then
has_normal = true
if mark[4] and mark[4].hl_group == 'DiffsClear' then
has_clear = true
break
end
end
assert.is_true(has_normal)
assert.is_true(has_clear)
delete_buffer(bufnr)
end)
it('produces treesitter captures on all lines with split parsing', function()
local bufnr = create_buffer({
'@@ -1,3 +1,3 @@',
' local x = 1',
'-local y = 2',
'+local y = 3',
' return x',
})
local hunk = {
filename = 'test.lua',
lang = 'lua',
start_line = 1,
lines = { ' local x = 1', '-local y = 2', '+local y = 3', ' return x' },
}
highlight.highlight_hunk(bufnr, ns, hunk, default_opts())
local extmarks = get_extmarks(bufnr)
local lines_with_ts = {}
for _, mark in ipairs(extmarks) do
if mark[4] and mark[4].hl_group and mark[4].hl_group:match('^@.*%.lua$') then
lines_with_ts[mark[2]] = true
end
end
assert.is_true(lines_with_ts[1] ~= nil)
assert.is_true(lines_with_ts[2] ~= nil)
assert.is_true(lines_with_ts[3] ~= nil)
assert.is_true(lines_with_ts[4] ~= nil)
delete_buffer(bufnr)
end)
@ -576,7 +610,7 @@ describe('highlight', function()
local extmarks = get_extmarks(bufnr)
local has_syntax_hl = false
for _, mark in ipairs(extmarks) do
if mark[4] and mark[4].hl_group and mark[4].hl_group ~= 'Normal' then
if mark[4] and mark[4].hl_group and mark[4].hl_group ~= 'DiffsClear' then
has_syntax_hl = true
break
end
@ -610,7 +644,7 @@ describe('highlight', function()
local extmarks = get_extmarks(bufnr)
local has_syntax_hl = false
for _, mark in ipairs(extmarks) do
if mark[4] and mark[4].hl_group and mark[4].hl_group ~= 'Normal' then
if mark[4] and mark[4].hl_group and mark[4].hl_group ~= 'DiffsClear' then
has_syntax_hl = true
break
end
@ -682,7 +716,7 @@ describe('highlight', function()
delete_buffer(bufnr)
end)
it('applies Normal blanking for vim fallback hunks', function()
it('applies DiffsClear blanking for vim fallback hunks', function()
local orig_synID = vim.fn.synID
local orig_synIDtrans = vim.fn.synIDtrans
local orig_synIDattr = vim.fn.synIDattr
@ -722,14 +756,14 @@ describe('highlight', function()
vim.fn.synIDattr = orig_synIDattr
local extmarks = get_extmarks(bufnr)
local has_normal = false
local has_clear = false
for _, mark in ipairs(extmarks) do
if mark[4] and mark[4].hl_group == 'Normal' then
has_normal = true
if mark[4] and mark[4].hl_group == 'DiffsClear' then
has_clear = true
break
end
end
assert.is_true(has_normal)
assert.is_true(has_clear)
delete_buffer(bufnr)
end)
@ -765,7 +799,7 @@ describe('highlight', function()
delete_buffer(bufnr)
end)
it('line bg priority > Normal priority', function()
it('line bg priority > DiffsClear priority', function()
local bufnr = create_buffer({
'@@ -1,2 +1,1 @@',
'-local x = 1',
@ -787,20 +821,20 @@ describe('highlight', function()
)
local extmarks = get_extmarks(bufnr)
local normal_priority = nil
local clear_priority = nil
local line_bg_priority = nil
for _, mark in ipairs(extmarks) do
local d = mark[4]
if d and d.hl_group == 'Normal' then
normal_priority = d.priority
if d and d.hl_group == 'DiffsClear' then
clear_priority = d.priority
end
if d and (d.hl_group == 'DiffsAdd' or d.hl_group == 'DiffsDelete') then
line_bg_priority = d.priority
end
end
assert.is_not_nil(normal_priority)
assert.is_not_nil(clear_priority)
assert.is_not_nil(line_bg_priority)
assert.is_true(line_bg_priority > normal_priority)
assert.is_true(line_bg_priority > clear_priority)
delete_buffer(bufnr)
end)
@ -960,7 +994,7 @@ describe('highlight', function()
delete_buffer(bufnr)
end)
it('enforces priority order: Normal < line bg < syntax < char bg', function()
it('enforces priority order: DiffsClear < syntax < line bg < char bg', function()
vim.api.nvim_set_hl(0, 'DiffsAddText', { bg = 0x00FF00 })
vim.api.nvim_set_hl(0, 'DiffsDeleteText', { bg = 0xFF0000 })
@ -990,12 +1024,12 @@ describe('highlight', function()
)
local extmarks = get_extmarks(bufnr)
local priorities = { normal = {}, line_bg = {}, syntax = {}, char_bg = {} }
local priorities = { clear = {}, line_bg = {}, syntax = {}, char_bg = {} }
for _, mark in ipairs(extmarks) do
local d = mark[4]
if d then
if d.hl_group == 'Normal' then
table.insert(priorities.normal, d.priority)
if d.hl_group == 'DiffsClear' then
table.insert(priorities.clear, d.priority)
elseif d.hl_group == 'DiffsAdd' or d.hl_group == 'DiffsDelete' then
table.insert(priorities.line_bg, d.priority)
elseif d.hl_group == 'DiffsAddText' or d.hl_group == 'DiffsDeleteText' then
@ -1006,19 +1040,19 @@ describe('highlight', function()
end
end
assert.is_true(#priorities.normal > 0)
assert.is_true(#priorities.clear > 0)
assert.is_true(#priorities.line_bg > 0)
assert.is_true(#priorities.syntax > 0)
assert.is_true(#priorities.char_bg > 0)
local max_normal = math.max(unpack(priorities.normal))
local max_clear = math.max(unpack(priorities.clear))
local min_line_bg = math.min(unpack(priorities.line_bg))
local min_syntax = math.min(unpack(priorities.syntax))
local min_char_bg = math.min(unpack(priorities.char_bg))
assert.is_true(max_normal < min_line_bg)
assert.is_true(min_line_bg < min_syntax)
assert.is_true(min_syntax < min_char_bg)
assert.is_true(max_clear < min_syntax)
assert.is_true(min_syntax < min_line_bg)
assert.is_true(min_line_bg < min_char_bg)
delete_buffer(bufnr)
end)
end)
@ -1214,7 +1248,7 @@ describe('highlight', function()
}
end
it('uses priority 200 for code languages', function()
it('uses priority 199 for code languages', function()
local bufnr = create_buffer({
'@@ -1,1 +1,2 @@',
' local x = 1',
@ -1231,16 +1265,16 @@ describe('highlight', function()
highlight.highlight_hunk(bufnr, ns, hunk, default_opts())
local extmarks = get_extmarks(bufnr)
local has_priority_200 = false
local has_priority_199 = false
for _, mark in ipairs(extmarks) do
if mark[4] and mark[4].hl_group and mark[4].hl_group:match('^@.*%.lua$') then
if mark[4].priority == 200 then
has_priority_200 = true
if mark[4].priority == 199 then
has_priority_199 = true
break
end
end
end
assert.is_true(has_priority_200)
assert.is_true(has_priority_199)
delete_buffer(bufnr)
end)
@ -1278,7 +1312,7 @@ describe('highlight', function()
end
assert.is_true(#diff_extmark_priorities > 0)
for _, priority in ipairs(diff_extmark_priorities) do
assert.is_true(priority < 200)
assert.is_true(priority < 199)
end
delete_buffer(bufnr)
end)