diff --git a/include/thespian/c/instance.h b/include/thespian/c/instance.h index 3a2eb06..29e5977 100644 --- a/include/thespian/c/instance.h +++ b/include/thespian/c/instance.h @@ -11,8 +11,10 @@ extern "C" { typedef thespian_result (*thespian_receiver)(thespian_behaviour_state, thespian_handle from, cbor_buffer); +typedef void (*thespian_receiver_dtor)(thespian_behaviour_state); -void thespian_receive(thespian_receiver, thespian_behaviour_state); +void thespian_receive(thespian_receiver, thespian_behaviour_state, + thespian_receiver_dtor); bool thespian_get_trap(); bool thespian_set_trap(bool); diff --git a/src/c/instance.cpp b/src/c/instance.cpp index 48232b8..5007f58 100644 --- a/src/c/instance.cpp +++ b/src/c/instance.cpp @@ -7,11 +7,19 @@ using std::string_view; extern "C" { -void thespian_receive(thespian_receiver r, thespian_behaviour_state s) { - thespian::receive([r, s](auto from, cbor::buffer msg) -> thespian::result { +void thespian_receive(thespian_receiver r, thespian_behaviour_state s, + thespian_receiver_dtor dtor) { + struct receiver_wrapper { + thespian_receiver r; + thespian_behaviour_state s; + thespian_receiver_dtor dtor; + ~receiver_wrapper() { dtor(s); } + }; + auto wrapper = std::make_shared(r, s, dtor); + thespian::receive([wrapper](auto from, cbor::buffer msg) -> thespian::result { thespian_handle from_handle = reinterpret_cast( // NOLINT &from); - auto *ret = r(s, from_handle, {msg.data(), msg.size()}); + auto *ret = wrapper->r(wrapper->s, from_handle, {msg.data(), msg.size()}); if (ret) { auto err = cbor::buffer(); const uint8_t *data = ret->base; diff --git a/src/thespian.zig b/src/thespian.zig index 1afd814..d05e261 100644 --- a/src/thespian.zig +++ b/src/thespian.zig @@ -541,23 +541,29 @@ pub fn receive(r: anytype) void { }, else => @compileError("invalid receiver type"), }; - c.thespian_receive(T.run, r); + c.thespian_receive(T.run, r, T.dtor); } pub fn Receiver(comptime T: type) type { return struct { f: FunT, + deinit_f: DeinitFunT, data: T, const FunT: type = *const fn (T, from: pid_ref, m: message) result; + const DeinitFunT: type = *const fn (T) void; const Self = @This(); - pub fn init(f: FunT, data: T) Self { - return .{ .f = f, .data = data }; + pub fn init(f: FunT, deinit_fn: DeinitFunT, data: T) Self { + return .{ .f = f, .deinit_f = deinit_fn, .data = data }; } pub fn run(ostate: c.thespian_behaviour_state, from: c.thespian_handle, m: c.cbor_buffer) callconv(.c) c.thespian_result { const state: *Self = @ptrCast(@alignCast(ostate orelse unreachable)); reset_error(); return to_result(state.f(state.data, wrap_handle(from), message.from(m))); } + pub fn dtor(ostate: c.thespian_behaviour_state) callconv(.c) void { + const state: *Self = @ptrCast(@alignCast(ostate orelse unreachable)); + state.deinit_f(state.data); + } }; } @@ -980,11 +986,10 @@ const CallContext = struct { .response = null, .a = a, }; - self.receiver = ReceiverT.init(receive_, self); + self.receiver = ReceiverT.init(receive_, deinit_from_dtor, self); const proc = try spawn_link(a, self, start, @typeName(Self)); defer proc.deinit(); try self.done.timedWait(timeout_ns); - defer self.deinit(); // only deinit on success. if we timed out proc will have to deinit return self.response orelse .{}; } @@ -992,22 +997,23 @@ const CallContext = struct { self.a.destroy(self); } + fn deinit_from_dtor(self: *Self) void { + // dtor fires after actor exits; only free if caller has already moved on (timed out or got response) + const expired = self.from.expired(); + self.from.deinit(); + if (expired) self.deinit(); + } + fn start(self: *Self) result { errdefer self.done.set(); - _ = set_trap(true); try self.to.link(); try self.to.send_raw(self.request); receive(&self.receiver); } fn receive_(self: *Self, _: pid_ref, m: message) result { - defer { - const expired = self.from.expired(); - self.from.deinit(); - self.done.set(); - if (expired) self.deinit(); - } self.response = m.clone(self.a) catch |e| return exit_error(e, @errorReturnTrace()); + self.done.set(); return exit_normal(); } }; @@ -1034,17 +1040,17 @@ const DelayedSender = struct { } fn start(self: *DelayedSender) result { - self.receiver = ReceiverT.init(receive_, self); + self.receiver = ReceiverT.init(receive_, deinit, self); const m_ = self.message.?; self.timeout = timeout.init(self.delay_us, m_) catch |e| return exit_error(e, @errorReturnTrace()); self.a.free(m_.buf); - _ = set_trap(true); receive(&self.receiver); } fn deinit(self: *DelayedSender) void { self.timeout.deinit(); self.target.deinit(); + self.a.destroy(self); } fn receive_(self: *DelayedSender, _: pid_ref, m_: message) result { @@ -1052,7 +1058,6 @@ const DelayedSender = struct { self.timeout.cancel() catch |e| return exit_error(e, @errorReturnTrace()); return; } - defer self.deinit(); if (try m_.match(.{ "exit", "timeout_error", any, any })) return exit_normal(); try self.target.send_raw(m_);