fix(highlight): support combined diff format for unmerged files

Problem: fugitive shows combined diffs (@@@ headers, 2-char prefixes)
for unmerged (UU) files. The parser and highlight pipeline assumed
unified diff format (@@, 1-char prefix), causing broken prefix
concealment, missing background colors on ` +`/`+ ` lines, and no
treesitter highlights due to garbage prefix chars in code arrays.

Solution: detect prefix width from the number of leading @ signs in
hunk headers. Propagate prefix_width through the parser (new field on
diffs.Hunk) and highlight pipeline (prefix stripping, col_offset,
concealment, line classification). Add U to filename pattern for
unmerged file detection. Skip intra-line diffing for combined diffs
since the 2-char prefix semantics don't produce meaningful change
groups.
This commit is contained in:
Barrett Ruth 2026-02-09 17:03:17 -05:00
parent 59fcf14817
commit a6dd0503b3
4 changed files with 232 additions and 29 deletions

View file

@ -239,13 +239,14 @@ local function highlight_vim_syntax(
pcall(vim.api.nvim_buf_delete, scratch, { force = true }) pcall(vim.api.nvim_buf_delete, scratch, { force = true })
local hunk_line_count = #hunk.lines local hunk_line_count = #hunk.lines
local col_off = (hunk.prefix_width or 1) - 1
local extmark_count = 0 local extmark_count = 0
for _, span in ipairs(spans) do for _, span in ipairs(spans) do
local adj = span.line - leading_offset local adj = span.line - leading_offset
if adj >= 1 and adj <= hunk_line_count then if adj >= 1 and adj <= hunk_line_count then
local buf_line = hunk.start_line + adj - 1 local buf_line = hunk.start_line + adj - 1
pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, span.col_start, { pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, span.col_start + col_off, {
end_col = span.col_end, end_col = span.col_end + col_off,
hl_group = span.hl_name, hl_group = span.hl_name,
priority = priorities.syntax, priority = priorities.syntax,
}) })
@ -265,6 +266,7 @@ end
---@param opts diffs.HunkOpts ---@param opts diffs.HunkOpts
function M.highlight_hunk(bufnr, ns, hunk, opts) function M.highlight_hunk(bufnr, ns, hunk, opts)
local p = opts.highlights.priorities local p = opts.highlights.priorities
local pw = hunk.prefix_width or 1
local use_ts = hunk.lang and opts.highlights.treesitter.enabled local use_ts = hunk.lang and opts.highlights.treesitter.enabled
local use_vim = not use_ts and hunk.ft and opts.highlights.vim.enabled local use_vim = not use_ts and hunk.ft and opts.highlights.vim.enabled
@ -296,14 +298,16 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
local old_map = {} local old_map = {}
for i, line in ipairs(hunk.lines) do for i, line in ipairs(hunk.lines) do
local prefix = line:sub(1, 1) local prefix = line:sub(1, pw)
local stripped = line:sub(2) local stripped = line:sub(pw + 1)
local buf_line = hunk.start_line + i - 1 local buf_line = hunk.start_line + i - 1
local has_add = prefix:find('+', 1, true) ~= nil
local has_del = prefix:find('-', 1, true) ~= nil
if prefix == '+' then if has_add and not has_del then
new_map[#new_code] = buf_line new_map[#new_code] = buf_line
table.insert(new_code, stripped) table.insert(new_code, stripped)
elseif prefix == '-' then elseif has_del and not has_add then
old_map[#old_code] = buf_line old_map[#old_code] = buf_line
table.insert(old_code, stripped) table.insert(old_code, stripped)
else else
@ -314,9 +318,9 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
end end
extmark_count = extmark_count =
highlight_treesitter(bufnr, ns, new_code, hunk.lang, new_map, 1, covered_lines, p) highlight_treesitter(bufnr, ns, new_code, hunk.lang, new_map, pw, covered_lines, p)
extmark_count = extmark_count extmark_count = extmark_count
+ highlight_treesitter(bufnr, ns, old_code, hunk.lang, old_map, 1, covered_lines, p) + highlight_treesitter(bufnr, ns, old_code, hunk.lang, old_map, pw, covered_lines, p)
if hunk.header_context and hunk.header_context_col then if hunk.header_context and hunk.header_context_col then
local header_line = hunk.start_line - 1 local header_line = hunk.start_line - 1
@ -344,7 +348,7 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
---@type string[] ---@type string[]
local code_lines = {} local code_lines = {}
for _, line in ipairs(hunk.lines) do for _, line in ipairs(hunk.lines) do
table.insert(code_lines, line:sub(2)) table.insert(code_lines, line:sub(pw + 1))
end end
extmark_count = highlight_vim_syntax(bufnr, ns, hunk, code_lines, covered_lines, 0, p) extmark_count = highlight_vim_syntax(bufnr, ns, hunk, code_lines, covered_lines, 0, p)
end end
@ -367,7 +371,7 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
---@type diffs.IntraChanges? ---@type diffs.IntraChanges?
local intra = nil local intra = nil
local intra_cfg = opts.highlights.intra local intra_cfg = opts.highlights.intra
if intra_cfg and intra_cfg.enabled and #hunk.lines <= intra_cfg.max_lines then if intra_cfg and intra_cfg.enabled and pw == 1 and #hunk.lines <= intra_cfg.max_lines then
dbg('computing intra for hunk %s:%d (%d lines)', hunk.filename, hunk.start_line, #hunk.lines) dbg('computing intra for hunk %s:%d (%d lines)', hunk.filename, hunk.start_line, #hunk.lines)
intra = diff.compute_intra_hunks(hunk.lines, intra_cfg.algorithm) intra = diff.compute_intra_hunks(hunk.lines, intra_cfg.algorithm)
if intra then if intra then
@ -401,22 +405,23 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
for i, line in ipairs(hunk.lines) do for i, line in ipairs(hunk.lines) do
local buf_line = hunk.start_line + i - 1 local buf_line = hunk.start_line + i - 1
local line_len = #line local line_len = #line
local prefix = line:sub(1, 1) local prefix = line:sub(1, pw)
local has_add = prefix:find('+', 1, true) ~= nil
local is_diff_line = prefix == '+' or prefix == '-' local has_del = prefix:find('-', 1, true) ~= nil
local line_hl = is_diff_line and (prefix == '+' and 'DiffsAdd' or 'DiffsDelete') or nil local is_diff_line = has_add or has_del
local number_hl = is_diff_line and (prefix == '+' and 'DiffsAddNr' or 'DiffsDeleteNr') or nil local line_hl = is_diff_line and (has_add and 'DiffsAdd' or 'DiffsDelete') or nil
local number_hl = is_diff_line and (has_add and 'DiffsAddNr' or 'DiffsDeleteNr') or nil
if opts.hide_prefix then if opts.hide_prefix then
local virt_hl = (opts.highlights.background and line_hl) or nil local virt_hl = (opts.highlights.background and line_hl) or nil
pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, 0, { pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, 0, {
virt_text = { { ' ', virt_hl } }, virt_text = { { string.rep(' ', pw), virt_hl } },
virt_text_pos = 'overlay', virt_text_pos = 'overlay',
}) })
end end
if line_len > 1 and covered_lines[buf_line] then if line_len > pw and covered_lines[buf_line] then
pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, 1, { pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, pw, {
end_col = line_len, end_col = line_len,
hl_group = 'DiffsClear', hl_group = 'DiffsClear',
priority = p.clear, priority = p.clear,
@ -439,7 +444,7 @@ function M.highlight_hunk(bufnr, ns, hunk, opts)
end end
if char_spans_by_line[i] then if char_spans_by_line[i] then
local char_hl = prefix == '+' and 'DiffsAddText' or 'DiffsDeleteText' local char_hl = has_add and 'DiffsAddText' or 'DiffsDeleteText'
for _, span in ipairs(char_spans_by_line[i]) do for _, span in ipairs(char_spans_by_line[i]) do
dbg( dbg(
'char extmark: line=%d buf_line=%d col=%d..%d hl=%s text="%s"', 'char extmark: line=%d buf_line=%d col=%d..%d hl=%s text="%s"',

View file

@ -12,6 +12,7 @@
---@field file_old_count integer? ---@field file_old_count integer?
---@field file_new_start integer? ---@field file_new_start integer?
---@field file_new_count integer? ---@field file_new_count integer?
---@field prefix_width integer
---@field repo_root string? ---@field repo_root string?
local M = {} local M = {}
@ -133,6 +134,8 @@ function M.parse_buffer(bufnr)
local hunk_lines = {} local hunk_lines = {}
---@type integer? ---@type integer?
local hunk_count = nil local hunk_count = nil
---@type integer
local hunk_prefix_width = 1
---@type integer? ---@type integer?
local header_start = nil local header_start = nil
---@type string[] ---@type string[]
@ -156,6 +159,7 @@ function M.parse_buffer(bufnr)
header_context = hunk_header_context, header_context = hunk_header_context,
header_context_col = hunk_header_context_col, header_context_col = hunk_header_context_col,
lines = hunk_lines, lines = hunk_lines,
prefix_width = hunk_prefix_width,
file_old_start = file_old_start, file_old_start = file_old_start,
file_old_count = file_old_count, file_old_count = file_old_count,
file_new_start = file_new_start, file_new_start = file_new_start,
@ -179,7 +183,7 @@ function M.parse_buffer(bufnr)
end end
for i, line in ipairs(lines) do for i, line in ipairs(lines) do
local filename = line:match('^[MADRC%?!]%s+(.+)$') or line:match('^diff %-%-git a/.+ b/(.+)$') local filename = line:match('^[MADRCU%?!]%s+(.+)$') or line:match('^diff %-%-git a/.+ b/(.+)$')
if filename then if filename then
flush_hunk() flush_hunk()
current_filename = filename current_filename = filename
@ -191,11 +195,15 @@ function M.parse_buffer(bufnr)
dbg('file: %s -> ft: %s (no ts parser)', filename, current_ft) dbg('file: %s -> ft: %s (no ts parser)', filename, current_ft)
end end
hunk_count = 0 hunk_count = 0
hunk_prefix_width = 1
header_start = i header_start = i
header_lines = {} header_lines = {}
elseif line:match('^@@.-@@') then elseif line:match('^@@+') then
flush_hunk() flush_hunk()
hunk_start = i hunk_start = i
local at_prefix = line:match('^(@@+)')
hunk_prefix_width = #at_prefix - 1
if #at_prefix == 2 then
local hs, hc, hs2, hc2 = line:match('^@@ %-(%d+),?(%d*) %+(%d+),?(%d*) @@') local hs, hc, hs2, hc2 = line:match('^@@ %-(%d+),?(%d*) %+(%d+),?(%d*) @@')
if hs then if hs then
file_old_start = tonumber(hs) file_old_start = tonumber(hs)
@ -203,10 +211,17 @@ function M.parse_buffer(bufnr)
file_new_start = tonumber(hs2) file_new_start = tonumber(hs2)
file_new_count = tonumber(hc2) or 1 file_new_count = tonumber(hc2) or 1
end end
local prefix, context = line:match('^(@@.-@@%s*)(.*)') else
local hs2, hc2 = line:match('%+(%d+),?(%d*) @@')
if hs2 then
file_new_start = tonumber(hs2)
file_new_count = tonumber(hc2) or 1
end
end
local at_end, context = line:match('^(@@+.-@@+%s*)(.*)')
if context and context ~= '' then if context and context ~= '' then
hunk_header_context = context hunk_header_context = context
hunk_header_context_col = #prefix hunk_header_context_col = #at_end
end end
if hunk_count then if hunk_count then
hunk_count = hunk_count + 1 hunk_count = hunk_count + 1

View file

@ -830,6 +830,114 @@ describe('highlight', function()
delete_buffer(bufnr) delete_buffer(bufnr)
end) end)
it('applies DiffsAdd background to combined diff lines with + in prefix', function()
local bufnr = create_buffer({
'@@@ -1,3 -1,5 +1,9 @@@',
' local M = {}',
'++<<<<<<< HEAD',
' + return 1',
'++=======',
'+ return 2',
'++>>>>>>> feature',
})
local hunk = {
filename = 'test.lua',
lang = 'lua',
start_line = 1,
prefix_width = 2,
lines = {
' local M = {}',
'++<<<<<<< HEAD',
' + return 1',
'++=======',
'+ return 2',
'++>>>>>>> feature',
},
}
highlight.highlight_hunk(
bufnr,
ns,
hunk,
default_opts({ highlights = { background = true } })
)
local extmarks = get_extmarks(bufnr)
local add_lines = {}
for _, mark in ipairs(extmarks) do
if mark[4] and mark[4].hl_group == 'DiffsAdd' then
add_lines[mark[2]] = true
end
end
assert.is_nil(add_lines[0])
assert.is_nil(add_lines[1])
assert.is_true(add_lines[2] ~= nil)
assert.is_true(add_lines[3] ~= nil)
assert.is_true(add_lines[4] ~= nil)
assert.is_true(add_lines[5] ~= nil)
assert.is_true(add_lines[6] ~= nil)
delete_buffer(bufnr)
end)
it('conceals 2-char prefix for combined diffs', function()
local bufnr = create_buffer({
'@@@ -1,2 -1,2 +1,3 @@@',
' local M = {}',
'++<<<<<<< HEAD',
})
local hunk = {
filename = 'test.lua',
lang = 'lua',
start_line = 1,
prefix_width = 2,
lines = { ' local M = {}', '++<<<<<<< HEAD' },
}
highlight.highlight_hunk(bufnr, ns, hunk, default_opts({ hide_prefix = true }))
local extmarks = get_extmarks(bufnr)
local overlay_count = 0
for _, mark in ipairs(extmarks) do
if mark[4] and mark[4].virt_text_pos == 'overlay' then
overlay_count = overlay_count + 1
assert.are.equal(' ', mark[4].virt_text[1][1])
end
end
assert.are.equal(2, overlay_count)
delete_buffer(bufnr)
end)
it('produces treesitter captures on combined diff content lines', function()
local bufnr = create_buffer({
'@@@ -1,2 -1,2 +1,3 @@@',
' local M = {}',
' +local x = 1',
})
local hunk = {
filename = 'test.lua',
lang = 'lua',
start_line = 1,
prefix_width = 2,
lines = { ' local M = {}', ' +local x = 1' },
}
highlight.highlight_hunk(bufnr, ns, hunk, default_opts())
local extmarks = get_extmarks(bufnr)
local has_ts = false
for _, mark in ipairs(extmarks) do
if mark[4] and mark[4].hl_group and mark[4].hl_group:match('^@.*%.lua$') then
has_ts = true
break
end
end
assert.is_true(has_ts)
delete_buffer(bufnr)
end)
it('filters @spell and @nospell captures from injections', function() it('filters @spell and @nospell captures from injections', function()
local bufnr = create_buffer({ local bufnr = create_buffer({
'@@ -1,1 +1,2 @@', '@@ -1,1 +1,2 @@',

View file

@ -425,6 +425,81 @@ describe('parser', function()
delete_buffer(bufnr) delete_buffer(bufnr)
end) end)
it('recognizes U prefix for unmerged files', function()
local bufnr = create_buffer({
'U merge_me.lua',
'@@@ -1,3 -1,5 +1,9 @@@',
' local M = {}',
'++<<<<<<< HEAD',
' + return 1',
'++=======',
'+ return 2',
'++>>>>>>> feature',
})
local hunks = parser.parse_buffer(bufnr)
assert.are.equal(1, #hunks)
assert.are.equal('merge_me.lua', hunks[1].filename)
assert.are.equal('lua', hunks[1].ft)
delete_buffer(bufnr)
end)
it('sets prefix_width from @@@ combined diff header', function()
local bufnr = create_buffer({
'U test.lua',
'@@@ -1,3 -1,5 +1,9 @@@',
' local M = {}',
'++<<<<<<< HEAD',
' + return 1',
})
local hunks = parser.parse_buffer(bufnr)
assert.are.equal(1, #hunks)
assert.are.equal(2, hunks[1].prefix_width)
delete_buffer(bufnr)
end)
it('sets prefix_width 1 for standard @@ unified diff', function()
local bufnr = create_buffer({
'M test.lua',
'@@ -1,2 +1,3 @@',
' local x = 1',
'+local y = 2',
})
local hunks = parser.parse_buffer(bufnr)
assert.are.equal(1, #hunks)
assert.are.equal(1, hunks[1].prefix_width)
delete_buffer(bufnr)
end)
it('extracts new range from combined diff header', function()
local bufnr = create_buffer({
'U test.lua',
'@@@ -1,3 -1,5 +1,9 @@@',
' local M = {}',
})
local hunks = parser.parse_buffer(bufnr)
assert.are.equal(1, #hunks)
assert.are.equal(1, hunks[1].file_new_start)
assert.are.equal(9, hunks[1].file_new_count)
delete_buffer(bufnr)
end)
it('extracts header context from combined diff header', function()
local bufnr = create_buffer({
'U test.lua',
'@@@ -1,3 -1,5 +1,9 @@@ function M.greet()',
' local M = {}',
})
local hunks = parser.parse_buffer(bufnr)
assert.are.equal(1, #hunks)
assert.are.equal('function M.greet()', hunks[1].header_context)
delete_buffer(bufnr)
end)
it('stores repo_root on hunk when available', function() it('stores repo_root on hunk when available', function()
local bufnr = create_buffer({ local bufnr = create_buffer({
'M lua/test.lua', 'M lua/test.lua',