diff --git a/lua/diffs/highlight.lua b/lua/diffs/highlight.lua index 98943d2..2a2f8b1 100644 --- a/lua/diffs/highlight.lua +++ b/lua/diffs/highlight.lua @@ -74,6 +74,7 @@ end ---@param col_offset integer ---@param covered_lines? table ---@param priorities diffs.PrioritiesConfig +---@param force_high_priority? boolean ---@return integer local function highlight_treesitter( bufnr, @@ -83,7 +84,8 @@ local function highlight_treesitter( line_map, col_offset, covered_lines, - priorities + priorities, + force_high_priority ) local code = table.concat(code_lines, '\n') if code == '' then @@ -123,8 +125,9 @@ local function highlight_treesitter( local buf_sc = sc + col_offset local buf_ec = ec + col_offset + local meta_prio = tonumber(metadata.priority) or 100 local priority = tree_lang == 'diff' - and (col_offset > 0 and priorities.syntax or (tonumber(metadata.priority) or 100)) + and ((col_offset > 0 or force_high_priority) and (priorities.syntax + meta_prio - 100) or meta_prio) or priorities.syntax pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_sr, buf_sc, { @@ -367,14 +370,14 @@ function M.highlight_hunk(bufnr, ns, hunk, opts) header_map[i] = hunk.header_start_line - 1 + i end extmark_count = extmark_count - + highlight_treesitter(bufnr, ns, hunk.header_lines, 'diff', header_map, qw, nil, p) + + highlight_treesitter(bufnr, ns, hunk.header_lines, 'diff', header_map, qw, nil, p, qw > 0 or pw > 1) end local at_raw_line - if qw > 0 and opts.highlights.treesitter.enabled then + if (qw > 0 or pw > 1) and opts.highlights.treesitter.enabled then local at_buf_line = hunk.start_line - 1 at_raw_line = vim.api.nvim_buf_get_lines(bufnr, at_buf_line, at_buf_line + 1, false)[1] - if at_raw_line then + if qw > 0 and at_raw_line then local at_logical = at_raw_line:sub(qw + 1) local at_map = { [0] = at_buf_line } extmark_count = extmark_count @@ -417,7 +420,7 @@ function M.highlight_hunk(bufnr, ns, hunk, opts) end if - qw > 0 + (qw > 0 or pw > 1) and hunk.header_start_line and hunk.header_lines and #hunk.header_lines > 0 @@ -430,16 +433,46 @@ function M.highlight_hunk(bufnr, ns, hunk, opts) hl_group = 'DiffsClear', priority = p.clear, }) + + if pw > 1 then + local hline = hunk.header_lines[i + 1] + if hline:match('^index ') then + pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, qw, { + end_col = qw + 5, + hl_group = '@keyword.diff', + priority = p.syntax, + }) + local dot_pos = hline:find('%.%.', 1, false) + if dot_pos then + local rest = hline:sub(dot_pos + 2) + local hash = rest:match('^(%x+)') + if hash then + pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, qw + dot_pos + 1, { + end_col = qw + dot_pos + 1 + #hash, + hl_group = '@constant.diff', + priority = p.syntax, + }) + end + end + end + end end end - if qw > 0 and at_raw_line then + if (qw > 0 or pw > 1) and at_raw_line then local at_buf_line = hunk.start_line - 1 pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, at_buf_line, 0, { end_col = #at_raw_line, hl_group = 'DiffsClear', priority = p.clear, }) + if pw > 1 and opts.highlights.treesitter.enabled then + pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, at_buf_line, qw, { + end_col = #at_raw_line, + hl_group = '@attribute.diff', + priority = p.syntax, + }) + end end if use_ts and hunk.header_context and hunk.header_context_col then @@ -484,6 +517,25 @@ function M.highlight_hunk(bufnr, ns, hunk, opts) hl_group = 'DiffsClear', priority = p.clear, }) + elseif pw > 1 then + pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, 0, { + end_col = pw, + hl_group = 'DiffsClear', + priority = p.clear, + }) + end + + if pw > 1 then + for ci = 0, pw - 1 do + local ch = line:sub(ci + 1, ci + 1) + if ch == '+' or ch == '-' then + pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, ci + qw, { + end_col = ci + qw + 1, + hl_group = ch == '+' and '@diff.plus' or '@diff.minus', + priority = p.syntax, + }) + end + end end if line_len > pw and covered_lines[buf_line] then diff --git a/spec/highlight_spec.lua b/spec/highlight_spec.lua index ebd075f..4695d66 100644 --- a/spec/highlight_spec.lua +++ b/spec/highlight_spec.lua @@ -1064,11 +1064,16 @@ describe('highlight', function() highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) local extmarks = get_extmarks(bufnr) + local content_clear_count = 0 for _, mark in ipairs(extmarks) do if mark[4] and mark[4].hl_group == 'DiffsClear' then - assert.are.equal(2, mark[3]) + assert.is_true(mark[3] == 0 or mark[3] == 2, 'DiffsClear at unexpected col ' .. mark[3]) + if mark[3] == 2 then + content_clear_count = content_clear_count + 1 + end end end + assert.are.equal(2, content_clear_count) delete_buffer(bufnr) end) @@ -1363,6 +1368,449 @@ describe('highlight', function() assert.are.equal(0, header_extmarks) delete_buffer(bufnr) end) + + it('does not apply DiffsClear to header lines for non-quoted diffs', function() + local bufnr = create_buffer({ + 'diff --git a/parser.lua b/parser.lua', + 'index 3e8afa0..018159c 100644', + '--- a/parser.lua', + '+++ b/parser.lua', + '@@ -1,2 +1,3 @@', + ' local M = {}', + '+local x = 1', + }) + + local hunk = { + filename = 'parser.lua', + lang = 'lua', + start_line = 5, + lines = { ' local M = {}', '+local x = 1' }, + header_start_line = 1, + header_lines = { + 'diff --git a/parser.lua b/parser.lua', + 'index 3e8afa0..018159c 100644', + '--- a/parser.lua', + '+++ b/parser.lua', + }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + for _, mark in ipairs(extmarks) do + local d = mark[4] + if d and d.hl_group == 'DiffsClear' and mark[3] == 0 and mark[2] < 4 then + error('unexpected DiffsClear on header row ' .. mark[2] .. ' for non-quoted diff') + end + end + delete_buffer(bufnr) + end) + + it('preserves diff grammar treesitter on headers for non-quoted diffs', function() + local bufnr = create_buffer({ + 'diff --git a/parser.lua b/parser.lua', + '--- a/parser.lua', + '+++ b/parser.lua', + '@@ -1,2 +1,3 @@', + ' local M = {}', + '+local x = 1', + }) + + local hunk = { + filename = 'parser.lua', + lang = 'lua', + start_line = 4, + lines = { ' local M = {}', '+local x = 1' }, + header_start_line = 1, + header_lines = { + 'diff --git a/parser.lua b/parser.lua', + '--- a/parser.lua', + '+++ b/parser.lua', + }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local header_ts_count = 0 + for _, mark in ipairs(extmarks) do + local d = mark[4] + if mark[2] < 3 and d and d.hl_group and d.hl_group:match('^@.*%.diff$') then + header_ts_count = header_ts_count + 1 + end + end + assert.is_true(header_ts_count > 0, 'expected diff grammar treesitter on header lines') + delete_buffer(bufnr) + end) + + it('applies syntax extmarks to combined diff body lines', function() + local bufnr = create_buffer({ + '@@@ -1,2 -1,2 +1,3 @@@', + ' local M = {}', + '+ local x = 1', + ' -local y = 2', + }) + + local hunk = { + filename = 'test.lua', + lang = 'lua', + prefix_width = 2, + start_line = 1, + lines = { ' local M = {}', '+ local x = 1', ' -local y = 2' }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local syntax_on_body = 0 + for _, mark in ipairs(extmarks) do + local d = mark[4] + if mark[2] >= 1 and d and d.hl_group and d.hl_group:match('^@.*%.lua$') then + syntax_on_body = syntax_on_body + 1 + end + end + assert.is_true(syntax_on_body > 0, 'expected lua treesitter syntax on combined diff body') + delete_buffer(bufnr) + end) + + it('applies DiffsClear and per-char diff fg to combined diff body prefixes', function() + local bufnr = create_buffer({ + '@@@', + ' unchanged', + '+ added', + ' -removed', + '++both', + }) + + local hunk = { + filename = 'test.lua', + lang = 'lua', + prefix_width = 2, + start_line = 1, + lines = { ' unchanged', '+ added', ' -removed', '++both' }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local prefix_clears = {} + local plus_marks = {} + local minus_marks = {} + for _, mark in ipairs(extmarks) do + local d = mark[4] + if mark[2] >= 1 and d then + if d.hl_group == 'DiffsClear' and mark[3] == 0 and d.end_col == 2 then + prefix_clears[mark[2]] = true + end + if d.hl_group == '@diff.plus' and d.priority == 199 then + if not plus_marks[mark[2]] then + plus_marks[mark[2]] = {} + end + table.insert(plus_marks[mark[2]], mark[3]) + end + if d.hl_group == '@diff.minus' and d.priority == 199 then + if not minus_marks[mark[2]] then + minus_marks[mark[2]] = {} + end + table.insert(minus_marks[mark[2]], mark[3]) + end + end + end + + assert.is_true(prefix_clears[1] ~= nil, 'DiffsClear on context prefix') + assert.is_true(prefix_clears[2] ~= nil, 'DiffsClear on add prefix') + assert.is_true(prefix_clears[3] ~= nil, 'DiffsClear on del prefix') + assert.is_true(prefix_clears[4] ~= nil, 'DiffsClear on both-add prefix') + + assert.is_true(plus_marks[2] ~= nil, '@diff.plus on + in "+ added"') + assert.are.equal(0, plus_marks[2][1]) + + assert.is_true(minus_marks[3] ~= nil, '@diff.minus on - in " -removed"') + assert.are.equal(1, minus_marks[3][1]) + + assert.is_true(plus_marks[4] ~= nil, '@diff.plus on ++ in "++both"') + assert.are.equal(2, #plus_marks[4]) + + assert.is_nil(plus_marks[1], 'no @diff.plus on context " unchanged"') + assert.is_nil(minus_marks[1], 'no @diff.minus on context " unchanged"') + delete_buffer(bufnr) + end) + + it('applies DiffsClear to headers for combined diffs', function() + local bufnr = create_buffer({ + 'diff --combined lua/merge/target.lua', + 'index abc1234,def5678..ghi9012', + '--- a/lua/merge/target.lua', + '+++ b/lua/merge/target.lua', + '@@@ -1,2 -1,2 +1,3 @@@', + ' local M = {}', + '+ local x = 1', + }) + + local hunk = { + filename = 'lua/merge/target.lua', + lang = 'lua', + prefix_width = 2, + start_line = 5, + lines = { ' local M = {}', '+ local x = 1' }, + header_start_line = 1, + header_lines = { + 'diff --combined lua/merge/target.lua', + 'index abc1234,def5678..ghi9012', + '--- a/lua/merge/target.lua', + '+++ b/lua/merge/target.lua', + }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local clear_lines = {} + for _, mark in ipairs(extmarks) do + local d = mark[4] + if d and d.hl_group == 'DiffsClear' and mark[3] == 0 and mark[2] < 4 then + clear_lines[mark[2]] = true + end + end + assert.is_true(clear_lines[0] ~= nil, 'DiffsClear on diff --combined line') + assert.is_true(clear_lines[1] ~= nil, 'DiffsClear on index line') + assert.is_true(clear_lines[2] ~= nil, 'DiffsClear on --- line') + assert.is_true(clear_lines[3] ~= nil, 'DiffsClear on +++ line') + delete_buffer(bufnr) + end) + + it('applies @attribute.diff at syntax priority to @@@ line for combined diffs', function() + local bufnr = create_buffer({ + '@@@ -1,2 -1,2 +1,3 @@@', + ' local M = {}', + '+ local x = 1', + }) + + local hunk = { + filename = 'test.lua', + lang = 'lua', + prefix_width = 2, + start_line = 1, + lines = { ' local M = {}', '+ local x = 1' }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local has_attr = false + for _, mark in ipairs(extmarks) do + local d = mark[4] + if mark[2] == 0 and d and d.hl_group == '@attribute.diff' and (d.priority or 0) >= 199 then + has_attr = true + end + end + assert.is_true(has_attr, '@attribute.diff at p>=199 on @@@ line') + delete_buffer(bufnr) + end) + + it('applies DiffsClear to @@@ line for combined diffs', function() + local bufnr = create_buffer({ + '@@@ -1,2 -1,2 +1,3 @@@', + ' local M = {}', + '+ local x = 1', + }) + + local hunk = { + filename = 'test.lua', + lang = 'lua', + prefix_width = 2, + start_line = 1, + lines = { ' local M = {}', '+ local x = 1' }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local has_at_clear = false + for _, mark in ipairs(extmarks) do + local d = mark[4] + if mark[2] == 0 and d and d.hl_group == 'DiffsClear' and mark[3] == 0 then + has_at_clear = true + end + end + assert.is_true(has_at_clear, 'DiffsClear on @@@ line') + delete_buffer(bufnr) + end) + + it('applies header diff grammar at syntax priority for combined diffs', function() + local bufnr = create_buffer({ + 'diff --combined lua/merge/target.lua', + 'index abc1234,def5678..ghi9012', + '--- a/lua/merge/target.lua', + '+++ b/lua/merge/target.lua', + '@@@ -1,2 -1,2 +1,3 @@@', + ' local M = {}', + '+ local x = 1', + }) + + local hunk = { + filename = 'lua/merge/target.lua', + lang = 'lua', + prefix_width = 2, + start_line = 5, + lines = { ' local M = {}', '+ local x = 1' }, + header_start_line = 1, + header_lines = { + 'diff --combined lua/merge/target.lua', + 'index abc1234,def5678..ghi9012', + '--- a/lua/merge/target.lua', + '+++ b/lua/merge/target.lua', + }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local high_prio_diff = {} + for _, mark in ipairs(extmarks) do + local d = mark[4] + if mark[2] < 4 and d and d.hl_group and d.hl_group:match('^@.*%.diff$') and (d.priority or 0) >= 199 then + high_prio_diff[mark[2]] = true + end + end + assert.is_true(high_prio_diff[2] ~= nil, 'diff grammar at p>=199 on --- line') + assert.is_true(high_prio_diff[3] ~= nil, 'diff grammar at p>=199 on +++ line') + delete_buffer(bufnr) + end) + + it('@diff.minus wins over @punctuation.special on combined diff headers', function() + local bufnr = create_buffer({ + 'diff --combined lua/merge/target.lua', + 'index abc1234,def5678..ghi9012', + '--- a/lua/merge/target.lua', + '+++ b/lua/merge/target.lua', + '@@@ -1,2 -1,2 +1,3 @@@', + ' local M = {}', + '+ local x = 1', + }) + + local hunk = { + filename = 'lua/merge/target.lua', + lang = 'lua', + prefix_width = 2, + start_line = 5, + lines = { ' local M = {}', '+ local x = 1' }, + header_start_line = 1, + header_lines = { + 'diff --combined lua/merge/target.lua', + 'index abc1234,def5678..ghi9012', + '--- a/lua/merge/target.lua', + '+++ b/lua/merge/target.lua', + }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local minus_prio, punct_prio_minus = 0, 0 + local plus_prio, punct_prio_plus = 0, 0 + for _, mark in ipairs(extmarks) do + local d = mark[4] + if d and d.hl_group then + if mark[2] == 2 then + if d.hl_group == '@diff.minus.diff' then + minus_prio = math.max(minus_prio, d.priority or 0) + elseif d.hl_group == '@punctuation.special.diff' then + punct_prio_minus = math.max(punct_prio_minus, d.priority or 0) + end + elseif mark[2] == 3 then + if d.hl_group == '@diff.plus.diff' then + plus_prio = math.max(plus_prio, d.priority or 0) + elseif d.hl_group == '@punctuation.special.diff' then + punct_prio_plus = math.max(punct_prio_plus, d.priority or 0) + end + end + end + end + assert.is_true(minus_prio > punct_prio_minus, '@diff.minus.diff should beat @punctuation.special.diff on --- line') + assert.is_true(plus_prio > punct_prio_plus, '@diff.plus.diff should beat @punctuation.special.diff on +++ line') + delete_buffer(bufnr) + end) + + it('applies @keyword.diff on index word for combined diffs', function() + local bufnr = create_buffer({ + 'diff --combined lua/merge/target.lua', + 'index abc1234,def5678..ghi9012', + '--- a/lua/merge/target.lua', + '+++ b/lua/merge/target.lua', + '@@@ -1,2 -1,2 +1,3 @@@', + ' local M = {}', + '+ local x = 1', + }) + + local hunk = { + filename = 'lua/merge/target.lua', + lang = 'lua', + prefix_width = 2, + start_line = 5, + lines = { ' local M = {}', '+ local x = 1' }, + header_start_line = 1, + header_lines = { + 'diff --combined lua/merge/target.lua', + 'index abc1234,def5678..ghi9012', + '--- a/lua/merge/target.lua', + '+++ b/lua/merge/target.lua', + }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local has_keyword = false + for _, mark in ipairs(extmarks) do + local d = mark[4] + if mark[2] == 1 and d and d.hl_group == '@keyword.diff' and mark[3] == 0 and (d.end_col or 0) == 5 then + has_keyword = true + end + end + assert.is_true(has_keyword, '@keyword.diff at row 1, cols 0-5') + delete_buffer(bufnr) + end) + + it('applies @constant.diff on result hash for combined diffs', function() + local bufnr = create_buffer({ + 'diff --combined lua/merge/target.lua', + 'index abc1234,def5678..ghi9012', + '--- a/lua/merge/target.lua', + '+++ b/lua/merge/target.lua', + '@@@ -1,2 -1,2 +1,3 @@@', + ' local M = {}', + '+ local x = 1', + }) + + local hunk = { + filename = 'lua/merge/target.lua', + lang = 'lua', + prefix_width = 2, + start_line = 5, + lines = { ' local M = {}', '+ local x = 1' }, + header_start_line = 1, + header_lines = { + 'diff --combined lua/merge/target.lua', + 'index abc1234,def5678..ghi9012', + '--- a/lua/merge/target.lua', + '+++ b/lua/merge/target.lua', + }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local has_constant = false + for _, mark in ipairs(extmarks) do + local d = mark[4] + if mark[2] == 1 and d and d.hl_group == '@constant.diff' and (d.priority or 0) >= 199 then + has_constant = true + end + end + assert.is_true(has_constant, '@constant.diff on result hash') + delete_buffer(bufnr) + end) end) describe('extmark priority', function()