Added struct and union cases in extract and extractAlloc and match (except for structs). Added tests for enums, unions and structs.

This commit is contained in:
Lumor Sunil 2025-07-20 22:40:08 +02:00 committed by CJ van den Berg
parent 6eccce0b98
commit 3bdf25183e
2 changed files with 437 additions and 22 deletions

View file

@ -18,7 +18,9 @@ pub const Error = error{
OutOfMemory,
InvalidFloatType,
InvalidArrayType,
InvalidMapType,
InvalidPIntType,
InvalidUnion,
JsonIncompatibleType,
NotAnObject,
BadArrayAllocExtract,
@ -212,6 +214,44 @@ fn writeErrorset(writer: anytype, err: anyerror) @TypeOf(writer).Error!void {
return writeString(writer, stream.getWritten());
}
fn writeEnum(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
const T = @TypeOf(value);
if (std.meta.hasFn(T, "cborEncode")) {
return value.cborEncode(writer);
}
return writeString(writer, @tagName(value));
}
fn writeUnion(writer: anytype, value: anytype, info: std.builtin.Type.Union) @TypeOf(writer).Error!void {
const T = @TypeOf(value);
if (std.meta.hasFn(T, "cborEncode")) {
return value.cborEncode(writer);
}
if (info.tag_type) |TagType| {
inline for (info.fields) |u_field| {
const t = @field(TagType, u_field.name);
if (value == t) {
const Payload = std.meta.TagPayload(T, t);
if (Payload != void) {
try writeArrayHeader(writer, 2);
try writeEnum(writer, value);
return try writeValue(writer, @field(value, u_field.name));
} else {
try writeArrayHeader(writer, 1);
try writeEnum(writer, value);
}
return;
}
} else unreachable;
} else {
@compileError("cannot write untagged union '" ++ @typeName(T) ++ "' to cbor stream");
}
}
pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
const T = @TypeOf(value);
switch (@typeInfo(T)) {
@ -220,25 +260,7 @@ pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
.optional => return if (value) |v| writeValue(writer, v) else writeNull(writer),
.error_union => return if (value) |v| writeValue(writer, v) else |err| writeValue(writer, err),
.error_set => return writeErrorset(writer, value),
.@"union" => |info| {
if (std.meta.hasFn(T, "cborEncode")) {
return value.cborEncode(writer);
}
if (info.tag_type) |TagType| {
comptime var v = void;
inline for (info.fields) |u_field| {
if (value == @field(TagType, u_field.name))
v = @field(value, u_field.name);
}
try writeArray(writer, .{
@typeName(T),
@tagName(@as(TagType, value)),
v,
});
} else {
try writeArray(writer, .{@typeName(T)});
}
},
.@"union" => |info| return writeUnion(writer, value, info),
.@"struct" => |info| {
if (std.meta.hasFn(T, "cborEncode")) {
return value.cborEncode(writer);
@ -501,6 +523,17 @@ pub fn matchIntValue(comptime T: type, iter: *[]const u8, val: T) Error!bool {
return if (try matchInt(T, iter, &v)) v == val else false;
}
pub fn matchNull(iter_: *[]const u8) Error!bool {
var iter = iter_.*;
if (iter.len > 0 and iter[0] == cbor_magic_null) {
iter_.* = iter[1..];
return true;
}
return false;
}
pub fn matchBool(iter_: *[]const u8, v: *bool) Error!bool {
var iter = iter_.*;
const t = try decodeType(&iter);
@ -555,6 +588,191 @@ fn matchEnumValue(comptime T: type, iter: *[]const u8, val: T) Error!bool {
return matchStringValue(iter, @tagName(val));
}
fn matchUnionScalar(comptime T: type, iter_: *[]const u8, val_: *T) Error!bool {
var iter = iter_.*;
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
error.InvalidArrayType => return false,
error.InvalidPIntType => return e,
error.TooShort => return e,
};
if (n == 0) return false;
const TagType = std.meta.Tag(T);
var unionTag: TagType = undefined;
if (!try matchEnum(TagType, &iter, &unionTag)) return false;
inline for (comptime std.meta.tags(TagType)) |t_| {
if (t_ == unionTag) {
const Payload = std.meta.TagPayload(T, t_);
if (Payload == void) {
if (n != 1) return false;
val_.* = t_;
iter_.* = iter;
return true;
} else {
if (n != 2) return false;
var val: Payload = undefined;
if (try matchValue(&iter, extract(&val))) {
val_.* = @unionInit(T, @tagName(t_), val);
iter_.* = iter;
return true;
}
}
}
}
return false;
}
fn matchUnionAlloc(comptime T: type, iter_: *[]const u8, val_: *T, allocator: std.mem.Allocator) Error!bool {
var iter = iter_.*;
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
error.InvalidArrayType => return false,
error.InvalidPIntType => return e,
error.TooShort => return e,
};
if (n == 0) return false;
const TagType = std.meta.Tag(T);
var unionTag: TagType = undefined;
if (!try matchEnum(TagType, &iter, &unionTag)) return false;
inline for (comptime std.meta.tags(TagType)) |t_| {
if (t_ == unionTag) {
const Payload = std.meta.TagPayload(T, t_);
if (Payload == void) {
if (n != 1) return false;
val_.* = t_;
iter_.* = iter;
return true;
} else {
if (n != 2) return false;
var val: Payload = undefined;
if (try matchValue(&iter, extractAlloc(&val, allocator))) {
val_.* = @unionInit(T, @tagName(t_), val);
iter_.* = iter;
return true;
}
}
}
}
return false;
}
fn matchUnionValue(comptime T: type, iter_: *[]const u8, val: T) Error!bool {
switch (val) {
inline else => |v, t| {
var iter = iter_.*;
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
error.InvalidArrayType => return false,
error.InvalidPIntType => return e,
error.TooShort => return e,
};
if (n == 0) return false;
if (!try matchEnumValue(std.meta.Tag(T), &iter, t)) return false;
if (std.meta.TagPayload(T, t) != void) {
if (n != 2) return false;
if (!try matchValue(&iter, v)) return false;
} else {
if (n != 1) return false;
}
iter_.* = iter;
return true;
},
}
}
fn matchStructScalar(comptime T: type, iter_: *[]const u8, val_: *T) Error!bool {
var iter = iter_.*;
const info = @typeInfo(T).@"struct";
const len = decodeMapHeader(&iter) catch |err| switch (err) {
error.TooShort => return false,
error.InvalidMapType => return err,
error.InvalidPIntType => return err,
};
if (len != info.fields.len) return false;
if (info.fields.len == 0) {
iter_.* = iter;
val_.* = .{};
return true;
}
var val: T = undefined;
fields: for (0..info.fields.len) |_| {
var fieldName: []const u8 = undefined;
if (!try matchString(&iter, &fieldName)) return false;
inline for (info.fields) |f| {
if (std.mem.eql(u8, f.name, fieldName)) {
var fieldVal: @FieldType(T, f.name) = undefined;
if (!try matchValue(&iter, extract(&fieldVal))) return false;
@field(val, f.name) = fieldVal;
continue :fields;
}
}
return false;
}
val_.* = val;
iter_.* = iter;
return true;
}
fn matchStructAlloc(comptime T: type, iter_: *[]const u8, val_: *T, allocator: std.mem.Allocator) Error!bool {
var iter = iter_.*;
const info = @typeInfo(T).@"struct";
const len = decodeMapHeader(&iter) catch |err| switch (err) {
error.TooShort => return false,
error.InvalidMapType => return err,
error.InvalidPIntType => return err,
};
if (len != info.fields.len) return false;
if (info.fields.len == 0) {
iter_.* = iter;
val_.* = .{};
return true;
}
var val: T = undefined;
for (0..info.fields.len) |_| {
var fieldName: []const u8 = undefined;
if (!try matchString(&iter, &fieldName)) return false;
inline for (info.fields) |f| {
if (std.mem.eql(u8, f.name, fieldName)) {
var fieldVal: @FieldType(T, f.name) = undefined;
if (!try matchValue(&iter, extractAlloc(&fieldVal, allocator))) return false;
@field(val, f.name) = fieldVal;
break;
}
} else return false;
}
val_.* = val;
iter_.* = iter;
return true;
}
fn skipString(iter: *[]const u8, minor: u5) Error!void {
const len: usize = @intCast(try decodePInt(iter, minor));
if (iter.len < len)
@ -683,14 +901,20 @@ pub fn matchValue(iter: *[]const u8, value: anytype) Error!bool {
.many, .c => matchError(T),
.slice => if (info.child == u8) matchStringValue(iter, value) else matchArray(iter, value, info),
},
.optional => if (value) |v| matchValue(iter, v) else matchNull(iter),
.@"struct" => |info| if (info.is_tuple)
matchArray(iter, value, info)
// TODO: Add case for matching struct here
else
matchError(T),
.array => |info| if (info.child == u8) matchStringValue(iter, &value) else matchArray(iter, value, info),
.float => matchFloatValue(T, iter, value),
.comptime_float => matchFloatValue(f64, iter, value),
.@"enum" => matchEnumValue(T, iter, value),
.@"union" => |info| if (info.tag_type) |_|
matchUnionValue(T, iter, value)
else
@compileError("cannot match value type '" ++ @typeName(T) ++ "' to cbor stream"),
else => @compileError("cannot match value type '" ++ @typeName(T) ++ "' to cbor stream"),
};
}
@ -949,7 +1173,13 @@ fn GenericExtractorAlloc(T: type) type {
.float => return matchFloat(T, iter, self.dest),
.@"enum" => return matchEnum(T, iter, self.dest),
.array => return matchArrayScalar(iter, self.dest),
else => return self.dest.cborExtract(iter),
else => if (@hasDecl(T, "cborExtract")) {
return self.dest.cborExtract(iter);
} else switch (comptime @typeInfo(T)) {
.@"union" => return matchUnionAlloc(T, iter, self.dest, self.allocator),
.@"struct" => return matchStructAlloc(T, iter, self.dest, self.allocator),
else => @compileError(@typeName(T) ++ " (" ++ @tagName(@typeInfo(T)) ++ ") is and unsupported or invalid type for cbor extract, or implement cborExtract function"),
},
}
}
}
@ -1024,7 +1254,13 @@ fn Extractor(comptime T: type) type {
.float => return matchFloat(T, iter, self.dest),
.@"enum" => return matchEnum(T, iter, self.dest),
.array => return matchArrayScalar(iter, self.dest),
else => return self.dest.cborExtract(iter),
else => if (@hasDecl(T, "cborExtract")) {
return self.dest.cborExtract(iter);
} else switch (comptime @typeInfo(T)) {
.@"union" => return matchUnionScalar(T, iter, self.dest),
.@"struct" => return matchStructScalar(T, iter, self.dest),
else => @compileError("cannot extract type " ++ @typeName(T)),
},
}
}
};

