From 4d1665ba932922a0fabcfab585e01da2d67ec8f9 Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Tue, 21 May 2024 15:26:20 +0200 Subject: [PATCH] linux: add cancel support --- io/darwin.zig | 22 +++++++++++++++++++ io/linux.zig | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++ io/test.zig | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+) diff --git a/io/darwin.zig b/io/darwin.zig index 37f7dae..7d3008c 100644 --- a/io/darwin.zig +++ b/io/darwin.zig @@ -557,6 +557,28 @@ pub const IO = struct { ); } + pub const CancelError = error{ NotFound, ExpirationInProgress } || posix.UnexpectedError; + + pub fn cancel( + self: *IO, + comptime Context: type, + context: Context, + comptime callback: fn ( + context: Context, + completion: *Completion, + result: CancelError!void, + ) void, + completion: *Completion, + cancel_completion: *Completion, + ) void { + _ = self; + _ = context; + _ = callback; + _ = completion; + _ = cancel_completion; + // TODO implement cancellation w/ kqueue. + } + pub const TimeoutError = error{Canceled} || posix.UnexpectedError; pub fn timeout( diff --git a/io/linux.zig b/io/linux.zig index 72c3fc0..76726c8 100644 --- a/io/linux.zig +++ b/io/linux.zig @@ -239,6 +239,9 @@ pub const IO = struct { op.offset, ); }, + .cancel => |op| { + sqe.prep_cancel(op.c, 0); + }, } sqe.user_data = @intFromPtr(completion); } @@ -464,6 +467,22 @@ pub const IO = struct { }; completion.callback(completion.context, completion, &result); }, + .cancel => { + const result: CancelError!void = blk: { + if (completion.result < 0) { + const err = switch (@as(posix.E, @enumFromInt(-completion.result))) { + .SUCCESS => {}, + .NOENT => error.NotFound, + .ALREADY => error.ExpirationInProgress, + else => |errno| posix.unexpectedErrno(errno), + }; + break :blk err; + } else { + break :blk; + } + }; + completion.callback(completion.context, completion, &result); + }, } } }; @@ -503,6 +522,9 @@ pub const IO = struct { buffer: []const u8, offset: u64, }, + cancel: struct { + c: u64, + }, }; pub const AcceptError = error{ @@ -793,6 +815,42 @@ pub const IO = struct { self.enqueue(completion); } + pub const CancelError = error{ NotFound, ExpirationInProgress } || posix.UnexpectedError; + + pub fn cancel( + self: *IO, + comptime Context: type, + context: Context, + comptime callback: fn ( + context: Context, + completion: *Completion, + result: CancelError!void, + ) void, + completion: *Completion, + cancel_completion: *Completion, + ) void { + completion.* = .{ + .io = self, + .context = context, + .callback = struct { + fn wrapper(ctx: ?*anyopaque, comp: *Completion, res: *const anyopaque) void { + callback( + @as(Context, @ptrFromInt(@intFromPtr(ctx))), + comp, + @as(*const CancelError!void, @ptrFromInt(@intFromPtr(res))).*, + ); + } + }.wrapper, + .operation = .{ + .cancel = .{ + .c = @intFromPtr(cancel_completion), + }, + }, + }; + + self.enqueue(completion); + } + pub const TimeoutError = error{Canceled} || posix.UnexpectedError; pub fn timeout( diff --git a/io/test.zig b/io/test.zig index 518c0c9..803f950 100644 --- a/io/test.zig +++ b/io/test.zig @@ -640,3 +640,62 @@ test "pipe data over socket" { } }.run(); } + +test "cancel" { + try struct { + const Context = @This(); + + io: IO, + timeout_res: IO.TimeoutError!void = undefined, + timeout_done: bool = false, + cancel_done: bool = false, + + fn run_test() !void { + var self: Context = .{ + .io = try IO.init(32, 0), + }; + defer self.io.deinit(); + + var completion: IO.Completion = undefined; + self.io.timeout( + *Context, + &self, + timeout_callback, + &completion, + 100 * std.time.ns_per_ms, + ); + + var cancel_completion: IO.Completion = undefined; + self.io.cancel( + *Context, + &self, + cancel_callback, + &cancel_completion, + &completion, + ); + while (!self.cancel_done and !self.timeout_done) try self.io.tick(); + + try testing.expectEqual(true, self.timeout_done); + try testing.expectEqual(true, self.cancel_done); + try testing.expectError(IO.TimeoutError.Canceled, self.timeout_res); + } + + fn timeout_callback( + self: *Context, + _: *IO.Completion, + result: IO.TimeoutError!void, + ) void { + self.timeout_res = result; + self.timeout_done = true; + } + + fn cancel_callback( + self: *Context, + _: *IO.Completion, + result: IO.CancelError!void, + ) void { + result catch |err| std.debug.panic("cancel error: {}", .{err}); + self.cancel_done = true; + } + }.run_test(); +}