diff --git a/src/syntax/src/QueryCache.zig b/src/syntax/src/QueryCache.zig index 21870c8..5843e4a 100644 --- a/src/syntax/src/QueryCache.zig +++ b/src/syntax/src/QueryCache.zig @@ -13,16 +13,23 @@ const Query = treez.Query; allocator: std.mem.Allocator, mutex: ?std.Thread.Mutex, -highlights: std.StringHashMapUnmanaged(*Query) = .{}, -injections: std.StringHashMapUnmanaged(*Query) = .{}, +highlights: std.StringHashMapUnmanaged(CacheEntry) = .{}, +injections: std.StringHashMapUnmanaged(CacheEntry) = .{}, ref_count: usize = 1, +const CacheEntry = struct { + mutex: ?std.Thread.Mutex, + query: ?*Query, + file_type: *const FileType, + query_type: QueryType, +}; + pub const QueryType = enum { highlights, injections, }; -pub const QueryParseError = error{ +const QueryParseError = error{ InvalidSyntax, InvalidNodeType, InvalidField, @@ -31,10 +38,12 @@ pub const QueryParseError = error{ InvalidLanguage, }; -pub const Error = (error{ +const CacheError = error{ NotFound, OutOfMemory, -} || QueryParseError); +}; + +pub const Error = CacheError || QueryParseError; pub fn create(allocator: std.mem.Allocator, opts: struct { lock: bool = false }) !*Self { const self = try allocator.create(Self); @@ -65,18 +74,62 @@ fn release_ref_unlocked_and_maybe_destroy(self: *Self) void { var iter_highlights = self.highlights.iterator(); while (iter_highlights.next()) |p| { self.allocator.free(p.key_ptr.*); - p.value_ptr.*.destroy(); + if (p.value_ptr.*.query) |q| q.destroy(); } var iter_injections = self.injections.iterator(); while (iter_injections.next()) |p| { self.allocator.free(p.key_ptr.*); - p.value_ptr.*.destroy(); + if (p.value_ptr.*.query) |q| q.destroy(); } self.highlights.deinit(self.allocator); self.injections.deinit(self.allocator); self.allocator.destroy(self); } +fn get_cache_entry(self: *Self, file_type: *const FileType, comptime query_type: QueryType) CacheError!*CacheEntry { + if (self.mutex) |*mtx| mtx.lock(); + defer if (self.mutex) |*mtx| mtx.unlock(); + + const hash = switch (query_type) { + .highlights => &self.highlights, + .injections => &self.injections, + }; + + return if (hash.getPtr(file_type.name)) |entry| entry else blk: { + const entry_ = try hash.getOrPut(self.allocator, try self.allocator.dupe(u8, file_type.name)); + entry_.value_ptr.* = .{ + .query = null, + .mutex = if (self.mutex) |_| .{} else null, + .file_type = file_type, + .query_type = query_type, + }; + break :blk entry_.value_ptr; + }; +} + +fn get_cached_query(_: *Self, entry: *CacheEntry) QueryParseError!?*Query { + if (entry.mutex) |*mtx| mtx.lock(); + defer if (entry.mutex) |*mtx| mtx.unlock(); + return if (entry.query) |query| query else blk: { + const lang = entry.file_type.lang_fn() orelse std.debug.panic("tree-sitter parser function failed for language: {s}", .{entry.file_type.name}); + entry.query = try Query.create(lang, switch (entry.query_type) { + .highlights => entry.file_type.highlights, + .injections => if (entry.file_type.injections) |injections| injections else return null, + }); + break :blk entry.query.?; + }; +} + +fn pre_load_internal(self: *Self, file_type: *const FileType, comptime query_type: QueryType) Error!void { + _ = try self.get_cached_query(try self.get_cache_entry(file_type, query_type)); +} + +pub fn pre_load(self: *Self, lang_name: []const u8) Error!void { + const file_type = FileType.get_by_name(lang_name) orelse return; + _ = try self.pre_load_internal(file_type, .highlights); + _ = try self.pre_load_internal(file_type, .injections); +} + fn ReturnType(comptime query_type: QueryType) type { return switch (query_type) { .highlights => *Query, @@ -84,38 +137,13 @@ fn ReturnType(comptime query_type: QueryType) type { }; } -fn get_or_add_internal(self: *Self, file_type: *const FileType, comptime query_type: QueryType) Error!ReturnType(query_type) { - const hash = switch (query_type) { - .highlights => &self.highlights, - .injections => &self.injections, - }; - - return if (hash.get(file_type.name)) |query| query else blk: { - const lang = file_type.lang_fn() orelse std.debug.panic("tree-sitter parser function failed for language: {s}", .{file_type.name}); - const query = try Query.create(lang, switch (query_type) { - .highlights => file_type.highlights, - .injections => if (file_type.injections) |injections| injections else return null, - }); - errdefer query.destroy(); - try hash.put(self.allocator, try self.allocator.dupe(u8, file_type.name), query); - break :blk query; - }; -} - -pub fn pre_load(self: *Self, lang_name: []const u8) Error!void { - if (self.mutex) |*mtx| mtx.lock(); - defer if (self.mutex) |*mtx| mtx.unlock(); - const file_type = FileType.get_by_name(lang_name) orelse return; - _ = try self.get_or_add_internal(file_type, .highlights); - _ = try self.get_or_add_internal(file_type, .injections); -} - pub fn get(self: *Self, file_type: *const FileType, comptime query_type: QueryType) Error!ReturnType(query_type) { - if (self.mutex) |*mtx| mtx.lock(); - defer if (self.mutex) |*mtx| mtx.unlock(); - const query = try self.get_or_add_internal(file_type, query_type); + const query = try self.get_cached_query(try self.get_cache_entry(file_type, query_type)); self.add_ref_locked(); - return query; + return switch (@typeInfo(ReturnType(query_type))) { + .optional => |_| query, + else => query.?, + }; } pub fn release(self: *Self, query: *Query, comptime query_type: QueryType) void {