Added struct and union cases in extract and extractAlloc and match (except for structs). Added tests for enums, unions and structs.

This commit is contained in:
Lumor Sunil 2025-07-20 22:40:08 +02:00 committed by CJ van den Berg
parent 6eccce0b98
commit 3bdf25183e
2 changed files with 437 additions and 22 deletions

View file

@ -18,7 +18,9 @@ pub const Error = error{
OutOfMemory,
InvalidFloatType,
InvalidArrayType,
InvalidMapType,
InvalidPIntType,
InvalidUnion,
JsonIncompatibleType,
NotAnObject,
BadArrayAllocExtract,
@ -212,6 +214,44 @@ fn writeErrorset(writer: anytype, err: anyerror) @TypeOf(writer).Error!void {
return writeString(writer, stream.getWritten());
}
fn writeEnum(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
const T = @TypeOf(value);
if (std.meta.hasFn(T, "cborEncode")) {
return value.cborEncode(writer);
}
return writeString(writer, @tagName(value));
}
fn writeUnion(writer: anytype, value: anytype, info: std.builtin.Type.Union) @TypeOf(writer).Error!void {
const T = @TypeOf(value);
if (std.meta.hasFn(T, "cborEncode")) {
return value.cborEncode(writer);
}
if (info.tag_type) |TagType| {
inline for (info.fields) |u_field| {
const t = @field(TagType, u_field.name);
if (value == t) {
const Payload = std.meta.TagPayload(T, t);
if (Payload != void) {
try writeArrayHeader(writer, 2);
try writeEnum(writer, value);
return try writeValue(writer, @field(value, u_field.name));
} else {
try writeArrayHeader(writer, 1);
try writeEnum(writer, value);
}
return;
}
} else unreachable;
} else {
@compileError("cannot write untagged union '" ++ @typeName(T) ++ "' to cbor stream");
}
}
pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
const T = @TypeOf(value);
switch (@typeInfo(T)) {
@ -220,25 +260,7 @@ pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
.optional => return if (value) |v| writeValue(writer, v) else writeNull(writer),
.error_union => return if (value) |v| writeValue(writer, v) else |err| writeValue(writer, err),
.error_set => return writeErrorset(writer, value),
.@"union" => |info| {
if (std.meta.hasFn(T, "cborEncode")) {
return value.cborEncode(writer);
}
if (info.tag_type) |TagType| {
comptime var v = void;
inline for (info.fields) |u_field| {
if (value == @field(TagType, u_field.name))
v = @field(value, u_field.name);
}
try writeArray(writer, .{
@typeName(T),
@tagName(@as(TagType, value)),
v,
});
} else {
try writeArray(writer, .{@typeName(T)});
}
},
.@"union" => |info| return writeUnion(writer, value, info),
.@"struct" => |info| {
if (std.meta.hasFn(T, "cborEncode")) {
return value.cborEncode(writer);
@ -501,6 +523,17 @@ pub fn matchIntValue(comptime T: type, iter: *[]const u8, val: T) Error!bool {
return if (try matchInt(T, iter, &v)) v == val else false;
}
pub fn matchNull(iter_: *[]const u8) Error!bool {
var iter = iter_.*;
if (iter.len > 0 and iter[0] == cbor_magic_null) {
iter_.* = iter[1..];
return true;
}
return false;
}
pub fn matchBool(iter_: *[]const u8, v: *bool) Error!bool {
var iter = iter_.*;
const t = try decodeType(&iter);
@ -555,6 +588,191 @@ fn matchEnumValue(comptime T: type, iter: *[]const u8, val: T) Error!bool {
return matchStringValue(iter, @tagName(val));
}
fn matchUnionScalar(comptime T: type, iter_: *[]const u8, val_: *T) Error!bool {
var iter = iter_.*;
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
error.InvalidArrayType => return false,
error.InvalidPIntType => return e,
error.TooShort => return e,
};
if (n == 0) return false;
const TagType = std.meta.Tag(T);
var unionTag: TagType = undefined;
if (!try matchEnum(TagType, &iter, &unionTag)) return false;
inline for (comptime std.meta.tags(TagType)) |t_| {
if (t_ == unionTag) {
const Payload = std.meta.TagPayload(T, t_);
if (Payload == void) {
if (n != 1) return false;
val_.* = t_;
iter_.* = iter;
return true;
} else {
if (n != 2) return false;
var val: Payload = undefined;
if (try matchValue(&iter, extract(&val))) {
val_.* = @unionInit(T, @tagName(t_), val);
iter_.* = iter;
return true;
}
}
}
}
return false;
}
fn matchUnionAlloc(comptime T: type, iter_: *[]const u8, val_: *T, allocator: std.mem.Allocator) Error!bool {
var iter = iter_.*;
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
error.InvalidArrayType => return false,
error.InvalidPIntType => return e,
error.TooShort => return e,
};
if (n == 0) return false;
const TagType = std.meta.Tag(T);
var unionTag: TagType = undefined;
if (!try matchEnum(TagType, &iter, &unionTag)) return false;
inline for (comptime std.meta.tags(TagType)) |t_| {
if (t_ == unionTag) {
const Payload = std.meta.TagPayload(T, t_);
if (Payload == void) {
if (n != 1) return false;
val_.* = t_;
iter_.* = iter;
return true;
} else {
if (n != 2) return false;
var val: Payload = undefined;
if (try matchValue(&iter, extractAlloc(&val, allocator))) {
val_.* = @unionInit(T, @tagName(t_), val);
iter_.* = iter;
return true;
}
}
}
}
return false;
}
fn matchUnionValue(comptime T: type, iter_: *[]const u8, val: T) Error!bool {
switch (val) {
inline else => |v, t| {
var iter = iter_.*;
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
error.InvalidArrayType => return false,
error.InvalidPIntType => return e,
error.TooShort => return e,
};
if (n == 0) return false;
if (!try matchEnumValue(std.meta.Tag(T), &iter, t)) return false;
if (std.meta.TagPayload(T, t) != void) {
if (n != 2) return false;
if (!try matchValue(&iter, v)) return false;
} else {
if (n != 1) return false;
}
iter_.* = iter;
return true;
},
}
}
fn matchStructScalar(comptime T: type, iter_: *[]const u8, val_: *T) Error!bool {
var iter = iter_.*;
const info = @typeInfo(T).@"struct";
const len = decodeMapHeader(&iter) catch |err| switch (err) {
error.TooShort => return false,
error.InvalidMapType => return err,
error.InvalidPIntType => return err,
};
if (len != info.fields.len) return false;
if (info.fields.len == 0) {
iter_.* = iter;
val_.* = .{};
return true;
}
var val: T = undefined;
fields: for (0..info.fields.len) |_| {
var fieldName: []const u8 = undefined;
if (!try matchString(&iter, &fieldName)) return false;
inline for (info.fields) |f| {
if (std.mem.eql(u8, f.name, fieldName)) {
var fieldVal: @FieldType(T, f.name) = undefined;
if (!try matchValue(&iter, extract(&fieldVal))) return false;
@field(val, f.name) = fieldVal;
continue :fields;
}
}
return false;
}
val_.* = val;
iter_.* = iter;
return true;
}
fn matchStructAlloc(comptime T: type, iter_: *[]const u8, val_: *T, allocator: std.mem.Allocator) Error!bool {
var iter = iter_.*;
const info = @typeInfo(T).@"struct";
const len = decodeMapHeader(&iter) catch |err| switch (err) {
error.TooShort => return false,
error.InvalidMapType => return err,
error.InvalidPIntType => return err,
};
if (len != info.fields.len) return false;
if (info.fields.len == 0) {
iter_.* = iter;
val_.* = .{};
return true;
}
var val: T = undefined;
for (0..info.fields.len) |_| {
var fieldName: []const u8 = undefined;
if (!try matchString(&iter, &fieldName)) return false;
inline for (info.fields) |f| {
if (std.mem.eql(u8, f.name, fieldName)) {
var fieldVal: @FieldType(T, f.name) = undefined;
if (!try matchValue(&iter, extractAlloc(&fieldVal, allocator))) return false;
@field(val, f.name) = fieldVal;
break;
}
} else return false;
}
val_.* = val;
iter_.* = iter;
return true;
}
fn skipString(iter: *[]const u8, minor: u5) Error!void {
const len: usize = @intCast(try decodePInt(iter, minor));
if (iter.len < len)
@ -683,14 +901,20 @@ pub fn matchValue(iter: *[]const u8, value: anytype) Error!bool {
.many, .c => matchError(T),
.slice => if (info.child == u8) matchStringValue(iter, value) else matchArray(iter, value, info),
},
.optional => if (value) |v| matchValue(iter, v) else matchNull(iter),
.@"struct" => |info| if (info.is_tuple)
matchArray(iter, value, info)
// TODO: Add case for matching struct here
else
matchError(T),
.array => |info| if (info.child == u8) matchStringValue(iter, &value) else matchArray(iter, value, info),
.float => matchFloatValue(T, iter, value),
.comptime_float => matchFloatValue(f64, iter, value),
.@"enum" => matchEnumValue(T, iter, value),
.@"union" => |info| if (info.tag_type) |_|
matchUnionValue(T, iter, value)
else
@compileError("cannot match value type '" ++ @typeName(T) ++ "' to cbor stream"),
else => @compileError("cannot match value type '" ++ @typeName(T) ++ "' to cbor stream"),
};
}
@ -949,7 +1173,13 @@ fn GenericExtractorAlloc(T: type) type {
.float => return matchFloat(T, iter, self.dest),
.@"enum" => return matchEnum(T, iter, self.dest),
.array => return matchArrayScalar(iter, self.dest),
else => return self.dest.cborExtract(iter),
else => if (@hasDecl(T, "cborExtract")) {
return self.dest.cborExtract(iter);
} else switch (comptime @typeInfo(T)) {
.@"union" => return matchUnionAlloc(T, iter, self.dest, self.allocator),
.@"struct" => return matchStructAlloc(T, iter, self.dest, self.allocator),
else => @compileError(@typeName(T) ++ " (" ++ @tagName(@typeInfo(T)) ++ ") is and unsupported or invalid type for cbor extract, or implement cborExtract function"),
},
}
}
}
@ -1024,7 +1254,13 @@ fn Extractor(comptime T: type) type {
.float => return matchFloat(T, iter, self.dest),
.@"enum" => return matchEnum(T, iter, self.dest),
.array => return matchArrayScalar(iter, self.dest),
else => return self.dest.cborExtract(iter),
else => if (@hasDecl(T, "cborExtract")) {
return self.dest.cborExtract(iter);
} else switch (comptime @typeInfo(T)) {
.@"union" => return matchUnionScalar(T, iter, self.dest),
.@"struct" => return matchStructScalar(T, iter, self.dest),
else => @compileError("cannot extract type " ++ @typeName(T)),
},
}
}
};