diff --git a/lua/diffs/highlight.lua b/lua/diffs/highlight.lua index bf9063d..306b0e5 100644 --- a/lua/diffs/highlight.lua +++ b/lua/diffs/highlight.lua @@ -41,9 +41,15 @@ local PRIORITY_CHAR_BG = 201 ---@param col_offset integer ---@param text string ---@param lang string +---@param context_lines? string[] ---@return integer -local function highlight_text(bufnr, ns, hunk, col_offset, text, lang) - local ok, parser_obj = pcall(vim.treesitter.get_string_parser, text, lang) +local function highlight_text(bufnr, ns, hunk, col_offset, text, lang, context_lines) + local parse_text = text + if context_lines and #context_lines > 0 then + parse_text = text .. '\n' .. table.concat(context_lines, '\n') + end + + local ok, parser_obj = pcall(vim.treesitter.get_string_parser, parse_text, lang) if not ok or not parser_obj then return 0 end @@ -61,24 +67,26 @@ local function highlight_text(bufnr, ns, hunk, col_offset, text, lang) local extmark_count = 0 local header_line = hunk.start_line - 1 - for id, node, metadata in query:iter_captures(trees[1]:root(), text) do - local capture_name = '@' .. query.captures[id] .. '.' .. lang - local sr, sc, er, ec = node:range() + for id, node, metadata in query:iter_captures(trees[1]:root(), parse_text) do + local sr, sc, _, ec = node:range() + if sr == 0 then + local capture_name = '@' .. query.captures[id] .. '.' .. lang - local buf_sr = header_line + sr - local buf_er = header_line + er - local buf_sc = col_offset + sc - local buf_ec = col_offset + ec + local buf_sr = header_line + local buf_er = header_line + local buf_sc = col_offset + sc + local buf_ec = col_offset + ec - local priority = lang == 'diff' and (tonumber(metadata.priority) or 100) or PRIORITY_SYNTAX + local priority = lang == 'diff' and (tonumber(metadata.priority) or 100) or PRIORITY_SYNTAX - pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_sr, buf_sc, { - end_row = buf_er, - end_col = buf_ec, - hl_group = capture_name, - priority = priority, - }) - extmark_count = extmark_count + 1 + pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_sr, buf_sc, { + end_row = buf_er, + end_col = buf_ec, + hl_group = capture_name, + priority = priority, + }) + extmark_count = extmark_count + 1 + end end return extmark_count @@ -360,8 +368,15 @@ function M.highlight_hunk(bufnr, ns, hunk, opts) hl_group = 'DiffsClear', priority = PRIORITY_CLEAR, }) - local header_extmarks = - highlight_text(bufnr, ns, hunk, hunk.header_context_col, hunk.header_context, hunk.lang) + local header_extmarks = highlight_text( + bufnr, + ns, + hunk, + hunk.header_context_col, + hunk.header_context, + hunk.lang, + new_code + ) if header_extmarks > 0 then dbg('header %s:%d applied %d extmarks', hunk.filename, hunk.start_line, header_extmarks) end diff --git a/spec/highlight_spec.lua b/spec/highlight_spec.lua index 2f95d02..0b9051e 100644 --- a/spec/highlight_spec.lua +++ b/spec/highlight_spec.lua @@ -220,6 +220,40 @@ describe('highlight', function() delete_buffer(bufnr) end) + it('highlights function keyword in header context', function() + local bufnr = create_buffer({ + '@@ -5,3 +5,4 @@ function M.setup()', + ' local x = 1', + '+local y = 2', + ' return x', + }) + + local hunk = { + filename = 'test.lua', + lang = 'lua', + start_line = 1, + header_context = 'function M.setup()', + header_context_col = 18, + lines = { ' local x = 1', '+local y = 2', ' return x' }, + } + + highlight.highlight_hunk(bufnr, ns, hunk, default_opts()) + + local extmarks = get_extmarks(bufnr) + local has_keyword_function = false + for _, mark in ipairs(extmarks) do + if mark[2] == 0 and mark[4] and mark[4].hl_group then + local hl = mark[4].hl_group + if hl == '@keyword.function.lua' or hl == '@keyword.lua' then + has_keyword_function = true + break + end + end + end + assert.is_true(has_keyword_function) + delete_buffer(bufnr) + end) + it('does not highlight header when no header_context', function() local bufnr = create_buffer({ '@@ -10,3 +10,4 @@',