diff --git a/src/napi/async.zig b/src/napi/async.zig index b8d7148..435067a 100644 --- a/src/napi/async.zig +++ b/src/napi/async.zig @@ -14,6 +14,8 @@ const AbortRegistration = @import("./abort_signal.zig").AbortRegistration; var threaded_runtime_mutex: std.atomic.Mutex = .unlocked; var threaded_runtime_initialized = false; var threaded_runtime_active_operations: usize = 0; +var threaded_runtime_cleanup_registered = false; +var threaded_runtime_cleanup_requested = false; var threaded_runtime: std.Io.Threaded = undefined; pub const RuntimeModel = enum { @@ -70,12 +72,53 @@ fn singleIo() std.Io { return std.Io.Threaded.global_single_threaded.io(); } -fn acquireThreadedRuntime() std.Io { +fn deinitThreadedRuntimeLocked() void { + if (!threaded_runtime_initialized) { + threaded_runtime_cleanup_requested = false; + return; + } + + threaded_runtime.deinit(); + threaded_runtime_initialized = false; + threaded_runtime_cleanup_requested = false; + threaded_runtime = undefined; +} + +fn threadedRuntimeCleanupHook(_: ?*anyopaque) callconv(.c) void { while (!threaded_runtime_mutex.tryLock()) { std.Thread.yield() catch {}; } defer threaded_runtime_mutex.unlock(); + threaded_runtime_cleanup_registered = false; + threaded_runtime_cleanup_requested = true; + if (threaded_runtime_active_operations == 0) { + deinitThreadedRuntimeLocked(); + } +} + +fn ensureThreadedRuntimeCleanupHookLocked(env_raw: napi.napi_env) !void { + if (threaded_runtime_cleanup_registered) return; + + const status = napi.napi_add_env_cleanup_hook(env_raw, threadedRuntimeCleanupHook, null); + if (status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(status)); + } + threaded_runtime_cleanup_registered = true; +} + +fn acquireThreadedRuntime(env_raw: napi.napi_env) !std.Io { + while (!threaded_runtime_mutex.tryLock()) { + std.Thread.yield() catch {}; + } + defer threaded_runtime_mutex.unlock(); + + if (threaded_runtime_cleanup_requested) { + return NapiError.Error.fromStatus(NapiError.Status.Closing); + } + + try ensureThreadedRuntimeCleanupHookLocked(env_raw); + if (!threaded_runtime_initialized) { threaded_runtime = std.Io.Threaded.init(GlobalAllocator.globalAllocator(), .{}); threaded_runtime_initialized = true; @@ -107,10 +150,8 @@ fn releaseThreadedRuntime() void { } threaded_runtime_active_operations -= 1; - if (threaded_runtime_active_operations == 0) { - threaded_runtime.deinit(); - threaded_runtime_initialized = false; - threaded_runtime = undefined; + if (threaded_runtime_active_operations == 0 and threaded_runtime_cleanup_requested) { + deinitThreadedRuntimeLocked(); } } @@ -363,8 +404,8 @@ fn AsyncTaskOperation( listener_ref: ?napi.napi_ref = null, abort_registration: ?*AbortRegistration = null, cancel_token: CancelToken = .{}, - controller_thread: ?std.Thread = null, future: ?std.Io.Future(void) = null, + controller_future: ?std.Io.Future(void) = null, tsfn_raw: napi.napi_threadsafe_function = null, state_mutex: std.Io.Mutex = .init, state_cond: std.Io.Condition = .init, @@ -420,7 +461,7 @@ fn AsyncTaskOperation( switch (effectiveRuntime(runtime)) { .single => self.runSingle(), .thread => { - const io = acquireThreadedRuntime(); + const io = try acquireThreadedRuntime(self.env); self.uses_threaded_runtime = true; try self.initThreadDispatcher(); self.future = std.Io.concurrent(io, runTask, .{self}) catch |err| { @@ -428,13 +469,21 @@ fn AsyncTaskOperation( self.dispatchCompletion(self.env); return promise; }; - self.controller_thread = try std.Thread.spawn(.{}, controllerThreadMain, .{self}); + self.controller_future = std.Io.concurrent(io, controllerTask, .{self}) catch |err| { + if (self.future) |*future| { + future.cancel(io); + self.future = null; + } + self.err = mapAnyError(err); + self.dispatchCompletion(self.env); + return promise; + }; }, } return promise; } - fn controllerThreadMain(self: *Self) void { + fn controllerTask(self: *Self) void { const io = self.operationIo(); const should_cancel = self.waitForTaskDoneOrAbort(); if (self.future) |*future| { @@ -620,9 +669,13 @@ fn AsyncTaskOperation( } fn dispatchCompletion(self: *Self, env_raw: napi.napi_env) void { - if (self.controller_thread) |thread| { - thread.join(); - self.controller_thread = null; + if (self.controller_future) |*controller_future| { + controller_future.await(self.operationIo()); + self.controller_future = null; + self.future = null; + } else if (self.future) |*future| { + future.await(self.operationIo()); + self.future = null; } const promise = self.promise;