Added struct and union cases in extract and extractAlloc and match (except for structs). Added tests for enums, unions and structs.
This commit is contained in:
parent
6eccce0b98
commit
3bdf25183e
2 changed files with 437 additions and 22 deletions
278
src/cbor.zig
278
src/cbor.zig
|
@ -18,7 +18,9 @@ pub const Error = error{
|
|||
OutOfMemory,
|
||||
InvalidFloatType,
|
||||
InvalidArrayType,
|
||||
InvalidMapType,
|
||||
InvalidPIntType,
|
||||
InvalidUnion,
|
||||
JsonIncompatibleType,
|
||||
NotAnObject,
|
||||
BadArrayAllocExtract,
|
||||
|
@ -212,6 +214,44 @@ fn writeErrorset(writer: anytype, err: anyerror) @TypeOf(writer).Error!void {
|
|||
return writeString(writer, stream.getWritten());
|
||||
}
|
||||
|
||||
fn writeEnum(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
|
||||
const T = @TypeOf(value);
|
||||
|
||||
if (std.meta.hasFn(T, "cborEncode")) {
|
||||
return value.cborEncode(writer);
|
||||
}
|
||||
|
||||
return writeString(writer, @tagName(value));
|
||||
}
|
||||
|
||||
fn writeUnion(writer: anytype, value: anytype, info: std.builtin.Type.Union) @TypeOf(writer).Error!void {
|
||||
const T = @TypeOf(value);
|
||||
|
||||
if (std.meta.hasFn(T, "cborEncode")) {
|
||||
return value.cborEncode(writer);
|
||||
}
|
||||
if (info.tag_type) |TagType| {
|
||||
inline for (info.fields) |u_field| {
|
||||
const t = @field(TagType, u_field.name);
|
||||
if (value == t) {
|
||||
const Payload = std.meta.TagPayload(T, t);
|
||||
if (Payload != void) {
|
||||
try writeArrayHeader(writer, 2);
|
||||
try writeEnum(writer, value);
|
||||
return try writeValue(writer, @field(value, u_field.name));
|
||||
} else {
|
||||
try writeArrayHeader(writer, 1);
|
||||
try writeEnum(writer, value);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
} else unreachable;
|
||||
} else {
|
||||
@compileError("cannot write untagged union '" ++ @typeName(T) ++ "' to cbor stream");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
|
||||
const T = @TypeOf(value);
|
||||
switch (@typeInfo(T)) {
|
||||
|
@ -220,25 +260,7 @@ pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void {
|
|||
.optional => return if (value) |v| writeValue(writer, v) else writeNull(writer),
|
||||
.error_union => return if (value) |v| writeValue(writer, v) else |err| writeValue(writer, err),
|
||||
.error_set => return writeErrorset(writer, value),
|
||||
.@"union" => |info| {
|
||||
if (std.meta.hasFn(T, "cborEncode")) {
|
||||
return value.cborEncode(writer);
|
||||
}
|
||||
if (info.tag_type) |TagType| {
|
||||
comptime var v = void;
|
||||
inline for (info.fields) |u_field| {
|
||||
if (value == @field(TagType, u_field.name))
|
||||
v = @field(value, u_field.name);
|
||||
}
|
||||
try writeArray(writer, .{
|
||||
@typeName(T),
|
||||
@tagName(@as(TagType, value)),
|
||||
v,
|
||||
});
|
||||
} else {
|
||||
try writeArray(writer, .{@typeName(T)});
|
||||
}
|
||||
},
|
||||
.@"union" => |info| return writeUnion(writer, value, info),
|
||||
.@"struct" => |info| {
|
||||
if (std.meta.hasFn(T, "cborEncode")) {
|
||||
return value.cborEncode(writer);
|
||||
|
@ -501,6 +523,17 @@ pub fn matchIntValue(comptime T: type, iter: *[]const u8, val: T) Error!bool {
|
|||
return if (try matchInt(T, iter, &v)) v == val else false;
|
||||
}
|
||||
|
||||
pub fn matchNull(iter_: *[]const u8) Error!bool {
|
||||
var iter = iter_.*;
|
||||
|
||||
if (iter.len > 0 and iter[0] == cbor_magic_null) {
|
||||
iter_.* = iter[1..];
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
pub fn matchBool(iter_: *[]const u8, v: *bool) Error!bool {
|
||||
var iter = iter_.*;
|
||||
const t = try decodeType(&iter);
|
||||
|
@ -555,6 +588,191 @@ fn matchEnumValue(comptime T: type, iter: *[]const u8, val: T) Error!bool {
|
|||
return matchStringValue(iter, @tagName(val));
|
||||
}
|
||||
|
||||
fn matchUnionScalar(comptime T: type, iter_: *[]const u8, val_: *T) Error!bool {
|
||||
var iter = iter_.*;
|
||||
|
||||
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
|
||||
error.InvalidArrayType => return false,
|
||||
error.InvalidPIntType => return e,
|
||||
error.TooShort => return e,
|
||||
};
|
||||
if (n == 0) return false;
|
||||
|
||||
const TagType = std.meta.Tag(T);
|
||||
var unionTag: TagType = undefined;
|
||||
if (!try matchEnum(TagType, &iter, &unionTag)) return false;
|
||||
|
||||
inline for (comptime std.meta.tags(TagType)) |t_| {
|
||||
if (t_ == unionTag) {
|
||||
const Payload = std.meta.TagPayload(T, t_);
|
||||
|
||||
if (Payload == void) {
|
||||
if (n != 1) return false;
|
||||
val_.* = t_;
|
||||
iter_.* = iter;
|
||||
return true;
|
||||
} else {
|
||||
if (n != 2) return false;
|
||||
var val: Payload = undefined;
|
||||
if (try matchValue(&iter, extract(&val))) {
|
||||
val_.* = @unionInit(T, @tagName(t_), val);
|
||||
iter_.* = iter;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
fn matchUnionAlloc(comptime T: type, iter_: *[]const u8, val_: *T, allocator: std.mem.Allocator) Error!bool {
|
||||
var iter = iter_.*;
|
||||
|
||||
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
|
||||
error.InvalidArrayType => return false,
|
||||
error.InvalidPIntType => return e,
|
||||
error.TooShort => return e,
|
||||
};
|
||||
if (n == 0) return false;
|
||||
|
||||
const TagType = std.meta.Tag(T);
|
||||
var unionTag: TagType = undefined;
|
||||
if (!try matchEnum(TagType, &iter, &unionTag)) return false;
|
||||
|
||||
inline for (comptime std.meta.tags(TagType)) |t_| {
|
||||
if (t_ == unionTag) {
|
||||
const Payload = std.meta.TagPayload(T, t_);
|
||||
|
||||
if (Payload == void) {
|
||||
if (n != 1) return false;
|
||||
val_.* = t_;
|
||||
iter_.* = iter;
|
||||
return true;
|
||||
} else {
|
||||
if (n != 2) return false;
|
||||
var val: Payload = undefined;
|
||||
if (try matchValue(&iter, extractAlloc(&val, allocator))) {
|
||||
val_.* = @unionInit(T, @tagName(t_), val);
|
||||
iter_.* = iter;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
fn matchUnionValue(comptime T: type, iter_: *[]const u8, val: T) Error!bool {
|
||||
switch (val) {
|
||||
inline else => |v, t| {
|
||||
var iter = iter_.*;
|
||||
|
||||
const n = decodeArrayHeader(&iter) catch |e| switch (e) {
|
||||
error.InvalidArrayType => return false,
|
||||
error.InvalidPIntType => return e,
|
||||
error.TooShort => return e,
|
||||
};
|
||||
if (n == 0) return false;
|
||||
|
||||
if (!try matchEnumValue(std.meta.Tag(T), &iter, t)) return false;
|
||||
|
||||
if (std.meta.TagPayload(T, t) != void) {
|
||||
if (n != 2) return false;
|
||||
if (!try matchValue(&iter, v)) return false;
|
||||
} else {
|
||||
if (n != 1) return false;
|
||||
}
|
||||
|
||||
iter_.* = iter;
|
||||
return true;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn matchStructScalar(comptime T: type, iter_: *[]const u8, val_: *T) Error!bool {
|
||||
var iter = iter_.*;
|
||||
const info = @typeInfo(T).@"struct";
|
||||
|
||||
const len = decodeMapHeader(&iter) catch |err| switch (err) {
|
||||
error.TooShort => return false,
|
||||
error.InvalidMapType => return err,
|
||||
error.InvalidPIntType => return err,
|
||||
};
|
||||
|
||||
if (len != info.fields.len) return false;
|
||||
|
||||
if (info.fields.len == 0) {
|
||||
iter_.* = iter;
|
||||
val_.* = .{};
|
||||
return true;
|
||||
}
|
||||
|
||||
var val: T = undefined;
|
||||
|
||||
fields: for (0..info.fields.len) |_| {
|
||||
var fieldName: []const u8 = undefined;
|
||||
if (!try matchString(&iter, &fieldName)) return false;
|
||||
|
||||
inline for (info.fields) |f| {
|
||||
if (std.mem.eql(u8, f.name, fieldName)) {
|
||||
var fieldVal: @FieldType(T, f.name) = undefined;
|
||||
if (!try matchValue(&iter, extract(&fieldVal))) return false;
|
||||
@field(val, f.name) = fieldVal;
|
||||
continue :fields;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
val_.* = val;
|
||||
iter_.* = iter;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
fn matchStructAlloc(comptime T: type, iter_: *[]const u8, val_: *T, allocator: std.mem.Allocator) Error!bool {
|
||||
var iter = iter_.*;
|
||||
const info = @typeInfo(T).@"struct";
|
||||
|
||||
const len = decodeMapHeader(&iter) catch |err| switch (err) {
|
||||
error.TooShort => return false,
|
||||
error.InvalidMapType => return err,
|
||||
error.InvalidPIntType => return err,
|
||||
};
|
||||
|
||||
if (len != info.fields.len) return false;
|
||||
|
||||
if (info.fields.len == 0) {
|
||||
iter_.* = iter;
|
||||
val_.* = .{};
|
||||
return true;
|
||||
}
|
||||
|
||||
var val: T = undefined;
|
||||
|
||||
for (0..info.fields.len) |_| {
|
||||
var fieldName: []const u8 = undefined;
|
||||
if (!try matchString(&iter, &fieldName)) return false;
|
||||
|
||||
inline for (info.fields) |f| {
|
||||
if (std.mem.eql(u8, f.name, fieldName)) {
|
||||
var fieldVal: @FieldType(T, f.name) = undefined;
|
||||
if (!try matchValue(&iter, extractAlloc(&fieldVal, allocator))) return false;
|
||||
@field(val, f.name) = fieldVal;
|
||||
break;
|
||||
}
|
||||
} else return false;
|
||||
}
|
||||
|
||||
val_.* = val;
|
||||
iter_.* = iter;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
fn skipString(iter: *[]const u8, minor: u5) Error!void {
|
||||
const len: usize = @intCast(try decodePInt(iter, minor));
|
||||
if (iter.len < len)
|
||||
|
@ -683,14 +901,20 @@ pub fn matchValue(iter: *[]const u8, value: anytype) Error!bool {
|
|||
.many, .c => matchError(T),
|
||||
.slice => if (info.child == u8) matchStringValue(iter, value) else matchArray(iter, value, info),
|
||||
},
|
||||
.optional => if (value) |v| matchValue(iter, v) else matchNull(iter),
|
||||
.@"struct" => |info| if (info.is_tuple)
|
||||
matchArray(iter, value, info)
|
||||
// TODO: Add case for matching struct here
|
||||
else
|
||||
matchError(T),
|
||||
.array => |info| if (info.child == u8) matchStringValue(iter, &value) else matchArray(iter, value, info),
|
||||
.float => matchFloatValue(T, iter, value),
|
||||
.comptime_float => matchFloatValue(f64, iter, value),
|
||||
.@"enum" => matchEnumValue(T, iter, value),
|
||||
.@"union" => |info| if (info.tag_type) |_|
|
||||
matchUnionValue(T, iter, value)
|
||||
else
|
||||
@compileError("cannot match value type '" ++ @typeName(T) ++ "' to cbor stream"),
|
||||
else => @compileError("cannot match value type '" ++ @typeName(T) ++ "' to cbor stream"),
|
||||
};
|
||||
}
|
||||
|
@ -949,7 +1173,13 @@ fn GenericExtractorAlloc(T: type) type {
|
|||
.float => return matchFloat(T, iter, self.dest),
|
||||
.@"enum" => return matchEnum(T, iter, self.dest),
|
||||
.array => return matchArrayScalar(iter, self.dest),
|
||||
else => return self.dest.cborExtract(iter),
|
||||
else => if (@hasDecl(T, "cborExtract")) {
|
||||
return self.dest.cborExtract(iter);
|
||||
} else switch (comptime @typeInfo(T)) {
|
||||
.@"union" => return matchUnionAlloc(T, iter, self.dest, self.allocator),
|
||||
.@"struct" => return matchStructAlloc(T, iter, self.dest, self.allocator),
|
||||
else => @compileError(@typeName(T) ++ " (" ++ @tagName(@typeInfo(T)) ++ ") is and unsupported or invalid type for cbor extract, or implement cborExtract function"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1024,7 +1254,13 @@ fn Extractor(comptime T: type) type {
|
|||
.float => return matchFloat(T, iter, self.dest),
|
||||
.@"enum" => return matchEnum(T, iter, self.dest),
|
||||
.array => return matchArrayScalar(iter, self.dest),
|
||||
else => return self.dest.cborExtract(iter),
|
||||
else => if (@hasDecl(T, "cborExtract")) {
|
||||
return self.dest.cborExtract(iter);
|
||||
} else switch (comptime @typeInfo(T)) {
|
||||
.@"union" => return matchUnionScalar(T, iter, self.dest),
|
||||
.@"struct" => return matchStructScalar(T, iter, self.dest),
|
||||
else => @compileError("cannot extract type " ++ @typeName(T)),
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
181
test/tests.zig
181
test/tests.zig
|
@ -24,6 +24,7 @@ const writeArrayHeader = cbor_mod.writeArrayHeader;
|
|||
const writeMapHeader = cbor_mod.writeMapHeader;
|
||||
const writeValue = cbor_mod.writeValue;
|
||||
const extract = cbor_mod.extract;
|
||||
const extractAlloc = cbor_mod.extractAlloc;
|
||||
const extract_cbor = cbor_mod.extract_cbor;
|
||||
|
||||
const more = cbor_mod.more;
|
||||
|
@ -520,5 +521,183 @@ test "cbor.extract_cbor f64" {
|
|||
const json2 = try toJsonAlloc(std.testing.allocator, sub);
|
||||
defer std.testing.allocator.free(json2);
|
||||
|
||||
try expectEqualStrings("[0.9689138531684875]", json2);
|
||||
try expectEqualStrings("[9.689138531684875e-1]", json2);
|
||||
}
|
||||
|
||||
test "cbor.writeValue enum" {
|
||||
const TestEnum = enum { a, b };
|
||||
var buf: [128]u8 = undefined;
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
const writer = stream.writer();
|
||||
try writeValue(writer, TestEnum.a);
|
||||
const expected = @tagName(TestEnum.a);
|
||||
var m = std.json.Value{ .null = {} };
|
||||
try expect(try match(stream.getWritten(), extract(&m)));
|
||||
try expectEqualStrings(expected, m.string);
|
||||
}
|
||||
|
||||
test "cbor.extract enum" {
|
||||
const TestEnum = enum { a, b };
|
||||
var buf: [128]u8 = undefined;
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
const writer = stream.writer();
|
||||
try writeValue(writer, TestEnum.a);
|
||||
var m: TestEnum = undefined;
|
||||
try expect(try match(stream.getWritten(), extract(&m)));
|
||||
try expectEqualStrings(@tagName(m), @tagName(TestEnum.a));
|
||||
try expect(m == .a);
|
||||
}
|
||||
|
||||
test "cbor.writeValue tagged union" {
|
||||
const TestUnion = union(enum) { a: f32, b: []const u8 };
|
||||
var buf: [128]u8 = undefined;
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
const writer = stream.writer();
|
||||
try writeValue(writer, TestUnion{ .b = "should work" });
|
||||
var tagName = std.json.Value{ .null = {} };
|
||||
var value = std.json.Value{ .null = {} };
|
||||
var iter: []const u8 = stream.getWritten();
|
||||
try expect(try matchValue(&iter, .{ extract(&tagName), extract(&value) }));
|
||||
try expectEqualStrings(@tagName(TestUnion.b), tagName.string);
|
||||
try expectEqualStrings("should work", value.string);
|
||||
}
|
||||
|
||||
test "cbor.writeValue tagged union no payload types" {
|
||||
const TestUnion = union(enum) { a, b };
|
||||
var buf: [128]u8 = undefined;
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
const writer = stream.writer();
|
||||
try writeValue(writer, TestUnion.b);
|
||||
|
||||
try expect(try match(stream.getWritten(), @tagName(TestUnion.b)));
|
||||
var tagName = std.json.Value{ .null = {} };
|
||||
try expect(try match(stream.getWritten(), extract(&tagName)));
|
||||
try expectEqualStrings(@tagName(TestUnion.b), tagName.string);
|
||||
}
|
||||
|
||||
test "cbor.writeValue nested union json" {
|
||||
const TestUnion = union(enum) {
|
||||
a: f32,
|
||||
b: []const u8,
|
||||
c: []const union(enum) {
|
||||
d: f32,
|
||||
e: i32,
|
||||
f,
|
||||
},
|
||||
};
|
||||
var buf: [256]u8 = undefined;
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
const writer = stream.writer();
|
||||
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||
defer arena.deinit();
|
||||
try writeValue(writer, TestUnion{ .c = &.{ .{ .d = 1.5 }, .{ .e = 5 }, .f } });
|
||||
var json_buf: [128]u8 = undefined;
|
||||
const json = try toJson(stream.getWritten(), &json_buf);
|
||||
try expectEqualStrings(json,
|
||||
\\["c",[["d",1.5e0],["e",5],["f"]]]
|
||||
);
|
||||
}
|
||||
|
||||
test "cbor.extract tagged union" {
|
||||
const TestUnion = union(enum) { a: f32, b: []const u8 };
|
||||
var buf: [128]u8 = undefined;
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
const writer = stream.writer();
|
||||
try writeValue(writer, TestUnion{ .b = "should work" });
|
||||
var m: TestUnion = undefined;
|
||||
try expect(try match(stream.getWritten(), extract(&m)));
|
||||
try expectEqualDeep(TestUnion{ .b = "should work" }, m);
|
||||
}
|
||||
|
||||
test "cbor.extract nested union" {
|
||||
const TestUnion = union(enum) {
|
||||
a: f32,
|
||||
b: []const u8,
|
||||
c: []const union(enum) {
|
||||
d: f32,
|
||||
e: i32,
|
||||
f,
|
||||
},
|
||||
};
|
||||
var buf: [256]u8 = undefined;
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
const writer = stream.writer();
|
||||
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||
const allocator = arena.allocator();
|
||||
defer arena.deinit();
|
||||
try writeValue(writer, TestUnion{ .c = &.{ .{ .d = 1.5 }, .{ .e = 5 }, .f } });
|
||||
var m: TestUnion = undefined;
|
||||
try expect(try match(stream.getWritten(), extractAlloc(&m, allocator)));
|
||||
try expectEqualDeep(TestUnion{ .c = &.{ .{ .d = 1.5 }, .{ .e = 5 }, .f } }, m);
|
||||
}
|
||||
|
||||
test "cbor.extract struct no fields" {
|
||||
const TestStruct = struct {};
|
||||
var buf: [128]u8 = undefined;
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
const writer = stream.writer();
|
||||
try writeValue(writer, TestStruct{});
|
||||
var m: TestStruct = undefined;
|
||||
try expect(try match(stream.getWritten(), extract(&m)));
|
||||
try expectEqualDeep(TestStruct{}, m);
|
||||
}
|
||||
|
||||
test "cbor.extract_cbor struct" {
|
||||
const TestStruct = struct {
|
||||
a: f32,
|
||||
b: []const u8,
|
||||
};
|
||||
var buf: [128]u8 = undefined;
|
||||
const v = TestStruct{ .a = 1.5, .b = "hello" };
|
||||
const m = fmt(&buf, v);
|
||||
var map_cbor: []const u8 = undefined;
|
||||
try expect(try match(m, extract_cbor(&map_cbor)));
|
||||
var json_buf: [256]u8 = undefined;
|
||||
const json = try toJson(map_cbor, &json_buf);
|
||||
try expectEqualStrings(json,
|
||||
\\{"a":1.5e0,"b":"hello"}
|
||||
);
|
||||
}
|
||||
|
||||
test "cbor.extract struct" {
|
||||
const TestStruct = struct {
|
||||
a: f32,
|
||||
b: []const u8,
|
||||
};
|
||||
var buf: [128]u8 = undefined;
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
const writer = stream.writer();
|
||||
const v = TestStruct{ .a = 1.5, .b = "hello" };
|
||||
try writeValue(writer, v);
|
||||
var obj: TestStruct = undefined;
|
||||
var json_buf: [256]u8 = undefined;
|
||||
var iter: []const u8 = stream.getWritten();
|
||||
const t = try decodeType(&iter);
|
||||
try expectEqual(5, t.major);
|
||||
const json = try toJson(stream.getWritten(), &json_buf);
|
||||
try expectEqualStrings(json,
|
||||
\\{"a":1.5e0,"b":"hello"}
|
||||
);
|
||||
try expect(try match(stream.getWritten(), extract(&obj)));
|
||||
try expectEqual(1.5, obj.a);
|
||||
try expectEqualStrings("hello", obj.b);
|
||||
}
|
||||
|
||||
test "cbor.extractAlloc struct" {
|
||||
const TestStruct = struct {
|
||||
a: f32,
|
||||
b: []const u8,
|
||||
};
|
||||
var buf: [128]u8 = undefined;
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||
const allocator = arena.allocator();
|
||||
defer arena.deinit();
|
||||
const writer = stream.writer();
|
||||
const v = TestStruct{ .a = 1.5, .b = "hello" };
|
||||
try writeValue(writer, v);
|
||||
var obj: TestStruct = undefined;
|
||||
try expect(try match(stream.getWritten(), extractAlloc(&obj, allocator)));
|
||||
try expectEqual(1.5, obj.a);
|
||||
try expectEqualStrings("hello", obj.b);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue