fix: use a partial write capable case folding writer in Buffer.find_all_ranges

This fixes case insensitive search. Previously the case folding would fail on
input slices that contain partial utf8 sequences, which is normal in the
buffer write process design. Now these partial utf8 sequences are not consumed
and instead pushed to the next write call where they will be completed from the
main buffer contents.
This commit is contained in:
CJ van den Berg 2025-11-26 09:56:39 +01:00
parent 68b17301cd
commit 99f9f95dbc
Signed by: neurocyte
GPG key ID: 8EB1E1BB660E3FB9
2 changed files with 71 additions and 7 deletions

View file

@ -989,9 +989,9 @@ const Node = union(enum) {
.case_folded => { .case_folded => {
const input_consume_size = @min(ctx.buf.len - ctx.rest.len, input.len); const input_consume_size = @min(ctx.buf.len - ctx.rest.len, input.len);
var writer = std.Io.Writer.fixed(ctx.buf[ctx.rest.len..]); var writer = std.Io.Writer.fixed(ctx.buf[ctx.rest.len..]);
unicode.case_folded_write(&writer, input[0..input_consume_size]) catch return error.WriteFailed; const folded = unicode.case_folded_write_partial(&writer, input[0..input_consume_size]) catch return error.WriteFailed;
ctx.rest = ctx.buf[0 .. ctx.rest.len + writer.end]; ctx.rest = ctx.buf[0 .. ctx.rest.len + folded.len];
input = input[input_consume_size..]; input = input[folded.len..];
}, },
} }

View file

