diff --git a/src/cbor.zig b/src/cbor.zig index ac2b2f0..a4d53e2 100644 --- a/src/cbor.zig +++ b/src/cbor.zig @@ -18,7 +18,9 @@ pub const Error = error{ OutOfMemory, InvalidFloatType, InvalidArrayType, + InvalidMapType, InvalidPIntType, + InvalidUnion, JsonIncompatibleType, NotAnObject, BadArrayAllocExtract, @@ -212,6 +214,44 @@ fn writeErrorset(writer: anytype, err: anyerror) @TypeOf(writer).Error!void { return writeString(writer, stream.getWritten()); } +fn writeEnum(writer: anytype, value: anytype) @TypeOf(writer).Error!void { + const T = @TypeOf(value); + + if (std.meta.hasFn(T, "cborEncode")) { + return value.cborEncode(writer); + } + + return writeString(writer, @tagName(value)); +} + +fn writeUnion(writer: anytype, value: anytype, info: std.builtin.Type.Union) @TypeOf(writer).Error!void { + const T = @TypeOf(value); + + if (std.meta.hasFn(T, "cborEncode")) { + return value.cborEncode(writer); + } + if (info.tag_type) |TagType| { + inline for (info.fields) |u_field| { + const t = @field(TagType, u_field.name); + if (value == t) { + const Payload = std.meta.TagPayload(T, t); + if (Payload != void) { + try writeArrayHeader(writer, 2); + try writeEnum(writer, value); + return try writeValue(writer, @field(value, u_field.name)); + } else { + try writeArrayHeader(writer, 1); + try writeEnum(writer, value); + } + + return; + } + } else unreachable; + } else { + @compileError("cannot write untagged union '" ++ @typeName(T) ++ "' to cbor stream"); + } +} + pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void { const T = @TypeOf(value); switch (@typeInfo(T)) { @@ -220,25 +260,7 @@ pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void { .optional => return if (value) |v| writeValue(writer, v) else writeNull(writer), .error_union => return if (value) |v| writeValue(writer, v) else |err| writeValue(writer, err), .error_set => return writeErrorset(writer, value), - .@"union" => |info| { - if (std.meta.hasFn(T, "cborEncode")) { - return value.cborEncode(writer); - } - if (info.tag_type) |TagType| { - comptime var v = void; - inline for (info.fields) |u_field| { - if (value == @field(TagType, u_field.name)) - v = @field(value, u_field.name); - } - try writeArray(writer, .{ - @typeName(T), - @tagName(@as(TagType, value)), - v, - }); - } else { - try writeArray(writer, .{@typeName(T)}); - } - }, + .@"union" => |info| return writeUnion(writer, value, info), .@"struct" => |info| { if (std.meta.hasFn(T, "cborEncode")) { return value.cborEncode(writer); @@ -501,6 +523,17 @@ pub fn matchIntValue(comptime T: type, iter: *[]const u8, val: T) Error!bool { return if (try matchInt(T, iter, &v)) v == val else false; } +pub fn matchNull(iter_: *[]const u8) Error!bool { + var iter = iter_.*; + + if (iter.len > 0 and iter[0] == cbor_magic_null) { + iter_.* = iter[1..]; + return true; + } + + return false; +} + pub fn matchBool(iter_: *[]const u8, v: *bool) Error!bool { var iter = iter_.*; const t = try decodeType(&iter); @@ -555,6 +588,191 @@ fn matchEnumValue(comptime T: type, iter: *[]const u8, val: T) Error!bool { return matchStringValue(iter, @tagName(val)); } +fn matchUnionScalar(comptime T: type, iter_: *[]const u8, val_: *T) Error!bool { + var iter = iter_.*; + + const n = decodeArrayHeader(&iter) catch |e| switch (e) { + error.InvalidArrayType => return false, + error.InvalidPIntType => return e, + error.TooShort => return e, + }; + if (n == 0) return false; + + const TagType = std.meta.Tag(T); + var unionTag: TagType = undefined; + if (!try matchEnum(TagType, &iter, &unionTag)) return false; + + inline for (comptime std.meta.tags(TagType)) |t_| { + if (t_ == unionTag) { + const Payload = std.meta.TagPayload(T, t_); + + if (Payload == void) { + if (n != 1) return false; + val_.* = t_; + iter_.* = iter; + return true; + } else { + if (n != 2) return false; + var val: Payload = undefined; + if (try matchValue(&iter, extract(&val))) { + val_.* = @unionInit(T, @tagName(t_), val); + iter_.* = iter; + return true; + } + } + } + } + + return false; +} + +fn matchUnionAlloc(comptime T: type, iter_: *[]const u8, val_: *T, allocator: std.mem.Allocator) Error!bool { + var iter = iter_.*; + + const n = decodeArrayHeader(&iter) catch |e| switch (e) { + error.InvalidArrayType => return false, + error.InvalidPIntType => return e, + error.TooShort => return e, + }; + if (n == 0) return false; + + const TagType = std.meta.Tag(T); + var unionTag: TagType = undefined; + if (!try matchEnum(TagType, &iter, &unionTag)) return false; + + inline for (comptime std.meta.tags(TagType)) |t_| { + if (t_ == unionTag) { + const Payload = std.meta.TagPayload(T, t_); + + if (Payload == void) { + if (n != 1) return false; + val_.* = t_; + iter_.* = iter; + return true; + } else { + if (n != 2) return false; + var val: Payload = undefined; + if (try matchValue(&iter, extractAlloc(&val, allocator))) { + val_.* = @unionInit(T, @tagName(t_), val); + iter_.* = iter; + return true; + } + } + } + } + + return false; +} + +fn matchUnionValue(comptime T: type, iter_: *[]const u8, val: T) Error!bool { + switch (val) { + inline else => |v, t| { + var iter = iter_.*; + + const n = decodeArrayHeader(&iter) catch |e| switch (e) { + error.InvalidArrayType => return false, + error.InvalidPIntType => return e, + error.TooShort => return e, + }; + if (n == 0) return false; + + if (!try matchEnumValue(std.meta.Tag(T), &iter, t)) return false; + + if (std.meta.TagPayload(T, t) != void) { + if (n != 2) return false; + if (!try matchValue(&iter, v)) return false; + } else { + if (n != 1) return false; + } + + iter_.* = iter; + return true; + }, + } +} + +fn matchStructScalar(comptime T: type, iter_: *[]const u8, val_: *T) Error!bool { + var iter = iter_.*; + const info = @typeInfo(T).@"struct"; + + const len = decodeMapHeader(&iter) catch |err| switch (err) { + error.TooShort => return false, + error.InvalidMapType => return err, + error.InvalidPIntType => return err, + }; + + if (len != info.fields.len) return false; + + if (info.fields.len == 0) { + iter_.* = iter; + val_.* = .{}; + return true; + } + + var val: T = undefined; + + fields: for (0..info.fields.len) |_| { + var fieldName: []const u8 = undefined; + if (!try matchString(&iter, &fieldName)) return false; + + inline for (info.fields) |f| { + if (std.mem.eql(u8, f.name, fieldName)) { + var fieldVal: @FieldType(T, f.name) = undefined; + if (!try matchValue(&iter, extract(&fieldVal))) return false; + @field(val, f.name) = fieldVal; + continue :fields; + } + } + + return false; + } + + val_.* = val; + iter_.* = iter; + + return true; +} + +fn matchStructAlloc(comptime T: type, iter_: *[]const u8, val_: *T, allocator: std.mem.Allocator) Error!bool { + var iter = iter_.*; + const info = @typeInfo(T).@"struct"; + + const len = decodeMapHeader(&iter) catch |err| switch (err) { + error.TooShort => return false, + error.InvalidMapType => return err, + error.InvalidPIntType => return err, + }; + + if (len != info.fields.len) return false; + + if (info.fields.len == 0) { + iter_.* = iter; + val_.* = .{}; + return true; + } + + var val: T = undefined; + + for (0..info.fields.len) |_| { + var fieldName: []const u8 = undefined; + if (!try matchString(&iter, &fieldName)) return false; + + inline for (info.fields) |f| { + if (std.mem.eql(u8, f.name, fieldName)) { + var fieldVal: @FieldType(T, f.name) = undefined; + if (!try matchValue(&iter, extractAlloc(&fieldVal, allocator))) return false; + @field(val, f.name) = fieldVal; + break; + } + } else return false; + } + + val_.* = val; + iter_.* = iter; + + return true; +} + fn skipString(iter: *[]const u8, minor: u5) Error!void { const len: usize = @intCast(try decodePInt(iter, minor)); if (iter.len < len) @@ -683,14 +901,20 @@ pub fn matchValue(iter: *[]const u8, value: anytype) Error!bool { .many, .c => matchError(T), .slice => if (info.child == u8) matchStringValue(iter, value) else matchArray(iter, value, info), }, + .optional => if (value) |v| matchValue(iter, v) else matchNull(iter), .@"struct" => |info| if (info.is_tuple) matchArray(iter, value, info) + // TODO: Add case for matching struct here else matchError(T), .array => |info| if (info.child == u8) matchStringValue(iter, &value) else matchArray(iter, value, info), .float => matchFloatValue(T, iter, value), .comptime_float => matchFloatValue(f64, iter, value), .@"enum" => matchEnumValue(T, iter, value), + .@"union" => |info| if (info.tag_type) |_| + matchUnionValue(T, iter, value) + else + @compileError("cannot match value type '" ++ @typeName(T) ++ "' to cbor stream"), else => @compileError("cannot match value type '" ++ @typeName(T) ++ "' to cbor stream"), }; } @@ -949,7 +1173,13 @@ fn GenericExtractorAlloc(T: type) type { .float => return matchFloat(T, iter, self.dest), .@"enum" => return matchEnum(T, iter, self.dest), .array => return matchArrayScalar(iter, self.dest), - else => return self.dest.cborExtract(iter), + else => if (@hasDecl(T, "cborExtract")) { + return self.dest.cborExtract(iter); + } else switch (comptime @typeInfo(T)) { + .@"union" => return matchUnionAlloc(T, iter, self.dest, self.allocator), + .@"struct" => return matchStructAlloc(T, iter, self.dest, self.allocator), + else => @compileError(@typeName(T) ++ " (" ++ @tagName(@typeInfo(T)) ++ ") is and unsupported or invalid type for cbor extract, or implement cborExtract function"), + }, } } } @@ -1024,7 +1254,13 @@ fn Extractor(comptime T: type) type { .float => return matchFloat(T, iter, self.dest), .@"enum" => return matchEnum(T, iter, self.dest), .array => return matchArrayScalar(iter, self.dest), - else => return self.dest.cborExtract(iter), + else => if (@hasDecl(T, "cborExtract")) { + return self.dest.cborExtract(iter); + } else switch (comptime @typeInfo(T)) { + .@"union" => return matchUnionScalar(T, iter, self.dest), + .@"struct" => return matchStructScalar(T, iter, self.dest), + else => @compileError("cannot extract type " ++ @typeName(T)), + }, } } }; diff --git a/test/tests.zig b/test/tests.zig index 6b0f307..6572639 100644 --- a/test/tests.zig +++ b/test/tests.zig @@ -24,6 +24,7 @@ const writeArrayHeader = cbor_mod.writeArrayHeader; const writeMapHeader = cbor_mod.writeMapHeader; const writeValue = cbor_mod.writeValue; const extract = cbor_mod.extract; +const extractAlloc = cbor_mod.extractAlloc; const extract_cbor = cbor_mod.extract_cbor; const more = cbor_mod.more; @@ -520,5 +521,183 @@ test "cbor.extract_cbor f64" { const json2 = try toJsonAlloc(std.testing.allocator, sub); defer std.testing.allocator.free(json2); - try expectEqualStrings("[0.9689138531684875]", json2); + try expectEqualStrings("[9.689138531684875e-1]", json2); +} + +test "cbor.writeValue enum" { + const TestEnum = enum { a, b }; + var buf: [128]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + const writer = stream.writer(); + try writeValue(writer, TestEnum.a); + const expected = @tagName(TestEnum.a); + var m = std.json.Value{ .null = {} }; + try expect(try match(stream.getWritten(), extract(&m))); + try expectEqualStrings(expected, m.string); +} + +test "cbor.extract enum" { + const TestEnum = enum { a, b }; + var buf: [128]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + const writer = stream.writer(); + try writeValue(writer, TestEnum.a); + var m: TestEnum = undefined; + try expect(try match(stream.getWritten(), extract(&m))); + try expectEqualStrings(@tagName(m), @tagName(TestEnum.a)); + try expect(m == .a); +} + +test "cbor.writeValue tagged union" { + const TestUnion = union(enum) { a: f32, b: []const u8 }; + var buf: [128]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + const writer = stream.writer(); + try writeValue(writer, TestUnion{ .b = "should work" }); + var tagName = std.json.Value{ .null = {} }; + var value = std.json.Value{ .null = {} }; + var iter: []const u8 = stream.getWritten(); + try expect(try matchValue(&iter, .{ extract(&tagName), extract(&value) })); + try expectEqualStrings(@tagName(TestUnion.b), tagName.string); + try expectEqualStrings("should work", value.string); +} + +test "cbor.writeValue tagged union no payload types" { + const TestUnion = union(enum) { a, b }; + var buf: [128]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + const writer = stream.writer(); + try writeValue(writer, TestUnion.b); + + try expect(try match(stream.getWritten(), @tagName(TestUnion.b))); + var tagName = std.json.Value{ .null = {} }; + try expect(try match(stream.getWritten(), extract(&tagName))); + try expectEqualStrings(@tagName(TestUnion.b), tagName.string); +} + +test "cbor.writeValue nested union json" { + const TestUnion = union(enum) { + a: f32, + b: []const u8, + c: []const union(enum) { + d: f32, + e: i32, + f, + }, + }; + var buf: [256]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + const writer = stream.writer(); + var arena = std.heap.ArenaAllocator.init(std.testing.allocator); + defer arena.deinit(); + try writeValue(writer, TestUnion{ .c = &.{ .{ .d = 1.5 }, .{ .e = 5 }, .f } }); + var json_buf: [128]u8 = undefined; + const json = try toJson(stream.getWritten(), &json_buf); + try expectEqualStrings(json, + \\["c",[["d",1.5e0],["e",5],["f"]]] + ); +} + +test "cbor.extract tagged union" { + const TestUnion = union(enum) { a: f32, b: []const u8 }; + var buf: [128]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + const writer = stream.writer(); + try writeValue(writer, TestUnion{ .b = "should work" }); + var m: TestUnion = undefined; + try expect(try match(stream.getWritten(), extract(&m))); + try expectEqualDeep(TestUnion{ .b = "should work" }, m); +} + +test "cbor.extract nested union" { + const TestUnion = union(enum) { + a: f32, + b: []const u8, + c: []const union(enum) { + d: f32, + e: i32, + f, + }, + }; + var buf: [256]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + const writer = stream.writer(); + var arena = std.heap.ArenaAllocator.init(std.testing.allocator); + const allocator = arena.allocator(); + defer arena.deinit(); + try writeValue(writer, TestUnion{ .c = &.{ .{ .d = 1.5 }, .{ .e = 5 }, .f } }); + var m: TestUnion = undefined; + try expect(try match(stream.getWritten(), extractAlloc(&m, allocator))); + try expectEqualDeep(TestUnion{ .c = &.{ .{ .d = 1.5 }, .{ .e = 5 }, .f } }, m); +} + +test "cbor.extract struct no fields" { + const TestStruct = struct {}; + var buf: [128]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + const writer = stream.writer(); + try writeValue(writer, TestStruct{}); + var m: TestStruct = undefined; + try expect(try match(stream.getWritten(), extract(&m))); + try expectEqualDeep(TestStruct{}, m); +} + +test "cbor.extract_cbor struct" { + const TestStruct = struct { + a: f32, + b: []const u8, + }; + var buf: [128]u8 = undefined; + const v = TestStruct{ .a = 1.5, .b = "hello" }; + const m = fmt(&buf, v); + var map_cbor: []const u8 = undefined; + try expect(try match(m, extract_cbor(&map_cbor))); + var json_buf: [256]u8 = undefined; + const json = try toJson(map_cbor, &json_buf); + try expectEqualStrings(json, + \\{"a":1.5e0,"b":"hello"} + ); +} + +test "cbor.extract struct" { + const TestStruct = struct { + a: f32, + b: []const u8, + }; + var buf: [128]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + const writer = stream.writer(); + const v = TestStruct{ .a = 1.5, .b = "hello" }; + try writeValue(writer, v); + var obj: TestStruct = undefined; + var json_buf: [256]u8 = undefined; + var iter: []const u8 = stream.getWritten(); + const t = try decodeType(&iter); + try expectEqual(5, t.major); + const json = try toJson(stream.getWritten(), &json_buf); + try expectEqualStrings(json, + \\{"a":1.5e0,"b":"hello"} + ); + try expect(try match(stream.getWritten(), extract(&obj))); + try expectEqual(1.5, obj.a); + try expectEqualStrings("hello", obj.b); +} + +test "cbor.extractAlloc struct" { + const TestStruct = struct { + a: f32, + b: []const u8, + }; + var buf: [128]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + var arena = std.heap.ArenaAllocator.init(std.testing.allocator); + const allocator = arena.allocator(); + defer arena.deinit(); + const writer = stream.writer(); + const v = TestStruct{ .a = 1.5, .b = "hello" }; + try writeValue(writer, v); + var obj: TestStruct = undefined; + try expect(try match(stream.getWritten(), extractAlloc(&obj, allocator))); + try expectEqual(1.5, obj.a); + try expectEqualStrings("hello", obj.b); }