diff --git a/src/cbor.zig b/src/cbor.zig index 2d791c8..9e2b4cd 100644 --- a/src/cbor.zig +++ b/src/cbor.zig @@ -16,6 +16,11 @@ pub const Error = error{ InvalidType, TooShort, OutOfMemory, + InvalidFloatType, + InvalidArrayType, + InvalidPIntType, + JsonIncompatibleType, + NotAnObject, }; pub const JsonEncodeError = (Error || error{ @@ -122,8 +127,8 @@ pub fn writeMapHeader(writer: anytype, sz: usize) @TypeOf(writer).Error!void { pub fn writeArray(writer: anytype, args: anytype) @TypeOf(writer).Error!void { const args_type_info = @typeInfo(@TypeOf(args)); - if (args_type_info != .Struct) @compileError("expected tuple or struct argument"); - const fields_info = args_type_info.Struct.fields; + if (args_type_info != .@"struct") @compileError("expected tuple or struct argument"); + const fields_info = args_type_info.@"struct".fields; try writeArrayHeader(writer, fields_info.len); inline for (fields_info) |field_info| try writeValue(writer, @field(args, field_info.name)); @@ -215,6 +220,9 @@ pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void { .error_union => return if (value) |v| writeValue(writer, v) else |err| writeValue(writer, err), .error_set => return writeErrorset(writer, value), .@"union" => |info| { + if (std.meta.hasFn(T, "cborEncode")) { + return value.cborEncode(writer); + } if (info.tag_type) |TagType| { comptime var v = void; inline for (info.fields) |u_field| { @@ -231,6 +239,9 @@ pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void { } }, .@"struct" => |info| { + if (std.meta.hasFn(T, "cborEncode")) { + return value.cborEncode(writer); + } if (info.is_tuple) { if (info.fields.len == 0) return writeNull(writer); try writeArrayHeader(writer, info.fields.len); @@ -277,6 +288,12 @@ pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void { 64 => try writeF64(writer, value), else => @compileError("cannot write type '" ++ @typeName(T) ++ "' to cbor stream"), }, + .@"enum" => { + if (std.meta.hasFn(T, "cborEncode")) { + return value.cborEncode(writer); + } + return writeString(writer, @tagName(value)); + }, else => @compileError("cannot write type '" ++ @typeName(T) ++ "' to cbor stream"), } } @@ -298,7 +315,7 @@ pub fn decodeType(iter: *[]const u8) error{TooShort}!CborType { return .{ .type = type_, .minor = bits.minor, .major = bits.major }; } -fn decodeUIntLengthRecurse(iter: *[]const u8, length: usize, acc: u64) !u64 { +fn decodeUIntLengthRecurse(iter: *[]const u8, length: usize, acc: u64) error{TooShort}!u64 { if (iter.len < 1) return error.TooShort; const v: u8 = iter.*[0]; @@ -315,14 +332,14 @@ fn decodeUIntLength(iter: *[]const u8, length: usize) !u64 { return decodeUIntLengthRecurse(iter, length, 0); } -fn decodePInt(iter: *[]const u8, minor: u5) !u64 { +fn decodePInt(iter: *[]const u8, minor: u5) error{ TooShort, InvalidPIntType }!u64 { if (minor < 24) return minor; return switch (minor) { 24 => decodeUIntLength(iter, 1), // 1 byte 25 => decodeUIntLength(iter, 2), // 2 byte 26 => decodeUIntLength(iter, 4), // 4 byte 27 => decodeUIntLength(iter, 8), // 8 byte - else => error.InvalidType, + else => error.InvalidPIntType, }; } @@ -335,16 +352,16 @@ pub fn decodeMapHeader(iter: *[]const u8) Error!usize { if (t.type == cbor_magic_null) return 0; if (t.major != 5) - return error.InvalidType; + return error.InvalidMapType; return @intCast(try decodePInt(iter, t.minor)); } -pub fn decodeArrayHeader(iter: *[]const u8) Error!usize { +pub fn decodeArrayHeader(iter: *[]const u8) error{ TooShort, InvalidArrayType, InvalidPIntType }!usize { const t = try decodeType(iter); if (t.type == cbor_magic_null) return 0; if (t.major != 4) - return error.InvalidType; + return error.InvalidArrayType; return @intCast(try decodePInt(iter, t.minor)); } @@ -449,7 +466,7 @@ fn decodeFloat(comptime T: type, iter_: *[]const u8, t: CborType) Error!T { v = @floatCast(f); iter = iter[8..]; }, - else => return error.InvalidType, + else => return error.InvalidFloatType, } iter_.* = iter; return v; @@ -743,7 +760,11 @@ fn matchArrayMore(iter_: *[]const u8, n_: u64) Error!bool { fn matchArray(iter_: *[]const u8, arr: anytype, info: anytype) Error!bool { var iter = iter_.*; - var n = try decodeArrayHeader(&iter); + var n = decodeArrayHeader(&iter) catch |e| switch (e) { + error.InvalidArrayType => return false, + error.InvalidPIntType => return e, + error.TooShort => return e, + }; inline for (info.fields) |f| { const value = @field(arr, f.name); if (isMore(value)) @@ -768,13 +789,25 @@ fn matchArray(iter_: *[]const u8, arr: anytype, info: anytype) Error!bool { return n == 0; } +fn matchArrayScalar(iter: *[]const u8, arr: anytype) Error!bool { + var i: usize = 0; + var n = try decodeArrayHeader(iter); + if (n != arr.len) return false; + while (n > 0) : (n -= 1) { + if (!(matchValue(iter, extract(&arr[i])) catch return false)) + return false; + i += 1; + } + return true; +} + fn matchJsonObject(iter_: *[]const u8, obj: *json.ObjectMap) !bool { var iter = iter_.*; const t = try decodeType(&iter); if (t.type == cbor_magic_null) return true; if (t.major != 5) - return error.InvalidType; + return error.NotAnObject; const ret = try decodeJsonObject(&iter, t.minor, obj); if (ret) iter_.* = iter; return ret; @@ -786,7 +819,11 @@ pub fn match(buf: []const u8, pattern: anytype) Error!bool { } fn extractError(comptime T: type) noreturn { - @compileError("cannot extract type '" ++ @typeName(T) ++ "' from cbor stream"); + @compileError("cannot extract type '" ++ @typeName(T) ++ "' from a cbor stream"); +} + +fn extractErrorAlloc(comptime T: type) noreturn { + @compileError("extracting type '" ++ @typeName(T) ++ "' from a cbor stream requires an allocating extractor, use extractAlloc"); } fn hasExtractorTag(info: anytype) bool { @@ -805,6 +842,79 @@ fn isExtractor(comptime T: type) bool { }; } +fn ExtractDef(comptime T: type) type { + return fn (*T, *[]const u8) Error!bool; +} + +fn hasExtractMethod(T: type, info: anytype) bool { + const result = blk: { + if (info.is_tuple) break :blk false; + for (info.decls) |decl| { + if (std.mem.eql(u8, decl.name, "cborExtract") and @TypeOf(@field(T, decl.name)) == ExtractDef(T)) + break :blk true; + } + break :blk false; + }; + // @compileLog("hasExtractMethod", @typeName(T), result); + return result; +} + +pub fn isExtractable(comptime T: type) bool { + return comptime switch (@typeInfo(T)) { + .@"struct" => |info| hasExtractMethod(T, info), + .@"enum" => |info| hasExtractMethod(T, info), + .@"union" => |info| hasExtractMethod(T, info), + else => false, + }; +} + +fn ExtractAllocDef(comptime T: type) type { + return fn (*T, *[]const u8, std.mem.Allocator) Error!bool; +} + +fn hasExtractMethodAlloc(T: type, info: anytype) bool { + const result = blk: { + if (@hasField(@TypeOf(info), "is_tuple") and info.is_tuple) break :blk false; + for (info.decls) |decl| { + if (std.mem.eql(u8, decl.name, "cborExtract") and @TypeOf(@field(T, decl.name)) == ExtractAllocDef(T)) + break :blk true; + } + break :blk false; + }; + // @compileLog("hasExtractMethodAlloc", @typeName(T), result); + return result; +} + +pub fn isExtractableAlloc(comptime T: type) bool { + return comptime switch (@typeInfo(T)) { + .@"struct" => |info| hasExtractMethodAlloc(T, info), + .@"enum" => |info| hasExtractMethodAlloc(T, info), + .@"union" => |info| hasExtractMethodAlloc(T, info), + else => false, + }; +} + +fn GenericExtractorAlloc(T: type) type { + return struct { + dest: *T, + allocator: std.mem.Allocator, + const Self = @This(); + pub const EXTRACTOR_TAG = struct {}; + + pub fn init(dest: *T, allocator: std.mem.Allocator) Self { + return .{ .dest = dest, .allocator = allocator }; + } + + pub fn extract(self: Self, iter: *[]const u8) Error!bool { + if (comptime isExtractableAlloc(T)) { + return self.dest.cborExtract(iter, self.allocator); + } else { + return self.dest.cborExtract(iter); + } + } + }; +} + const JsonValueExtractor = struct { dest: *T, const Self = @This(); @@ -872,7 +982,8 @@ fn Extractor(comptime T: type) type { }, .float => return matchFloat(T, iter, self.dest), .@"enum" => return matchEnum(T, iter, self.dest), - else => extractError(T), + .array => return matchArrayScalar(iter, self.dest), + else => return self.dest.cborExtract(iter), } } }; @@ -881,7 +992,10 @@ fn Extractor(comptime T: type) type { fn ExtractorType(comptime T: type) type { const T_type_info = @typeInfo(T); if (T_type_info != .pointer) @compileError("extract requires a pointer argument"); - return Extractor(T_type_info.pointer.child); + return if (isExtractableAlloc(T_type_info.pointer.child)) + extractErrorAlloc(T_type_info.pointer.child) + else + Extractor(T_type_info.pointer.child); } pub fn extract(dest: anytype) ExtractorType(@TypeOf(dest)) { @@ -892,6 +1006,21 @@ pub fn extract(dest: anytype) ExtractorType(@TypeOf(dest)) { return ExtractorType(@TypeOf(dest)).init(dest); } +fn ExtractorTypeAlloc(comptime T: type) type { + const T_type_info = @typeInfo(T); + if (T_type_info != .pointer) @compileError("extractAlloc requires a pointer argument"); + // @compileLog("ExtractorTypeAlloc", @typeName(T), isExtractableAlloc(T_type_info.pointer.child)); + return GenericExtractorAlloc(T_type_info.pointer.child); +} + +pub fn extractAlloc(dest: anytype, allocator: std.mem.Allocator) ExtractorTypeAlloc(@TypeOf(dest)) { + comptime { + if (!isExtractor(ExtractorTypeAlloc(@TypeOf(dest)))) + @compileError("isExtractor self check failed for " ++ @typeName(ExtractorTypeAlloc(@TypeOf(dest)))); + } + return ExtractorTypeAlloc(@TypeOf(dest)).init(dest, allocator); +} + const CborExtractor = struct { dest: *[]const u8, const Self = @This(); @@ -960,7 +1089,7 @@ pub fn JsonStreamWriter(comptime Writer: type) type { 3 => w.write(try decodeString(iter, t.minor)), // string 4 => jsonWriteArray(w, iter, t.minor), // array 5 => jsonWriteMap(w, iter, t.minor), // map - else => error.InvalidType, + else => error.JsonIncompatibleType, }; } };