@ -89,7 +89,7 @@ fn raw_byte_to_utf8(cp: u8, buf: []u8) ![]const u8 {
var utf16le: [1]u16 = undefined; var utf16le: [1]u16 = undefined;
const utf16le_as_bytes = std.mem.sliceAsBytes(utf16le[0..]); const utf16le_as_bytes = std.mem.sliceAsBytes(utf16le[0..]);
std.mem.writeInt(u16, utf16le_as_bytes[0..2], cp, .little); std.mem.writeInt(u16, utf16le_as_bytes[0..2], cp, .little);
return buf[0..try std.unicode.utf16LeToUtf8(buf, &utf16le)]; return buf[0..try utf16LeToUtf8(buf, &utf16le)];
} }
pub fn utf8_sanitize(allocator: std.mem.Allocator, input: []const u8) error{ pub fn utf8_sanitize(allocator: std.mem.Allocator, input: []const u8) error{
@ -113,7 +113,7 @@ pub const TransformError = error{
}; };
fn utf8_write_transform(comptime field: uucode.FieldEnum, writer: *std.Io.Writer, text: []const u8) TransformError!void { fn utf8_write_transform(comptime field: uucode.FieldEnum, writer: *std.Io.Writer, text: []const u8) TransformError!void {
const view: std.unicode.Utf8View = .initUnchecked(text); const view: Utf8View = .initUnchecked(text);
var it = view.iterator(); var it = view.iterator();
while (it.nextCodepoint()) |cp| { while (it.nextCodepoint()) |cp| {
const cp_ = switch (field) { const cp_ = switch (field) {
@ -122,11 +122,27 @@ fn utf8_write_transform(comptime field: uucode.FieldEnum, writer: *std.Io.Writer
else => @compileError(@tagName(field) ++ " is not a unicode transformation"), else => @compileError(@tagName(field) ++ " is not a unicode transformation"),
}; };
var utf8_buf: [6]u8 = undefined; var utf8_buf: [6]u8 = undefined;
const size = try std.unicode.utf8Encode(cp_, &utf8_buf); const size = try utf8Encode(cp_, &utf8_buf);
try writer.writeAll(utf8_buf[0..size]); try writer.writeAll(utf8_buf[0..size]);
} }
} }
fn utf8_partial_write_transform(comptime field: uucode.FieldEnum, writer: *std.Io.Writer, text: []const u8) TransformError![]const u8 {
const view: Utf8PartialView = .initUnchecked(text);
var it = view.iterator();
while (it.nextCodepoint()) |cp| {
const cp_ = switch (field) {
.simple_uppercase_mapping, .simple_lowercase_mapping => uucode.get(field, cp) orelse cp,
.case_folding_simple => uucode.get(field, cp),
else => @compileError(@tagName(field) ++ " is not a unicode transformation"),
};
var utf8_buf: [6]u8 = undefined;
const size = try utf8Encode(cp_, &utf8_buf);
try writer.writeAll(utf8_buf[0..size]);
}
return text[0..it.end];
}
fn utf8_transform(comptime field: uucode.FieldEnum, allocator: std.mem.Allocator, text: []const u8) TransformError![]u8 { fn utf8_transform(comptime field: uucode.FieldEnum, allocator: std.mem.Allocator, text: []const u8) TransformError![]u8 {
var result: std.Io.Writer.Allocating = .init(allocator); var result: std.Io.Writer.Allocating = .init(allocator);
defer result.deinit(); defer result.deinit();
@ -135,7 +151,7 @@ fn utf8_transform(comptime field: uucode.FieldEnum, allocator: std.mem.Allocator
} }
fn utf8_predicate(comptime field: uucode.FieldEnum, text: []const u8) bool { fn utf8_predicate(comptime field: uucode.FieldEnum, text: []const u8) bool {
const view: std.unicode.Utf8View = .initUnchecked(text); const view: Utf8View = .initUnchecked(text);
var it = view.iterator(); var it = view.iterator();
while (it.nextCodepoint()) |cp| { while (it.nextCodepoint()) |cp| {
const result = switch (field) { const result = switch (field) {
@ -163,6 +179,10 @@ pub fn case_folded_write(writer: *std.Io.Writer, text: []const u8) TransformErro
return utf8_write_transform(.case_folding_simple, writer, text); return utf8_write_transform(.case_folding_simple, writer, text);
} }
pub fn case_folded_write_partial(writer: *std.Io.Writer, text: []const u8) TransformError![]const u8 {
return utf8_partial_write_transform(.case_folding_simple, writer, text);
}
pub fn switch_case(allocator: std.mem.Allocator, text: []const u8) TransformError![]u8 { pub fn switch_case(allocator: std.mem.Allocator, text: []const u8) TransformError![]u8 {
return if (utf8_predicate(.is_lowercase, text)) return if (utf8_predicate(.is_lowercase, text))
to_upper(allocator, text) to_upper(allocator, text)
@ -176,3 +196,47 @@ pub fn is_lowercase(text: []const u8) bool {
const std = @import("std"); const std = @import("std");
const uucode = @import("vaxis").uucode; const uucode = @import("vaxis").uucode;
const utf16LeToUtf8 = std.unicode.utf16LeToUtf8;
const utf8ByteSequenceLength = std.unicode.utf8ByteSequenceLength;
const utf8Decode = std.unicode.utf8Decode;
const utf8Encode = std.unicode.utf8Encode;
const Utf8View = std.unicode.Utf8View;
const Utf8PartialIterator = struct {
bytes: []const u8,
end: usize,
fn nextCodepointSlice(it: *Utf8PartialIterator) ?[]const u8 {
if (it.end >= it.bytes.len) {
return null;
}
const cp_len = utf8ByteSequenceLength(it.bytes[it.end]) catch unreachable;
if (it.end + cp_len > it.bytes.len) {
return null;
}
it.end += cp_len;
return it.bytes[it.end - cp_len .. it.end];
}
fn nextCodepoint(it: *Utf8PartialIterator) ?u21 {
const slice = it.nextCodepointSlice() orelse return null;
return utf8Decode(slice) catch unreachable;
}
};
const Utf8PartialView = struct {
bytes: []const u8,
fn initUnchecked(s: []const u8) Utf8PartialView {
return Utf8PartialView{ .bytes = s };
}
fn iterator(s: Utf8PartialView) Utf8PartialIterator {
return Utf8PartialIterator{
.bytes = s.bytes,
.end = 0,
};
}
};