View file

@ -24,6 +24,7 @@ const writeArrayHeader = cbor_mod.writeArrayHeader;
const writeMapHeader = cbor_mod.writeMapHeader;
const writeValue = cbor_mod.writeValue;
const extract = cbor_mod.extract;
const extractAlloc = cbor_mod.extractAlloc;
const extract_cbor = cbor_mod.extract_cbor;
const more = cbor_mod.more;
@ -520,5 +521,183 @@ test "cbor.extract_cbor f64" {
const json2 = try toJsonAlloc(std.testing.allocator, sub);
defer std.testing.allocator.free(json2);
try expectEqualStrings("[0.9689138531684875]", json2);
try expectEqualStrings("[9.689138531684875e-1]", json2);
}
test "cbor.writeValue enum" {
const TestEnum = enum { a, b };
var buf: [128]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
const writer = stream.writer();
try writeValue(writer, TestEnum.a);
const expected = @tagName(TestEnum.a);
var m = std.json.Value{ .null = {} };
try expect(try match(stream.getWritten(), extract(&m)));
try expectEqualStrings(expected, m.string);
}
test "cbor.extract enum" {
const TestEnum = enum { a, b };
var buf: [128]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
const writer = stream.writer();
try writeValue(writer, TestEnum.a);
var m: TestEnum = undefined;
try expect(try match(stream.getWritten(), extract(&m)));
try expectEqualStrings(@tagName(m), @tagName(TestEnum.a));
try expect(m == .a);
}
test "cbor.writeValue tagged union" {
const TestUnion = union(enum) { a: f32, b: []const u8 };
var buf: [128]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
const writer = stream.writer();
try writeValue(writer, TestUnion{ .b = "should work" });
var tagName = std.json.Value{ .null = {} };
var value = std.json.Value{ .null = {} };
var iter: []const u8 = stream.getWritten();
try expect(try matchValue(&iter, .{ extract(&tagName), extract(&value) }));
try expectEqualStrings(@tagName(TestUnion.b), tagName.string);
try expectEqualStrings("should work", value.string);
}
test "cbor.writeValue tagged union no payload types" {
const TestUnion = union(enum) { a, b };
var buf: [128]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
const writer = stream.writer();
try writeValue(writer, TestUnion.b);
try expect(try match(stream.getWritten(), @tagName(TestUnion.b)));
var tagName = std.json.Value{ .null = {} };
try expect(try match(stream.getWritten(), extract(&tagName)));
try expectEqualStrings(@tagName(TestUnion.b), tagName.string);
}
test "cbor.writeValue nested union json" {
const TestUnion = union(enum) {
a: f32,
b: []const u8,
c: []const union(enum) {
d: f32,
e: i32,
f,
},
};
var buf: [256]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
const writer = stream.writer();
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
defer arena.deinit();
try writeValue(writer, TestUnion{ .c = &.{ .{ .d = 1.5 }, .{ .e = 5 }, .f } });
var json_buf: [128]u8 = undefined;
const json = try toJson(stream.getWritten(), &json_buf);
try expectEqualStrings(json,
\\["c",[["d",1.5e0],["e",5],["f"]]]
);
}
test "cbor.extract tagged union" {
const TestUnion = union(enum) { a: f32, b: []const u8 };
var buf: [128]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
const writer = stream.writer();
try writeValue(writer, TestUnion{ .b = "should work" });
var m: TestUnion = undefined;
try expect(try match(stream.getWritten(), extract(&m)));
try expectEqualDeep(TestUnion{ .b = "should work" }, m);
}
test "cbor.extract nested union" {
const TestUnion = union(enum) {
a: f32,
b: []const u8,
c: []const union(enum) {
d: f32,
e: i32,
f,
},
};
var buf: [256]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
const writer = stream.writer();
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
const allocator = arena.allocator();
defer arena.deinit();
try writeValue(writer, TestUnion{ .c = &.{ .{ .d = 1.5 }, .{ .e = 5 }, .f } });
var m: TestUnion = undefined;
try expect(try match(stream.getWritten(), extractAlloc(&m, allocator)));
try expectEqualDeep(TestUnion{ .c = &.{ .{ .d = 1.5 }, .{ .e = 5 }, .f } }, m);
}
test "cbor.extract struct no fields" {
const TestStruct = struct {};
var buf: [128]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
const writer = stream.writer();
try writeValue(writer, TestStruct{});
var m: TestStruct = undefined;
try expect(try match(stream.getWritten(), extract(&m)));
try expectEqualDeep(TestStruct{}, m);
}
test "cbor.extract_cbor struct" {
const TestStruct = struct {
a: f32,
b: []const u8,
};
var buf: [128]u8 = undefined;
const v = TestStruct{ .a = 1.5, .b = "hello" };
const m = fmt(&buf, v);
var map_cbor: []const u8 = undefined;
try expect(try match(m, extract_cbor(&map_cbor)));
var json_buf: [256]u8 = undefined;
const json = try toJson(map_cbor, &json_buf);
try expectEqualStrings(json,
\\{"a":1.5e0,"b":"hello"}
);
}
test "cbor.extract struct" {
const TestStruct = struct {
a: f32,
b: []const u8,
};
var buf: [128]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
const writer = stream.writer();
const v = TestStruct{ .a = 1.5, .b = "hello" };
try writeValue(writer, v);
var obj: TestStruct = undefined;
var json_buf: [256]u8 = undefined;
var iter: []const u8 = stream.getWritten();
const t = try decodeType(&iter);
try expectEqual(5, t.major);
const json = try toJson(stream.getWritten(), &json_buf);
try expectEqualStrings(json,
\\{"a":1.5e0,"b":"hello"}
);
try expect(try match(stream.getWritten(), extract(&obj)));
try expectEqual(1.5, obj.a);
try expectEqualStrings("hello", obj.b);
}
test "cbor.extractAlloc struct" {
const TestStruct = struct {
a: f32,
b: []const u8,
};
var buf: [128]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
const allocator = arena.allocator();
defer arena.deinit();
const writer = stream.writer();
const v = TestStruct{ .a = 1.5, .b = "hello" };
try writeValue(writer, v);
var obj: TestStruct = undefined;
try expect(try match(stream.getWritten(), extractAlloc(&obj, allocator)));
try expectEqual(1.5, obj.a);
try expectEqualStrings("hello", obj.b);
}