From 7f184309324351c971a0e2ce4bbacc94f9ff84e3 Mon Sep 17 00:00:00 2001 From: CJ van den Berg Date: Sat, 21 Feb 2026 20:43:18 +0100 Subject: [PATCH] feat: add initial injection rendering support --- src/syntax.zig | 263 +++++++++++++++++++++++++++++++++++++++++++- src/treez_dummy.zig | 1 + 2 files changed, 258 insertions(+), 6 deletions(-) diff --git a/src/syntax.zig b/src/syntax.zig index ec15373..b7dc98d 100644 --- a/src/syntax.zig +++ b/src/syntax.zig @@ -20,12 +20,26 @@ const Query = treez.Query; pub const Node = treez.Node; allocator: std.mem.Allocator, +query_cache: *QueryCache, lang: *const Language, parser: *Parser, query: *Query, errors_query: *Query, injections: ?*Query, tree: ?*treez.Tree = null, +injection_list: std.ArrayListUnmanaged(Injection) = .{}, + +pub const Injection = struct { + lang_name: []const u8, + start_point: Point, + end_row: u32, + syntax: *Self, + + fn deinit(self: *Injection, allocator: std.mem.Allocator) void { + self.syntax.destroy(); + allocator.free(self.lang_name); + } +}; pub fn create(file_type: FileType, allocator: std.mem.Allocator, query_cache: *QueryCache) !*Self { const query = try query_cache.get(file_type, .highlights); @@ -36,10 +50,13 @@ pub fn create(file_type: FileType, allocator: std.mem.Allocator, query_cache: *Q errdefer if (injections) |injections_| query_cache.release(injections_, .injections); const self = try allocator.create(Self); errdefer allocator.destroy(self); + const parser = try Parser.create(); + errdefer parser.destroy(); self.* = .{ .allocator = allocator, + .query_cache = query_cache, .lang = file_type.lang_fn() orelse std.debug.panic("tree-sitter parser function failed for language: {s}", .{file_type.name}), - .parser = try Parser.create(), + .parser = parser, .query = query, .errors_query = errors_query, .injections = injections, @@ -58,25 +75,35 @@ pub fn create_guess_file_type_static(allocator: std.mem.Allocator, content: []co return create(file_type, allocator, query_cache); } -pub fn destroy(self: *Self, query_cache: *QueryCache) void { +pub fn destroy(self: *Self) void { + self.clear_injections(); + self.injection_list.deinit(self.allocator); if (self.tree) |tree| tree.destroy(); - query_cache.release(self.query, .highlights); - query_cache.release(self.errors_query, .highlights); - if (self.injections) |injections| query_cache.release(injections, .injections); + self.query_cache.release(self.query, .highlights); + self.query_cache.release(self.errors_query, .highlights); + if (self.injections) |injections| self.query_cache.release(injections, .injections); self.parser.destroy(); self.allocator.destroy(self); } pub fn reset(self: *Self) void { + self.clear_injections(); if (self.tree) |tree| { tree.destroy(); self.tree = null; } } +fn clear_injections(self: *Self) void { + for (self.injection_list.items) |*inj| inj.deinit(self.allocator); + self.injection_list.clearRetainingCapacity(); +} + pub fn refresh_full(self: *Self, content: []const u8) !void { - self.reset(); + self.clear_injections(); + if (self.tree) |tree| tree.destroy(); self.tree = try self.parser.parseString(null, content); + try self.refresh_injections(content); } pub fn edit(self: *Self, ed: Edit) void { @@ -140,6 +167,170 @@ pub fn refresh_from_string(self: *Self, content: [:0]const u8) !void { .encoding = .utf_8, }; self.tree = try self.parser.parse(old_tree, input); + try self.refresh_injections(content); +} + +pub fn refresh_injections(self: *Self, content: []const u8) !void { + self.clear_injections(); + + const injections_query = self.injections orelse return; + const tree = self.tree orelse return; + + const cursor = try Query.Cursor.create(); + defer cursor.destroy(); + cursor.execute(injections_query, tree.getRootNode()); + + while (cursor.nextMatch()) |match| { + var lang_range: ?Range = null; + var content_range: ?Range = null; + + for (match.captures()) |capture| { + const name = injections_query.getCaptureNameForId(capture.id); + if (std.mem.eql(u8, name, "injection.language")) { + lang_range = capture.node.getRange(); + } else if (std.mem.eql(u8, name, "injection.content")) { + content_range = capture.node.getRange(); + } + } + + const crange = content_range orelse continue; + + // Determine language name: dynamic @injection.language capture takes priority, + // then fall back to a static #set! injection.language predicate. + const lang_name: []const u8 = if (lang_range) |lr| + extract_node_text(content, lr) orelse continue + else + get_static_injection_language(injections_query, match.pattern_index) orelse continue; + + if (lang_name.len == 0) continue; + + const file_type = FileType.get_by_name_static(lang_name) orelse + FileType.get_by_name_static(normalize_lang_name(lang_name)) orelse + continue; + + const start_byte = crange.start_byte; + const end_byte = crange.end_byte; + if (start_byte >= end_byte or end_byte > content.len) continue; + const child_content = content[start_byte..end_byte]; + + const child = try Self.create(file_type, self.allocator, self.query_cache); + errdefer child.destroy(); + if (child.tree) |t| t.destroy(); + child.tree = try child.parser.parseString(null, child_content); + + const lang_name_owned = try self.allocator.dupe(u8, lang_name); + errdefer self.allocator.free(lang_name_owned); + + try self.injection_list.append(self.allocator, .{ + .lang_name = lang_name_owned, + .start_point = crange.start_point, + .end_row = crange.end_point.row, + .syntax = child, + }); + } +} + +fn extract_node_text(content: []const u8, range: Range) ?[]const u8 { + const s = range.start_byte; + const e = range.end_byte; + if (s >= e or e > content.len) return null; + return std.mem.trim(u8, content[s..e], &std.ascii.whitespace); +} + +/// Normalize common language name aliases found in markdown +/// This should probably be in file_types +fn normalize_lang_name(name: []const u8) []const u8 { + const aliases = .{ + .{ "js", "javascript" }, + .{ "ts", "typescript" }, + .{ "py", "python" }, + .{ "rb", "ruby" }, + .{ "sh", "bash" }, + .{ "shell", "bash" }, + .{ "zsh", "bash" }, + .{ "c++", "cpp" }, + .{ "cs", "c-sharp" }, + .{ "csharp", "c-sharp" }, + .{ "yml", "yaml" }, + .{ "md", "markdown" }, + .{ "rs", "rust" }, + }; + inline for (aliases) |alias| { + if (std.ascii.eqlIgnoreCase(name, alias[0])) return alias[1]; + } + return name; +} + +/// Read a static `#set! injection.language "name"` predicate from the query's +/// internal predicate table for the given pattern index, returning the language +/// name string if found or null otherwise. +/// +/// This accesses TSQuery internals via the same cast used in ts_bin_query_gen.zig +fn get_static_injection_language(query: *Query, pattern_idx: u16) ?[]const u8 { + const tss = @import("ts_serializer.zig"); + const ts_query: *tss.TSQuery = @ptrCast(@alignCast(query)); + + const patterns = ts_query.patterns; + if (patterns.contents == null or @as(u32, pattern_idx) >= patterns.size) return null; + const pattern_arr: [*]tss.QueryPattern = @ptrCast(patterns.contents.?); + const pattern = pattern_arr[pattern_idx]; + + const pred_steps = ts_query.predicate_steps; + if (pred_steps.contents == null or pred_steps.size == 0) return null; + const steps_arr: [*]tss.PredicateStep = @ptrCast(pred_steps.contents.?); + + const pred_values = ts_query.predicate_values; + if (pred_values.slices.contents == null or pred_values.characters.contents == null) return null; + const slices_arr: [*]tss.Slice = @ptrCast(pred_values.slices.contents.?); + const chars: [*]u8 = @ptrCast(pred_values.characters.contents.?); + + // Walk the predicate steps for this pattern looking for the sequence: + // string("set!") string("injection.language") string("") done + const step_start = pattern.predicate_steps.offset; + const step_end = step_start + pattern.predicate_steps.length; + + var i = step_start; + while (i < step_end) { + const s = steps_arr[i]; + if (s.type == .done) { + i += 1; + continue; + } + + // We need at least 4 steps: 3 strings + done. + if (i + 3 >= step_end) break; + + const s0 = steps_arr[i]; + const s1 = steps_arr[i + 1]; + const s2 = steps_arr[i + 2]; + const s3 = steps_arr[i + 3]; + + if (s0.type == .string and s1.type == .string and + s2.type == .string and s3.type == .done) + { + if (s0.value_id < pred_values.slices.size and + s1.value_id < pred_values.slices.size and + s2.value_id < pred_values.slices.size) + { + const sl0 = slices_arr[s0.value_id]; + const sl1 = slices_arr[s1.value_id]; + const sl2 = slices_arr[s2.value_id]; + const n0 = chars[sl0.offset .. sl0.offset + sl0.length]; + const n1 = chars[sl1.offset .. sl1.offset + sl1.length]; + const n2 = chars[sl2.offset .. sl2.offset + sl2.length]; + if (std.mem.eql(u8, n0, "set!") and + std.mem.eql(u8, n1, "injection.language")) + { + return n2; + } + } + } + + // Advance past this predicate group to the next .done boundary. + while (i < step_end and steps_arr[i].type != .done) i += 1; + if (i < step_end) i += 1; + } + return null; } fn find_line_begin(s: []const u8, line: usize) ?usize { @@ -160,6 +351,66 @@ fn CallBack(comptime T: type) type { } pub fn render(self: *const Self, ctx: anytype, comptime cb: CallBack(@TypeOf(ctx)), range: ?Range) !void { + try self.render_highlights_only(ctx, cb, range); + + for (self.injection_list.items) |*inj| { + if (range) |r| { + if (inj.end_row < r.start_point.row) continue; + if (inj.start_point.row > r.end_point.row) continue; + } + + const child_range: ?Range = if (range) |r| blk: { + const child_start_row: u32 = if (r.start_point.row > inj.start_point.row) + r.start_point.row - inj.start_point.row + else + 0; + const child_end_row: u32 = r.end_point.row - inj.start_point.row; + break :blk .{ + .start_point = .{ .row = child_start_row, .column = 0 }, + .end_point = .{ .row = child_end_row, .column = 0 }, + .start_byte = 0, + .end_byte = 0, + }; + } else null; + + // Wrap the context so we can translate local ranges to absolute + // document coordinates before forwarding to the callback + const InjCtx = struct { + parent_ctx: @TypeOf(ctx), + inj: *const Injection, + + fn translated_cb( + self_: *const @This(), + child_sel: Range, + scope: []const u8, + id: u32, + capture_idx: usize, + node: *const Node, + ) error{Stop}!void { + const start_row = child_sel.start_point.row + self_.inj.start_point.row; + const end_row = child_sel.end_point.row + self_.inj.start_point.row; + // Column offset only applies on the very first line of the injection + const start_col = child_sel.start_point.column + + if (child_sel.start_point.row == 0) self_.inj.start_point.column else 0; + const end_col = child_sel.end_point.column + + if (child_sel.end_point.row == 0) self_.inj.start_point.column else 0; + + const doc_range: Range = .{ + .start_point = .{ .row = start_row, .column = start_col }, + .end_point = .{ .row = end_row, .column = end_col }, + .start_byte = child_sel.start_byte, + .end_byte = child_sel.end_byte, + }; + try cb(self_.parent_ctx, doc_range, scope, id, capture_idx, node); + } + }; + + var inj_ctx: InjCtx = .{ .parent_ctx = ctx, .inj = inj }; + try inj.syntax.render_highlights_only(&inj_ctx, InjCtx.translated_cb, child_range); + } +} + +fn render_highlights_only(self: *const Self, ctx: anytype, comptime cb: CallBack(@TypeOf(ctx)), range: ?Range) !void { const cursor = try Query.Cursor.create(); defer cursor.destroy(); const tree = self.tree orelse return; diff --git a/src/treez_dummy.zig b/src/treez_dummy.zig index e66f06b..58d032a 100644 --- a/src/treez_dummy.zig +++ b/src/treez_dummy.zig @@ -64,6 +64,7 @@ pub const Query = struct { pub fn destroy(_: *@This()) void {} pub const Match = struct { + pattern_index: u16 = 0, pub fn captures(_: *@This()) []Capture { return &[_]Capture{}; }