feat: add initial injection rendering support

This commit is contained in:
CJ van den Berg 2026-02-21 20:43:18 +01:00
parent f7af9e1c0f
commit 7f18430932
Signed by: neurocyte
GPG key ID: 8EB1E1BB660E3FB9
2 changed files with 258 additions and 6 deletions

View file

@ -20,12 +20,26 @@ const Query = treez.Query;
pub const Node = treez.Node;
allocator: std.mem.Allocator,
query_cache: *QueryCache,
lang: *const Language,
parser: *Parser,
query: *Query,
errors_query: *Query,
injections: ?*Query,
tree: ?*treez.Tree = null,
injection_list: std.ArrayListUnmanaged(Injection) = .{},
pub const Injection = struct {
lang_name: []const u8,
start_point: Point,
end_row: u32,
syntax: *Self,
fn deinit(self: *Injection, allocator: std.mem.Allocator) void {
self.syntax.destroy();
allocator.free(self.lang_name);
}
};
pub fn create(file_type: FileType, allocator: std.mem.Allocator, query_cache: *QueryCache) !*Self {
const query = try query_cache.get(file_type, .highlights);
@ -36,10 +50,13 @@ pub fn create(file_type: FileType, allocator: std.mem.Allocator, query_cache: *Q
errdefer if (injections) |injections_| query_cache.release(injections_, .injections);
const self = try allocator.create(Self);
errdefer allocator.destroy(self);
const parser = try Parser.create();
errdefer parser.destroy();
self.* = .{
.allocator = allocator,
.query_cache = query_cache,
.lang = file_type.lang_fn() orelse std.debug.panic("tree-sitter parser function failed for language: {s}", .{file_type.name}),
.parser = try Parser.create(),
.parser = parser,
.query = query,
.errors_query = errors_query,
.injections = injections,
@ -58,25 +75,35 @@ pub fn create_guess_file_type_static(allocator: std.mem.Allocator, content: []co
return create(file_type, allocator, query_cache);
}
pub fn destroy(self: *Self, query_cache: *QueryCache) void {
pub fn destroy(self: *Self) void {
self.clear_injections();
self.injection_list.deinit(self.allocator);
if (self.tree) |tree| tree.destroy();
query_cache.release(self.query, .highlights);
query_cache.release(self.errors_query, .highlights);
if (self.injections) |injections| query_cache.release(injections, .injections);
self.query_cache.release(self.query, .highlights);
self.query_cache.release(self.errors_query, .highlights);
if (self.injections) |injections| self.query_cache.release(injections, .injections);
self.parser.destroy();
self.allocator.destroy(self);
}
pub fn reset(self: *Self) void {
self.clear_injections();
if (self.tree) |tree| {
tree.destroy();
self.tree = null;
}
}
fn clear_injections(self: *Self) void {
for (self.injection_list.items) |*inj| inj.deinit(self.allocator);
self.injection_list.clearRetainingCapacity();
}
pub fn refresh_full(self: *Self, content: []const u8) !void {
self.reset();
self.clear_injections();
if (self.tree) |tree| tree.destroy();
self.tree = try self.parser.parseString(null, content);
try self.refresh_injections(content);
}
pub fn edit(self: *Self, ed: Edit) void {
@ -140,6 +167,170 @@ pub fn refresh_from_string(self: *Self, content: [:0]const u8) !void {
.encoding = .utf_8,
};
self.tree = try self.parser.parse(old_tree, input);
try self.refresh_injections(content);
}
pub fn refresh_injections(self: *Self, content: []const u8) !void {
self.clear_injections();
const injections_query = self.injections orelse return;
const tree = self.tree orelse return;
const cursor = try Query.Cursor.create();
defer cursor.destroy();
cursor.execute(injections_query, tree.getRootNode());
while (cursor.nextMatch()) |match| {
var lang_range: ?Range = null;
var content_range: ?Range = null;
for (match.captures()) |capture| {
const name = injections_query.getCaptureNameForId(capture.id);
if (std.mem.eql(u8, name, "injection.language")) {
lang_range = capture.node.getRange();
} else if (std.mem.eql(u8, name, "injection.content")) {
content_range = capture.node.getRange();
}
}
const crange = content_range orelse continue;
// Determine language name: dynamic @injection.language capture takes priority,
// then fall back to a static #set! injection.language predicate.
const lang_name: []const u8 = if (lang_range) |lr|
extract_node_text(content, lr) orelse continue
else
get_static_injection_language(injections_query, match.pattern_index) orelse continue;
if (lang_name.len == 0) continue;
const file_type = FileType.get_by_name_static(lang_name) orelse
FileType.get_by_name_static(normalize_lang_name(lang_name)) orelse
continue;
const start_byte = crange.start_byte;
const end_byte = crange.end_byte;
if (start_byte >= end_byte or end_byte > content.len) continue;
const child_content = content[start_byte..end_byte];
const child = try Self.create(file_type, self.allocator, self.query_cache);
errdefer child.destroy();
if (child.tree) |t| t.destroy();
child.tree = try child.parser.parseString(null, child_content);
const lang_name_owned = try self.allocator.dupe(u8, lang_name);
errdefer self.allocator.free(lang_name_owned);
try self.injection_list.append(self.allocator, .{
.lang_name = lang_name_owned,
.start_point = crange.start_point,
.end_row = crange.end_point.row,
.syntax = child,
});
}
}
fn extract_node_text(content: []const u8, range: Range) ?[]const u8 {
const s = range.start_byte;
const e = range.end_byte;
if (s >= e or e > content.len) return null;
return std.mem.trim(u8, content[s..e], &std.ascii.whitespace);
}
/// Normalize common language name aliases found in markdown
/// This should probably be in file_types
fn normalize_lang_name(name: []const u8) []const u8 {
const aliases = .{
.{ "js", "javascript" },
.{ "ts", "typescript" },
.{ "py", "python" },
.{ "rb", "ruby" },
.{ "sh", "bash" },
.{ "shell", "bash" },
.{ "zsh", "bash" },
.{ "c++", "cpp" },
.{ "cs", "c-sharp" },
.{ "csharp", "c-sharp" },
.{ "yml", "yaml" },
.{ "md", "markdown" },
.{ "rs", "rust" },
};
inline for (aliases) |alias| {
if (std.ascii.eqlIgnoreCase(name, alias[0])) return alias[1];
}
return name;
}
/// Read a static `#set! injection.language "name"` predicate from the query's
/// internal predicate table for the given pattern index, returning the language
/// name string if found or null otherwise.
///
/// This accesses TSQuery internals via the same cast used in ts_bin_query_gen.zig
fn get_static_injection_language(query: *Query, pattern_idx: u16) ?[]const u8 {
const tss = @import("ts_serializer.zig");
const ts_query: *tss.TSQuery = @ptrCast(@alignCast(query));
const patterns = ts_query.patterns;
if (patterns.contents == null or @as(u32, pattern_idx) >= patterns.size) return null;
const pattern_arr: [*]tss.QueryPattern = @ptrCast(patterns.contents.?);
const pattern = pattern_arr[pattern_idx];
const pred_steps = ts_query.predicate_steps;
if (pred_steps.contents == null or pred_steps.size == 0) return null;
const steps_arr: [*]tss.PredicateStep = @ptrCast(pred_steps.contents.?);
const pred_values = ts_query.predicate_values;
if (pred_values.slices.contents == null or pred_values.characters.contents == null) return null;
const slices_arr: [*]tss.Slice = @ptrCast(pred_values.slices.contents.?);
const chars: [*]u8 = @ptrCast(pred_values.characters.contents.?);
// Walk the predicate steps for this pattern looking for the sequence:
// string("set!") string("injection.language") string("<name>") done
const step_start = pattern.predicate_steps.offset;
const step_end = step_start + pattern.predicate_steps.length;
var i = step_start;
while (i < step_end) {
const s = steps_arr[i];
if (s.type == .done) {
i += 1;
continue;
}
// We need at least 4 steps: 3 strings + done.
if (i + 3 >= step_end) break;
const s0 = steps_arr[i];
const s1 = steps_arr[i + 1];
const s2 = steps_arr[i + 2];
const s3 = steps_arr[i + 3];
if (s0.type == .string and s1.type == .string and
s2.type == .string and s3.type == .done)
{
if (s0.value_id < pred_values.slices.size and
s1.value_id < pred_values.slices.size and
s2.value_id < pred_values.slices.size)
{
const sl0 = slices_arr[s0.value_id];
const sl1 = slices_arr[s1.value_id];
const sl2 = slices_arr[s2.value_id];
const n0 = chars[sl0.offset .. sl0.offset + sl0.length];
const n1 = chars[sl1.offset .. sl1.offset + sl1.length];
const n2 = chars[sl2.offset .. sl2.offset + sl2.length];
if (std.mem.eql(u8, n0, "set!") and
std.mem.eql(u8, n1, "injection.language"))
{
return n2;
}
}
}
// Advance past this predicate group to the next .done boundary.
while (i < step_end and steps_arr[i].type != .done) i += 1;
if (i < step_end) i += 1;
}
return null;
}
fn find_line_begin(s: []const u8, line: usize) ?usize {
@ -160,6 +351,66 @@ fn CallBack(comptime T: type) type {
}
pub fn render(self: *const Self, ctx: anytype, comptime cb: CallBack(@TypeOf(ctx)), range: ?Range) !void {
try self.render_highlights_only(ctx, cb, range);
for (self.injection_list.items) |*inj| {
if (range) |r| {
if (inj.end_row < r.start_point.row) continue;
if (inj.start_point.row > r.end_point.row) continue;
}
const child_range: ?Range = if (range) |r| blk: {
const child_start_row: u32 = if (r.start_point.row > inj.start_point.row)
r.start_point.row - inj.start_point.row
else
0;
const child_end_row: u32 = r.end_point.row - inj.start_point.row;
break :blk .{
.start_point = .{ .row = child_start_row, .column = 0 },
.end_point = .{ .row = child_end_row, .column = 0 },
.start_byte = 0,
.end_byte = 0,
};
} else null;
// Wrap the context so we can translate local ranges to absolute
// document coordinates before forwarding to the callback
const InjCtx = struct {
parent_ctx: @TypeOf(ctx),
inj: *const Injection,
fn translated_cb(
self_: *const @This(),
child_sel: Range,
scope: []const u8,
id: u32,
capture_idx: usize,
node: *const Node,
) error{Stop}!void {
const start_row = child_sel.start_point.row + self_.inj.start_point.row;
const end_row = child_sel.end_point.row + self_.inj.start_point.row;
// Column offset only applies on the very first line of the injection
const start_col = child_sel.start_point.column +
if (child_sel.start_point.row == 0) self_.inj.start_point.column else 0;
const end_col = child_sel.end_point.column +
if (child_sel.end_point.row == 0) self_.inj.start_point.column else 0;
const doc_range: Range = .{
.start_point = .{ .row = start_row, .column = start_col },
.end_point = .{ .row = end_row, .column = end_col },
.start_byte = child_sel.start_byte,
.end_byte = child_sel.end_byte,
};
try cb(self_.parent_ctx, doc_range, scope, id, capture_idx, node);
}
};
var inj_ctx: InjCtx = .{ .parent_ctx = ctx, .inj = inj };
try inj.syntax.render_highlights_only(&inj_ctx, InjCtx.translated_cb, child_range);
}
}
fn render_highlights_only(self: *const Self, ctx: anytype, comptime cb: CallBack(@TypeOf(ctx)), range: ?Range) !void {
const cursor = try Query.Cursor.create();
defer cursor.destroy();
const tree = self.tree orelse return;

View file

@ -64,6 +64,7 @@ pub const Query = struct {
pub fn destroy(_: *@This()) void {}
pub const Match = struct {
pattern_index: u16 = 0,
pub fn captures(_: *@This()) []Capture {
return &[_]Capture{};
}