Added struct and union cases in extract and extractAlloc and match (except for structs). Added tests for enums, unions and structs.
This commit is contained in:
parent
6eccce0b98
commit
3bdf25183e
2 changed files with 437 additions and 22 deletions
278
src/cbor.zig
278
src/cbor.zig
|
|
@ -18,7 +18,9 @@ pub const Error = error{
|
|||
OutOfMemory,
|
||||
InvalidFloatType,
|
||||
InvalidArrayType,
|
||||
InvalidMapType,
|
||||
InvalidPIntType,
|
||||
InvalidUnion,
|
||||
JsonIncompatibleType,
|
||||
NotAnObject,
|
||||
BadArrayAllocExtract,
|
||||
|
|
@ -212,6 +214,44 @@ fn writeErrorset(writer: anytype, err: anyerror) @TypeOf(writer).Error!void {
|
|||
return writeString(writer, stream.getWritten());
|
||||
}
|
||||
|
||||
fn writeEnum(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
|
||||
const T = @TypeOf(value);
|
||||
|
||||
if (std.meta.hasFn(T, "cborEncode")) {
|
||||
return value.cborEncode(writer);
|
||||
}
|
||||
|
||||
return writeString(writer, @tagName(value));
|
||||
}
|
||||
|
||||
fn writeUnion(writer: anytype, value: anytype, info: std.builtin.Type.Union) @TypeOf(writer).Error!void {
|
||||
const T = @TypeOf(value);
|
||||
|
||||
if (std.meta.hasFn(T, "cborEncode")) {
|
||||
return value.cborEncode(writer);
|
||||
}
|
||||
if (info.tag_type) |TagType| {
|
||||
inline for (info.fields) |u_field| {
|
||||
const t = @field(TagType, u_field.name);
|
||||
if (value == t) {
|
||||
const Payload = std.meta.TagPayload(T, t);
|
||||
if (Payload != void) {
|
||||
try writeArrayHeader(writer, 2);
|
||||
try writeEnum(writer, value);
|
||||
return try writeValue(writer, @field(value, u_field.name));
|
||||
} else {
|
||||
try writeArrayHeader(writer, 1);
|
||||
try writeEnum(writer, value);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
} else unreachable;
|
||||
} else {
|
||||
@compileError("cannot write untagged union '" ++ @typeName(T) ++ "' to cbor stream");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
|
||||
const T = @TypeOf(value);
|
||||
switch (@typeInfo(T)) {
|
||||
|
|
@ -220,25 +260,7 @@ pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
|
|||
.optional => return if (value) |v| writeValue(writer, v) else writeNull(writer),
|
||||
.error_union => return if (value) |v| writeValue(writer, v) else |err| writeValue(writer, err),
|
||||
.error_set => return writeErrorset(writer, value),
|
||||
.@"union" => |info| {
|
||||
if (std.meta.hasFn(T, "cborEncode")) {
|
||||
return value.cborEncode(writer);
|
||||
}
|
||||
if (info.tag_type) |TagType| {
|
||||
comptime var v = void;
|
||||
inline for (info.fields) |u_field| {
|
||||
if (value == @field(TagType, u_field.name))
|
||||
v = @field(value, u_field.name);
|
||||
}
|
||||
try writeArray(writer, .{
|
||||
@typeName(T),
|
||||
@tagName(@as(TagType, value)),
|
||||
v,
|
||||
});
|
||||
} else {
|
||||
try writeArray(writer, .{@typeName(T)});
|
||||
}
|
||||
},
|
||||
.@"union" => |info| return writeUnion(writer, value, info),
|
||||
.@"struct" => |info| {
|
||||
if (std.meta.hasFn(T, "cborEncode")) {
|
||||
return value.cborEncode(writer);
|
||||
|
|
@ -501,6 +523,17 @@ pub fn matchIntValue(comptime T: type, iter: *[]const u8, val: T) Error!bool {
|
|||
return if (try matchInt(T, iter, &v)) v == val else false;
|
||||
}
|
||||
|
||||
pub fn matchNull(iter_: *[]const u8) Error!bool {
|
||||
var iter = iter_.*;
|
||||
|
||||
if (iter.len > 0 and iter[0] == cbor_magic_null) {
|
||||
iter_.* = iter[1..];
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
pub fn matchBool(iter_: *[]const u8, v: *bool) Error!bool {
|
||||
var iter = iter_.*;
|
||||
const t = try decodeType(&iter);
|
||||
|
|
@ -555,6 +588,191 @@ fn matchEnumValue(comptime T: type, iter: *[]const u8, val: T) Error!bool {
|
|||
return matchStringValue(iter, @tagName(val));
|
||||
}
|
||||
|
||||
fn matchUnionScalar(comptime T: type, iter_: *[]const u8, val_: *T) Error!bool {
|
||||
var iter = iter_.*;
|
||||
|
||||
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
|
||||
error.InvalidArrayType => return false,
|
||||
error.InvalidPIntType => return e,
|
||||
error.TooShort => return e,
|
||||
};
|
||||
if (n == 0) return false;
|
||||
|
||||
const TagType = std.meta.Tag(T);
|
||||
var unionTag: TagType = undefined;
|
||||
if (!try matchEnum(TagType, &iter, &unionTag)) return false;
|
||||
|
||||
inline for (comptime std.meta.tags(TagType)) |t_| {
|
||||
if (t_ == unionTag) {
|
||||
const Payload = std.meta.TagPayload(T, t_);
|
||||
|
||||
if (Payload == void) {
|
||||
if (n != 1) return false;
|
||||
val_.* = t_;
|
||||
iter_.* = iter;
|
||||
return true;
|
||||
} else {
|
||||
if (n != 2) return false;
|
||||
var val: Payload = undefined;
|
||||
if (try matchValue(&iter, extract(&val))) {
|
||||
val_.* = @unionInit(T, @tagName(t_), val);
|
||||
iter_.* = iter;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
fn matchUnionAlloc(comptime T: type, iter_: *[]const u8, val_: *T, allocator: std.mem.Allocator) Error!bool {
|
||||
var iter = iter_.*;
|
||||
|
||||
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
|
||||
error.InvalidArrayType => return false,
|
||||
error.InvalidPIntType => return e,
|
||||
error.TooShort => return e,
|
||||
};
|
||||
if (n == 0) return false;
|
||||
|
||||
const TagType = std.meta.Tag(T);
|
||||
var unionTag: TagType = undefined;
|
||||
if (!try matchEnum(TagType, &iter, &unionTag)) return false;
|
||||
|
||||
inline for (comptime std.meta.tags(TagType)) |t_| {
|
||||
if (t_ == unionTag) {
|
||||
const Payload = std.meta.TagPayload(T, t_);
|
||||
|
||||
if (Payload == void) {
|
||||
if (n != 1) return false;
|
||||
val_.* = t_;
|
||||
iter_.* = iter;
|
||||
return true;
|
||||
} else {
|
||||
if (n != 2) return false;
|
||||
var val: Payload = undefined;
|
||||
if (try matchValue(&iter, extractAlloc(&val, allocator))) {
|
||||
val_.* = @unionInit(T, @tagName(t_), val);
|
||||
iter_.* = iter;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
fn matchUnionValue(comptime T: type, iter_: *[]const u8, val: T) Error!bool {
|
||||
switch (val) {
|
||||
inline else => |v, t| {
|
||||
var iter = iter_.*;
|
||||
|
||||
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
|
||||
error.InvalidArrayType => return false,
|
||||
error.InvalidPIntType => return e,
|
||||
error.TooShort => return e,
|
||||
};
|
||||
if (n == 0) return false;
|
||||
|
||||
if (!try matchEnumValue(std.meta.Tag(T), &iter, t)) return false;
|
||||
|
||||
if (std.meta.TagPayload(T, t) != void) {
|
||||
if (n != 2) return false;
|
||||
if (!try matchValue(&iter, v)) return false;
|
||||
} else {
|
||||
if (n != 1) return false;
|
||||
}
|
||||
|
||||
iter_.* = iter;
|
||||
return true;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn matchStructScalar(comptime T: type, iter_: *[]const u8, val_: *T) Error!bool {
|
||||
var iter = iter_.*;
|
||||
const info = @typeInfo(T).@"struct";
|
||||
|
||||
const len = decodeMapHeader(&iter) catch |err| switch (err) {
|
||||
error.TooShort => return false,
|
||||
error.InvalidMapType => return err,
|
||||
error.InvalidPIntType => return err,
|
||||
};
|
||||
|
||||
if (len != info.fields.len) return false;
|
||||
|
||||
if (info.fields.len == 0) {
|
||||
iter_.* = iter;
|
||||
val_.* = .{};
|
||||
return true;
|
||||
}
|
||||
|
||||
var val: T = undefined;
|
||||
|
||||
fields: for (0..info.fields.len) |_| {
|
||||
var fieldName: []const u8 = undefined;
|
||||
if (!try matchString(&iter, &fieldName)) return false;
|
||||
|
||||
inline for (info.fields) |f| {
|
||||
if (std.mem.eql(u8, f.name, fieldName)) {
|
||||
var fieldVal: @FieldType(T, f.name) = undefined;
|
||||
if (!try matchValue(&iter, extract(&fieldVal))) return false;
|
||||
@field(val, f.name) = fieldVal;
|
||||
continue :fields;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
val_.* = val;
|
||||
iter_.* = iter;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
fn matchStructAlloc(comptime T: type, iter_: *[]const u8, val_: *T, allocator: std.mem.Allocator) Error!bool {
|
||||
var iter = iter_.*;
|
||||
const info = @typeInfo(T).@"struct";
|
||||
|
||||
const len = decodeMapHeader(&iter) catch |err| switch (err) {
|
||||
error.TooShort => return false,
|
||||
error.InvalidMapType => return err,
|
||||
error.InvalidPIntType => return err,
|
||||
};
|
||||
|
||||
if (len != info.fields.len) return false;
|
||||
|
||||
if (info.fields.len == 0) {
|
||||
iter_.* = iter;
|
||||
val_.* = .{};
|
||||
return true;
|
||||
}
|
||||
|
||||
var val: T = undefined;
|
||||
|
||||
for (0..info.fields.len) |_| {
|
||||
var fieldName: []const u8 = undefined;
|
||||
if (!try matchString(&iter, &fieldName)) return false;
|
||||
|
||||
inline for (info.fields) |f| {
|
||||
if (std.mem.eql(u8, f.name, fieldName)) {
|
||||
var fieldVal: @FieldType(T, f.name) = undefined;
|
||||
if (!try matchValue(&iter, extractAlloc(&fieldVal, allocator))) return false;
|
||||
@field(val, f.name) = fieldVal;
|
||||
break;
|
||||
}
|
||||
} else return false;
|
||||
}
|
||||
|
||||
val_.* = val;
|
||||
iter_.* = iter;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
fn skipString(iter: *[]const u8, minor: u5) Error!void {
|
||||
const len: usize = @intCast(try decodePInt(iter, minor));
|
||||
if (iter.len < len)
|
||||
|
|
@ -683,14 +901,20 @@ pub fn matchValue(iter: *[]const u8, value: anytype) Error!bool {
|
|||
.many, .c => matchError(T),
|
||||
.slice => if (info.child == u8) matchStringValue(iter, value) else matchArray(iter, value, info),
|
||||
},
|
||||
.optional => if (value) |v| matchValue(iter, v) else matchNull(iter),
|
||||
.@"struct" => |info| if (info.is_tuple)
|
||||
matchArray(iter, value, info)
|
||||
// TODO: Add case for matching struct here
|
||||
else
|
||||
matchError(T),
|
||||
.array => |info| if (info.child == u8) matchStringValue(iter, &value) else matchArray(iter, value, info),
|
||||
.float => matchFloatValue(T, iter, value),
|
||||
.comptime_float => matchFloatValue(f64, iter, value),
|
||||
.@"enum" => matchEnumValue(T, iter, value),
|
||||
.@"union" => |info| if (info.tag_type) |_|
|
||||
matchUnionValue(T, iter, value)
|
||||
else
|
||||
@compileError("cannot match value type '" ++ @typeName(T) ++ "' to cbor stream"),
|
||||
else => @compileError("cannot match value type '" ++ @typeName(T) ++ "' to cbor stream"),
|
||||
};
|
||||
}
|
||||
|
|
@ -949,7 +1173,13 @@ fn GenericExtractorAlloc(T: type) type {
|
|||
.float => return matchFloat(T, iter, self.dest),
|
||||
.@"enum" => return matchEnum(T, iter, self.dest),
|
||||
.array => return matchArrayScalar(iter, self.dest),
|
||||
else => return self.dest.cborExtract(iter),
|
||||
else => if (@hasDecl(T, "cborExtract")) {
|
||||
return self.dest.cborExtract(iter);
|
||||
} else switch (comptime @typeInfo(T)) {
|
||||
.@"union" => return matchUnionAlloc(T, iter, self.dest, self.allocator),
|
||||
.@"struct" => return matchStructAlloc(T, iter, self.dest, self.allocator),
|
||||
else => @compileError(@typeName(T) ++ " (" ++ @tagName(@typeInfo(T)) ++ ") is and unsupported or invalid type for cbor extract, or implement cborExtract function"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1024,7 +1254,13 @@ fn Extractor(comptime T: type) type {
|
|||
.float => return matchFloat(T, iter, self.dest),
|
||||
.@"enum" => return matchEnum(T, iter, self.dest),
|
||||
.array => return matchArrayScalar(iter, self.dest),
|
||||
else => return self.dest.cborExtract(iter),
|
||||
else => if (@hasDecl(T, "cborExtract")) {
|
||||
return self.dest.cborExtract(iter);
|
||||
} else switch (comptime @typeInfo(T)) {
|
||||
.@"union" => return matchUnionScalar(T, iter, self.dest),
|
||||
.@"struct" => return matchStructScalar(T, iter, self.dest),
|
||||
else => @compileError("cannot extract type " ++ @typeName(T)),
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue