diff --git a/lua/fugitive-ts/highlight.lua b/lua/fugitive-ts/highlight.lua index 26b2548..e4aaed0 100644 --- a/lua/fugitive-ts/highlight.lua +++ b/lua/fugitive-ts/highlight.lua @@ -125,6 +125,47 @@ local function highlight_treesitter(bufnr, ns, hunk, code_lines) return extmark_count end +---@alias fugitive-ts.SyntaxQueryFn fun(line: integer, col: integer): integer, string + +---@param query_fn fugitive-ts.SyntaxQueryFn +---@param code_lines string[] +---@return {line: integer, col_start: integer, col_end: integer, hl_name: string}[] +function M.coalesce_syntax_spans(query_fn, code_lines) + local spans = {} + for i, line in ipairs(code_lines) do + local col = 1 + local line_len = #line + + while col <= line_len do + local syn_id, hl_name = query_fn(i, col) + if syn_id == 0 then + col = col + 1 + else + local span_start = col + + col = col + 1 + while col <= line_len do + local next_id, next_name = query_fn(i, col) + if next_id == 0 or next_name ~= hl_name then + break + end + col = col + 1 + end + + if hl_name ~= '' then + table.insert(spans, { + line = i, + col_start = span_start, + col_end = col, + hl_name = hl_name, + }) + end + end + end + end + return spans +end + ---@param bufnr integer ---@param ns integer ---@param hunk fugitive-ts.Hunk @@ -144,53 +185,39 @@ local function highlight_vim_syntax(bufnr, ns, hunk, code_lines) vim.api.nvim_buf_set_lines(scratch, 0, -1, false, code_lines) vim.api.nvim_set_option_value('bufhidden', 'wipe', { buf = scratch }) - local extmark_count = 0 + local spans = {} vim.api.nvim_buf_call(scratch, function() vim.cmd('setlocal syntax=' .. ft) vim.cmd('redraw') - for i, line in ipairs(code_lines) do - local col = 1 - local line_len = #line - - while col <= line_len do - local syn_id = vim.fn.synID(i, col, 1) - if syn_id == 0 then - col = col + 1 - else - local hl_name = vim.fn.synIDattr(vim.fn.synIDtrans(syn_id), 'name') - local span_start = col - - col = col + 1 - while col <= line_len do - local next_id = vim.fn.synID(i, col, 1) - if next_id == 0 then - break - end - local next_name = vim.fn.synIDattr(vim.fn.synIDtrans(next_id), 'name') - if next_name ~= hl_name then - break - end - col = col + 1 - end - - if hl_name ~= '' then - local buf_line = hunk.start_line + i - 1 - pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, span_start, { - end_col = col, - hl_group = hl_name, - priority = 200, - }) - extmark_count = extmark_count + 1 - end - end + ---@param line integer + ---@param col integer + ---@return integer, string + local function query_fn(line, col) + local syn_id = vim.fn.synID(line, col, 1) + if syn_id == 0 then + return 0, '' end + return syn_id, vim.fn.synIDattr(vim.fn.synIDtrans(syn_id), 'name') end + + spans = M.coalesce_syntax_spans(query_fn, code_lines) end) vim.api.nvim_buf_delete(scratch, { force = true }) + local extmark_count = 0 + for _, span in ipairs(spans) do + local buf_line = hunk.start_line + span.line - 1 + pcall(vim.api.nvim_buf_set_extmark, bufnr, ns, buf_line, span.col_start, { + end_col = span.col_end, + hl_group = span.hl_name, + priority = 200, + }) + extmark_count = extmark_count + 1 + end + return extmark_count end diff --git a/spec/highlight_spec.lua b/spec/highlight_spec.lua index a0f7816..b4862e6 100644 --- a/spec/highlight_spec.lua +++ b/spec/highlight_spec.lua @@ -533,6 +533,19 @@ describe('highlight', function() end) it('applies vim syntax extmarks when vim.enabled and no TS parser', function() + local orig_synID = vim.fn.synID + local orig_synIDtrans = vim.fn.synIDtrans + local orig_synIDattr = vim.fn.synIDattr + vim.fn.synID = function(_line, _col, _trans) + return 1 + end + vim.fn.synIDtrans = function(id) + return id + end + vim.fn.synIDattr = function(_id, _what) + return 'Identifier' + end + local bufnr = create_buffer({ '@@ -1,1 +1,2 @@', ' local x = 1', @@ -549,6 +562,10 @@ describe('highlight', function() highlight.highlight_hunk(bufnr, ns, hunk, default_opts({ vim = { enabled = true } })) + vim.fn.synID = orig_synID + vim.fn.synIDtrans = orig_synIDtrans + vim.fn.synIDattr = orig_synIDattr + local extmarks = get_extmarks(bufnr) local has_syntax_hl = false for _, mark in ipairs(extmarks) do @@ -654,6 +671,19 @@ describe('highlight', function() end) it('applies Normal blanking for vim fallback hunks', function() + local orig_synID = vim.fn.synID + local orig_synIDtrans = vim.fn.synIDtrans + local orig_synIDattr = vim.fn.synIDattr + vim.fn.synID = function(_line, _col, _trans) + return 1 + end + vim.fn.synIDtrans = function(id) + return id + end + vim.fn.synIDattr = function(_id, _what) + return 'Identifier' + end + local bufnr = create_buffer({ '@@ -1,1 +1,2 @@', ' local x = 1', @@ -670,6 +700,10 @@ describe('highlight', function() highlight.highlight_hunk(bufnr, ns, hunk, default_opts({ vim = { enabled = true } })) + vim.fn.synID = orig_synID + vim.fn.synIDtrans = orig_synIDtrans + vim.fn.synIDattr = orig_synIDattr + local extmarks = get_extmarks(bufnr) local has_normal = false for _, mark in ipairs(extmarks) do @@ -682,4 +716,57 @@ describe('highlight', function() delete_buffer(bufnr) end) end) + + describe('coalesce_syntax_spans', function() + it('coalesces adjacent chars with same hl group', function() + local function query_fn(_line, _col) + return 1, 'Keyword' + end + local spans = highlight.coalesce_syntax_spans(query_fn, { 'hello' }) + assert.are.equal(1, #spans) + assert.are.equal(1, spans[1].col_start) + assert.are.equal(6, spans[1].col_end) + assert.are.equal('Keyword', spans[1].hl_name) + end) + + it('splits spans at hl group boundaries', function() + local function query_fn(_line, col) + if col <= 3 then + return 1, 'Keyword' + end + return 2, 'String' + end + local spans = highlight.coalesce_syntax_spans(query_fn, { 'abcdef' }) + assert.are.equal(2, #spans) + assert.are.equal('Keyword', spans[1].hl_name) + assert.are.equal(1, spans[1].col_start) + assert.are.equal(4, spans[1].col_end) + assert.are.equal('String', spans[2].hl_name) + assert.are.equal(4, spans[2].col_start) + assert.are.equal(7, spans[2].col_end) + end) + + it('skips syn_id 0 gaps', function() + local function query_fn(_line, col) + if col == 2 or col == 3 then + return 0, '' + end + return 1, 'Identifier' + end + local spans = highlight.coalesce_syntax_spans(query_fn, { 'abcd' }) + assert.are.equal(2, #spans) + assert.are.equal(1, spans[1].col_start) + assert.are.equal(2, spans[1].col_end) + assert.are.equal(4, spans[2].col_start) + assert.are.equal(5, spans[2].col_end) + end) + + it('skips empty hl_name spans', function() + local function query_fn(_line, _col) + return 1, '' + end + local spans = highlight.coalesce_syntax_spans(query_fn, { 'abc' }) + assert.are.equal(0, #spans) + end) + end) end) diff --git a/spec/parser_spec.lua b/spec/parser_spec.lua index 6532b25..d4c1d2b 100644 --- a/spec/parser_spec.lua +++ b/spec/parser_spec.lua @@ -76,6 +76,25 @@ describe('parser', function() end) it('detects hunks across multiple files', function() + local orig_get_lang = vim.treesitter.language.get_lang + local orig_inspect = vim.treesitter.language.inspect + vim.treesitter.language.get_lang = function(ft) + local result = orig_get_lang(ft) + if result then + return result + end + if ft == 'python' then + return 'python' + end + return nil + end + vim.treesitter.language.inspect = function(lang) + if lang == 'python' then + return {} + end + return orig_inspect(lang) + end + local bufnr = create_buffer({ 'M lua/foo.lua', '@@ -1,1 +1,2 @@', @@ -88,6 +107,9 @@ describe('parser', function() }) local hunks = parser.parse_buffer(bufnr) + vim.treesitter.language.get_lang = orig_get_lang + vim.treesitter.language.inspect = orig_inspect + assert.are.equal(2, #hunks) assert.are.equal('lua/foo.lua', hunks[1].filename) assert.are.equal('lua', hunks[1].lang)