From f35287bdfebbef335301a08aed2fbae970547d03 Mon Sep 17 00:00:00 2001 From: Barrett Ruth Date: Wed, 4 Feb 2026 14:49:27 -0500 Subject: [PATCH] feat: treesitter highlighting for diff headers Apply treesitter highlighting to diff metadata lines (diff --git, index, ---, +++) using the diff language parser. Header info is attached only to the first hunk of each file to avoid duplicate highlighting. Based on PR #52 by @phanen with fixes: - header_lines now only contains diff metadata, not hunk content - header info attached only to first hunk per file - removed arbitrary hunk count restriction --- lua/diffs/highlight.lua | 25 +++++- lua/diffs/parser.lua | 27 ++++++- spec/highlight_spec.lua | 170 ++++++++++++++++++++++++++++++++++++++++ spec/parser_spec.lua | 70 +++++++++++++++++ 4 files changed, 287 insertions(+), 5 deletions(-) diff --git a/lua/diffs/highlight.lua b/lua/diffs/highlight.lua index e34931b..e814dc0 100644 --- a/lua/diffs/highlight.lua +++ b/lua/diffs/highlight.lua @@ -57,8 +57,9 @@ end ---@param ns integer ---@param hunk diffs.Hunk ---@param code_lines string[] +---@param col_offset integer? ---@return integer -local function highlight_treesitter(bufnr, ns, hunk, code_lines) +local function highlight_treesitter(bufnr, ns, hunk, code_lines, col_offset) local lang = hunk.lang if not lang then return 0 @@ -101,6 +102,8 @@ local function highlight_treesitter(bufnr, ns, hunk, code_lines) end end + col_offset = col_offset or 1 + local extmark_count = 0 for id, node, _ in query:iter_captures(trees[1]:root(), code) do local capture_name = '@' .. query.captures[id] @@ -108,8 +111,8 @@ local function highlight_treesitter(bufnr, ns, hunk, code_lines) local buf_sr = hunk.start_line + sr local buf_er = hunk.start_line + er - local buf_sc = sc + 1 - local buf_ec = ec + 1 + local buf_sc = sc + col_offset + local buf_ec = ec + col_offset pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_sr, buf_sc, { end_row = buf_er, @@ -257,6 +260,22 @@ function M.highlight_hunk(bufnr, ns, hunk, opts) extmark_count = highlight_vim_syntax(bufnr, ns, hunk, code_lines) end + if + hunk.header_start_line + and hunk.header_lines + and #hunk.header_lines > 0 + and opts.highlights.treesitter.enabled + then + extmark_count = extmark_count + + highlight_treesitter(bufnr, ns, { + filename = hunk.filename, + start_line = hunk.header_start_line - 1, + lang = 'diff', + lines = hunk.header_lines, + header_lines = {}, + }, hunk.header_lines, 0) + end + local syntax_applied = extmark_count > 0 for i, line in ipairs(hunk.lines) do diff --git a/lua/diffs/parser.lua b/lua/diffs/parser.lua index ca34ec3..bbb8ab5 100644 --- a/lua/diffs/parser.lua +++ b/lua/diffs/parser.lua @@ -6,6 +6,8 @@ ---@field header_context string? ---@field header_context_col integer? ---@field lines string[] +---@field header_start_line integer? +---@field header_lines string[]? local M = {} @@ -58,10 +60,16 @@ function M.parse_buffer(bufnr) local hunk_header_context_col = nil ---@type string[] local hunk_lines = {} + ---@type integer? + local hunk_count = nil + ---@type integer? + local header_start = nil + ---@type string[] + local header_lines = {} local function flush_hunk() if hunk_start and #hunk_lines > 0 and (current_lang or current_ft) then - table.insert(hunks, { + local hunk = { filename = current_filename, ft = current_ft, lang = current_lang, @@ -69,7 +77,12 @@ function M.parse_buffer(bufnr) header_context = hunk_header_context, header_context_col = hunk_header_context_col, lines = hunk_lines, - }) + } + if hunk_count == 1 and header_start and #header_lines > 0 then + hunk.header_start_line = header_start + hunk.header_lines = header_lines + end + table.insert(hunks, hunk) end hunk_start = nil hunk_header_context = nil @@ -89,6 +102,9 @@ function M.parse_buffer(bufnr) elseif current_ft then dbg('file: %s -> ft: %s (no ts parser)', filename, current_ft) end + hunk_count = 0 + header_start = i + header_lines = {} elseif line:match('^@@.-@@') then flush_hunk() hunk_start = i @@ -97,6 +113,9 @@ function M.parse_buffer(bufnr) hunk_header_context = context hunk_header_context_col = #prefix end + if hunk_count then + hunk_count = hunk_count + 1 + end elseif hunk_start then local prefix = line:sub(1, 1) if prefix == ' ' or prefix == '+' or prefix == '-' then @@ -112,8 +131,12 @@ function M.parse_buffer(bufnr) current_filename = nil current_ft = nil current_lang = nil + header_start = nil end end + if header_start and not hunk_start then + table.insert(header_lines, line) + end end flush_hunk() diff --git a/spec/highlight_spec.lua b/spec/highlight_spec.lua index 3d6df08..f55d917 100644 --- a/spec/highlight_spec.lua +++ b/spec/highlight_spec.lua @@ -729,6 +729,176 @@ describe('highlight', function() end) end) + describe('diff header highlighting', function() + local ns + + before_each(function() + ns = vim.api.nvim_create_namespace('diffs_test_header') + end) + + local function create_buffer(lines) + local bufnr = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) + return bufnr + end + + local function delete_buffer(bufnr) + if vim.api.nvim_buf_is_valid(bufnr) then + vim.api.nvim_buf_delete(bufnr, { force = true }) + end + end + + local function get_extmarks(bufnr) + return vim.api.nvim_buf_get_extmarks(bufnr, ns, 0, -1, { details = true }) + end + + local function default_opts() + return { + hide_prefix = false, + highlights = { + background = false, + gutter = false, + treesitter = { enabled = true, max_lines = 500 }, + vim = { enabled = false, max_lines = 200 }, + }, + } + end + + local function has_diff_parser() + return pcall(vim.treesitter.language.inspect, 'diff') + end + + it('applies treesitter extmarks to diff header lines', function() + if not has_diff_parser() then + pending('diff treesitter parser not installed') + return + end + + local bufnr = create_buffer({ + 'diff --git a/parser.lua b/parser.lua', + 'index 3e8afa0..018159c 100644', + '--- a/parser.lua', + '+++ b/parser.lua', + '@@ -1,2 +1,3 @@', + ' local M = {}', + '+local x = 1', + }) + + local hunk = { + filename = 'parser.lua', + lang = 'lua', + start_line = 5, + lines = { ' local M = {}', '+local x = 1' }, + header_start_line = 1, + header_lines = { + 'diff --git a/parser.lua b/parser.lua', + 'index 3e8afa0..018159c 100644', + '--- a/parser.lua', + '+++ b/parser.lua', + }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local header_extmarks = {} + for _, mark in ipairs(extmarks) do + if mark[2] < 4 and mark[4] and mark[4].hl_group then + table.insert(header_extmarks, mark) + end + end + + assert.is_true(#header_extmarks > 0) + + local has_function_hl = false + local has_keyword_hl = false + for _, mark in ipairs(header_extmarks) do + local hl = mark[4].hl_group + if hl == '@function' or hl == '@function.diff' then + has_function_hl = true + end + if hl == '@keyword' or hl == '@keyword.diff' then + has_keyword_hl = true + end + end + assert.is_true(has_function_hl or has_keyword_hl) + delete_buffer(bufnr) + end) + + it('does not apply header highlights when header_lines missing', function() + local bufnr = create_buffer({ + '@@ -1,2 +1,3 @@', + ' local M = {}', + '+local x = 1', + }) + + local hunk = { + filename = 'parser.lua', + lang = 'lua', + start_line = 1, + lines = { ' local M = {}', '+local x = 1' }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local header_extmarks = 0 + for _, mark in ipairs(extmarks) do + if mark[2] < 0 and mark[4] and mark[4].hl_group then + header_extmarks = header_extmarks + 1 + end + end + assert.are.equal(0, header_extmarks) + delete_buffer(bufnr) + end) + + it('does not apply header highlights when treesitter disabled', function() + if not has_diff_parser() then + pending('diff treesitter parser not installed') + return + end + + local bufnr = create_buffer({ + 'diff --git a/parser.lua b/parser.lua', + 'index 3e8afa0..018159c 100644', + '--- a/parser.lua', + '+++ b/parser.lua', + '@@ -1,2 +1,3 @@', + ' local M = {}', + '+local x = 1', + }) + + local hunk = { + filename = 'parser.lua', + lang = 'lua', + start_line = 5, + lines = { ' local M = {}', '+local x = 1' }, + header_start_line = 1, + header_lines = { + 'diff --git a/parser.lua b/parser.lua', + 'index 3e8afa0..018159c 100644', + '--- a/parser.lua', + '+++ b/parser.lua', + }, + } + + local opts = default_opts() + opts.highlights.treesitter.enabled = false + + highlight.highlight_hunk(bufnr, ns, hunk, opts) + + local extmarks = get_extmarks(bufnr) + local header_extmarks = 0 + for _, mark in ipairs(extmarks) do + if mark[2] < 4 and mark[4] and mark[4].hl_group and mark[4].hl_group:match('^@') then + header_extmarks = header_extmarks + 1 + end + end + assert.are.equal(0, header_extmarks) + delete_buffer(bufnr) + end) + end) + describe('coalesce_syntax_spans', function() it('coalesces adjacent chars with same hl group', function() local function query_fn(_line, _col) diff --git a/spec/parser_spec.lua b/spec/parser_spec.lua index c2c1c3a..9d2c01b 100644 --- a/spec/parser_spec.lua +++ b/spec/parser_spec.lua @@ -215,5 +215,75 @@ describe('parser', function() assert.are.equal(1, #hunks[2].lines) delete_buffer(bufnr) end) + + it('attaches header_lines to first hunk only', function() + local bufnr = create_buffer({ + 'diff --git a/parser.lua b/parser.lua', + 'index 3e8afa0..018159c 100644', + '--- a/parser.lua', + '+++ b/parser.lua', + '@@ -1,2 +1,3 @@', + ' local M = {}', + '+local x = 1', + '@@ -10,2 +11,3 @@', + ' function M.foo()', + '+ return true', + ' end', + }) + local hunks = parser.parse_buffer(bufnr) + + assert.are.equal(2, #hunks) + assert.is_not_nil(hunks[1].header_start_line) + assert.is_not_nil(hunks[1].header_lines) + assert.are.equal(1, hunks[1].header_start_line) + assert.is_nil(hunks[2].header_start_line) + assert.is_nil(hunks[2].header_lines) + delete_buffer(bufnr) + end) + + it('header_lines contains only diff metadata, not hunk content', function() + local bufnr = create_buffer({ + 'diff --git a/parser.lua b/parser.lua', + 'index 3e8afa0..018159c 100644', + '--- a/parser.lua', + '+++ b/parser.lua', + '@@ -1,2 +1,3 @@', + ' local M = {}', + '+local x = 1', + }) + local hunks = parser.parse_buffer(bufnr) + + assert.are.equal(1, #hunks) + assert.are.equal(4, #hunks[1].header_lines) + assert.are.equal('diff --git a/parser.lua b/parser.lua', hunks[1].header_lines[1]) + assert.are.equal('index 3e8afa0..018159c 100644', hunks[1].header_lines[2]) + assert.are.equal('--- a/parser.lua', hunks[1].header_lines[3]) + assert.are.equal('+++ b/parser.lua', hunks[1].header_lines[4]) + delete_buffer(bufnr) + end) + + it('handles fugitive status format with diff headers', function() + local bufnr = create_buffer({ + 'Head: main', + 'Push: origin/main', + '', + 'Unstaged (1)', + 'M parser.lua', + 'diff --git a/parser.lua b/parser.lua', + 'index 3e8afa0..018159c 100644', + '--- a/parser.lua', + '+++ b/parser.lua', + '@@ -1,2 +1,3 @@', + ' local M = {}', + '+local x = 1', + }) + local hunks = parser.parse_buffer(bufnr) + + assert.are.equal(1, #hunks) + assert.are.equal(6, hunks[1].header_start_line) + assert.are.equal(4, #hunks[1].header_lines) + assert.are.equal('diff --git a/parser.lua b/parser.lua', hunks[1].header_lines[1]) + delete_buffer(bufnr) + end) end) end)