From 3ace3163087a0260b30e2c420de76235dd82451f Mon Sep 17 00:00:00 2001 From: CJ van den Berg Date: Tue, 25 Jun 2024 21:04:51 +0200 Subject: [PATCH] feat: add support for floats to cbor --- src/cbor.zig | 195 +++++++++++++++++++++++++++++++++++++------- test/tests_cbor.zig | 88 ++++++++++++++++++++ 2 files changed, 253 insertions(+), 30 deletions(-) diff --git a/src/cbor.zig b/src/cbor.zig index 1fedf07..caa360a 100644 --- a/src/cbor.zig +++ b/src/cbor.zig @@ -1,5 +1,7 @@ const std = @import("std"); +const builtin = @import("builtin"); +const native_endian = builtin.cpu.arch.endian(); const eql = std.mem.eql; const bufPrint = std.fmt.bufPrint; const fixedBufferStream = std.io.fixedBufferStream; @@ -33,6 +35,9 @@ pub const CborJsonError = error{ const cbor_magic_null: u8 = 0xf6; const cbor_magic_true: u8 = 0xf5; const cbor_magic_false: u8 = 0xf4; +const cbor_magic_float16: u8 = 0xf9; +const cbor_magic_float32: u8 = 0xfa; +const cbor_magic_float64: u8 = 0xfb; const cbor_magic_type_array: u8 = 4; const cbor_magic_type_map: u8 = 5; @@ -46,6 +51,7 @@ const value_type = enum(u8) { tag, boolean, null, + float, any, more, unknown, @@ -138,6 +144,50 @@ fn writeU64(writer: anytype, value: u64) @TypeOf(writer).Error!void { return writeTypedVal(writer, 0, value); } +fn writeF16(writer: anytype, value: f16) @TypeOf(writer).Error!void { + try write(writer, cbor_magic_float16); + const value_bytes = std.mem.asBytes(&value); + switch (native_endian) { + .big => try write(writer, value_bytes), + .little => { + try write(writer, value_bytes[1]); + try write(writer, value_bytes[0]); + }, + } +} + +fn writeF32(writer: anytype, value: f32) @TypeOf(writer).Error!void { + try write(writer, cbor_magic_float32); + const value_bytes = std.mem.asBytes(&value); + switch (native_endian) { + .big => try write(writer, value_bytes), + .little => { + try write(writer, value_bytes[3]); + try write(writer, value_bytes[2]); + try write(writer, value_bytes[1]); + try write(writer, value_bytes[0]); + }, + } +} + +fn writeF64(writer: anytype, value: f64) @TypeOf(writer).Error!void { + try write(writer, cbor_magic_float64); + const value_bytes = std.mem.asBytes(&value); + switch (native_endian) { + .big => try write(writer, value_bytes), + .little => { + try write(writer, value_bytes[7]); + try write(writer, value_bytes[6]); + try write(writer, value_bytes[5]); + try write(writer, value_bytes[4]); + try write(writer, value_bytes[3]); + try write(writer, value_bytes[2]); + try write(writer, value_bytes[1]); + try write(writer, value_bytes[0]); + }, + } +} + fn writeString(writer: anytype, s: []const u8) @TypeOf(writer).Error!void { try writeTypedVal(writer, 3, s.len); _ = try writer.write(s); @@ -225,6 +275,12 @@ pub fn writeValue(writer: anytype, value: anytype) @TypeOf(writer).Error!void { } }, .Null => try writeNull(writer), + .Float => |info| switch (info.bits) { + 16 => try writeF16(writer, value), + 32 => try writeF32(writer, value), + 64 => try writeF64(writer, value), + else => @compileError("cannot write type '" ++ @typeName(T) ++ "' to cbor stream"), + }, else => @compileError("cannot write type '" ++ @typeName(T) ++ "' to cbor stream"), } } @@ -343,6 +399,66 @@ fn decodeJsonObject(iter_: *[]const u8, minor: u5, obj: *json.ObjectMap) CborErr return true; } +fn decodeFloat(comptime T: type, iter_: *[]const u8, t: CborType) CborError!T { + var v: T = undefined; + var iter = iter_.*; + switch (t.type) { + cbor_magic_float16 => { + if (iter.len < 2) return error.CborTooShort; + var f: f16 = undefined; + var f_bytes = std.mem.asBytes(&f); + switch (native_endian) { + .big => @memcpy(f_bytes, iter[0..2]), + .little => { + f_bytes[0] = iter[1]; + f_bytes[1] = iter[0]; + }, + } + v = @floatCast(f); + iter = iter[2..]; + }, + cbor_magic_float32 => { + if (iter.len < 4) return error.CborTooShort; + var f: f32 = undefined; + var f_bytes = std.mem.asBytes(&f); + switch (native_endian) { + .big => @memcpy(f_bytes, iter[0..4]), + .little => { + f_bytes[0] = iter[3]; + f_bytes[1] = iter[2]; + f_bytes[2] = iter[1]; + f_bytes[3] = iter[0]; + }, + } + v = @floatCast(f); + iter = iter[4..]; + }, + cbor_magic_float64 => { + if (iter.len < 8) return error.CborTooShort; + var f: f64 = undefined; + var f_bytes = std.mem.asBytes(&f); + switch (native_endian) { + .big => @memcpy(f_bytes, iter[0..8]), + .little => { + f_bytes[0] = iter[7]; + f_bytes[1] = iter[6]; + f_bytes[2] = iter[5]; + f_bytes[3] = iter[4]; + f_bytes[4] = iter[3]; + f_bytes[5] = iter[2]; + f_bytes[6] = iter[1]; + f_bytes[7] = iter[0]; + }, + } + v = @floatCast(f); + iter = iter[8..]; + }, + else => return error.CborInvalidType, + } + iter_.* = iter; + return v; +} + pub fn matchInt(comptime T: type, iter_: *[]const u8, val: *T) CborError!bool { var iter = iter_.*; const t = try decodeType(&iter); @@ -394,6 +510,22 @@ fn matchBoolValue(iter: *[]const u8, val: bool) CborError!bool { return if (try matchBool(iter, &v)) v == val else false; } +fn matchFloat(comptime T: type, iter_: *[]const u8, v: *T) CborError!bool { + var iter = iter_.*; + const t = try decodeType(&iter); + v.* = decodeFloat(T, &iter, t) catch |e| switch (e) { + error.CborInvalidType => return false, + else => return e, + }; + iter_.* = iter; + return true; +} + +fn matchFloatValue(comptime T: type, iter: *[]const u8, val: T) CborError!bool { + var v: T = 0.0; + return if (try matchFloat(T, iter, &v)) v == val else false; +} + fn skipString(iter: *[]const u8, minor: u5) CborError!void { const len: usize = @intCast(try decodePInt(iter, minor)); if (iter.len < len) @@ -423,35 +555,38 @@ fn skipMap(iter: *[]const u8, minor: u5) CborError!void { } pub fn skipValue(iter: *[]const u8) CborError!void { - const t = try decodeType(iter); - try skipValueType(iter, t.major, t.minor); + try skipValueType(iter, try decodeType(iter)); } -fn skipValueType(iter: *[]const u8, major: u3, minor: u5) CborError!void { - switch (major) { +fn skipValueType(iter: *[]const u8, t: CborType) CborError!void { + switch (t.major) { 0 => { // positive integer - _ = try decodePInt(iter, minor); + _ = try decodePInt(iter, t.minor); }, 1 => { // negative integer - _ = try decodeNInt(iter, minor); + _ = try decodeNInt(iter, t.minor); }, 2 => { // bytes - try skipBytes(iter, minor); + try skipBytes(iter, t.minor); }, 3 => { // string - try skipString(iter, minor); + try skipString(iter, t.minor); }, 4 => { // array - try skipArray(iter, minor); + try skipArray(iter, t.minor); }, 5 => { // map - try skipMap(iter, minor); + try skipMap(iter, t.minor); }, 6 => { // tag return error.CborInvalidType; }, - 7 => { // special - return; + 7 => switch (t.type) { // special + cbor_magic_null, cbor_magic_false, cbor_magic_true => return, + cbor_magic_float16 => iter.* = iter.*[2..], + cbor_magic_float32 => iter.* = iter.*[4..], + cbor_magic_float64 => iter.* = iter.*[8..], + else => return error.CborInvalidType, }, } } @@ -459,23 +594,18 @@ fn skipValueType(iter: *[]const u8, major: u3, minor: u5) CborError!void { fn matchType(iter_: *[]const u8, v: *value_type) CborError!bool { var iter = iter_.*; const t = try decodeType(&iter); - try skipValueType(&iter, t.major, t.minor); + try skipValueType(&iter, t); switch (t.major) { 0, 1 => v.* = value_type.number, // positive integer or negative integer 2 => v.* = value_type.bytes, // bytes 3 => v.* = value_type.string, // string 4 => v.* = value_type.array, // array 5 => v.* = value_type.map, // map - 7 => { // special - if (t.type == cbor_magic_null) { - v.* = value_type.null; - } else { - if (t.type == cbor_magic_false or t.type == cbor_magic_true) { - v.* = value_type.boolean; - } else { - return false; - } - } + 7 => switch (t.type) { // special + cbor_magic_null => v.* = value_type.null, + cbor_magic_false, cbor_magic_true => v.* = value_type.boolean, + cbor_magic_float16, cbor_magic_float32, cbor_magic_float64 => v.* = value_type.float, + else => return false, }, else => return false, } @@ -529,6 +659,8 @@ pub fn matchValue(iter: *[]const u8, value: anytype) CborError!bool { else matchError(T), .Array => |info| if (info.child == u8) matchStringValue(iter, &value) else matchArray(iter, value, info), + .Float => return matchFloatValue(T, iter, value), + .ComptimeFloat => matchFloatValue(f64, iter, value), else => @compileError("cannot match value type '" ++ @typeName(T) ++ "' to cbor stream"), }; } @@ -726,6 +858,7 @@ fn Extractor(comptime T: type) type { } return false; }, + .Float => return matchFloat(T, iter, self.dest), else => extractError(T), } } @@ -795,12 +928,15 @@ pub fn JsonStream(comptime T: type) type { pub fn jsonWriteValue(w: *JsonWriter, iter: *[]const u8) (CborJsonError || Writer.Error)!void { const t = try decodeType(iter); - if (t.type == cbor_magic_false) - return w.write(false); - if (t.type == cbor_magic_true) - return w.write(true); - if (t.type == cbor_magic_null) - return w.write(null); + switch (t.type) { + cbor_magic_false => return w.write(false), + cbor_magic_true => return w.write(true), + cbor_magic_null => return w.write(null), + cbor_magic_float16 => return w.write(try decodeFloat(f16, iter, t)), + cbor_magic_float32 => return w.write(try decodeFloat(f32, iter, t)), + cbor_magic_float64 => return w.write(try decodeFloat(f64, iter, t)), + else => {}, + } return switch (t.major) { 0 => w.write(try decodePInt(iter, t.minor)), // positive integer 1 => w.write(try decodeNInt(iter, t.minor)), // negative integer @@ -862,7 +998,6 @@ fn writeJsonValue(writer: anytype, value: json.Value) !void { .array => |_| unreachable, .object => |_| unreachable, .null => writeNull(writer), - .float => |_| error.CborUnsupportedType, inline else => |v| writeValue(writer, v), }; } diff --git a/test/tests_cbor.zig b/test/tests_cbor.zig index 6ba45ca..9229706 100644 --- a/test/tests_cbor.zig +++ b/test/tests_cbor.zig @@ -10,6 +10,9 @@ const fmt = cbor_mod.fmt; const toJson = cbor_mod.toJson; const toJsonPretty = cbor_mod.toJsonPretty; const fromJson = cbor_mod.fromJson; +const fromJsonAlloc = cbor_mod.fromJsonAlloc; +const toJsonAlloc = cbor_mod.toJsonAlloc; +const toJsonPrettyAlloc = cbor_mod.toJsonPrettyAlloc; const decodeType = cbor_mod.decodeType; const matchInt = cbor_mod.matchInt; const matchIntValue = cbor_mod.matchIntValue; @@ -433,3 +436,88 @@ test "cbor.fromJson_object" { const cbor = try fromJson(json_buf, &cbor_buf); try expect(try match(cbor, map)); } + +test "cbor f32" { + var buf: [128]u8 = undefined; + try expectEqualDeep( + fmt(&buf, .{ "float", @as(f32, 0.96891385316848755) }), + &[_]u8{ + 0x82, // 82 # array(2) + 0x65, // 65 # text(5) + 0x66, // 666C6F6174 # "float" + 0x6C, + 0x6F, + 0x61, + 0x74, + 0xfa, // FA 3F780ABD # primitive(1064831677) + 0x3F, + 0x78, + 0x0A, + 0xBD, + }, + ); +} + +test "cbor.fromJson_object f32" { + var buf: [128]u8 = undefined; + const json_buf: []const u8 = + \\["float",0.96891385316848755] + ; + const cbor = try fromJson(json_buf, &buf); + try expect(try match(cbor, array)); + + try expectEqualDeep( + &[_]u8{ + 0x82, // 82 # array(2) + 0x65, // 65 # text(5) + 0x66, // 666C6F6174 # "float" + 0x6C, + 0x6F, + 0x61, + 0x74, + 0xfb, // FB 3FEF0157A0000000 # primitive(4606902419681443840) + 0x3f, + 0xef, + 0x01, + 0x57, + 0xa0, + 0x00, + 0x00, + 0x00, + }, + cbor, + ); +} + +test "cbor.extract_match_f32" { + var buf: [128]u8 = undefined; + const json_buf: []const u8 = + \\["float",0.96891385316848755] + ; + const m = try fromJson(json_buf, &buf); + + try expect(try match(m, .{ "float", @as(f64, 0.96891385316848755) })); + + var f: f64 = undefined; + try expect(try match(m, .{ "float", extract(&f) })); + try expectEqual(0.96891385316848755, f); +} + +test "cbor.extract_cbor f64" { + var buf: [128]u8 = undefined; + const json_buf: []const u8 = + \\["float",[0.96891385316848755],"check"] + ; + const m = try fromJson(json_buf, &buf); + + var sub: []const u8 = undefined; + try expect(try match(m, .{ "float", extract_cbor(&sub), "check" })); + + const json = try toJsonPrettyAlloc(std.testing.allocator, sub); + defer std.testing.allocator.free(json); + + const json2 = try toJsonAlloc(std.testing.allocator, sub); + defer std.testing.allocator.free(json2); + + try expectEqualDeep("[9.689138531684875e-1]", json2); +}