Skip to content

Commit

Permalink
Ensure completions are executed on the currently connected client
Browse files Browse the repository at this point in the history
For the time being, given that we only allow 1 client at a time, I took a
shortcut to implement this. The server has an incrementing "current_client_id"
which is part of every completion. On completion callback, we just check if
its client_id is still equal to the server's current_client_id.
  • Loading branch information
karlseguin committed Feb 21, 2025
1 parent 09505db commit 756d662
Showing 1 changed file with 87 additions and 42 deletions.
129 changes: 87 additions & 42 deletions src/server.zig
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub const Client = ClientT(*Server, CDP);
const Server = struct {
allocator: Allocator,
loop: *jsruntime.Loop,
current_client_id: usize = 0,

// internal fields
listener: posix.socket_t,
Expand All @@ -64,10 +65,9 @@ const Server = struct {
// a memory poor for our Clietns
client_pool: std.heap.MemoryPool(Client),

timeout_completion_pool: std.heap.MemoryPool(Completion),
completion_state_pool: std.heap.MemoryPool(CompletionState),

// I/O fields
conn_completion: Completion,
close_completion: Completion,
accept_completion: Completion,

Expand All @@ -77,7 +77,7 @@ const Server = struct {
fn deinit(self: *Server) void {
self.send_pool.deinit();
self.client_pool.deinit();
self.timeout_completion_pool.deinit();
self.completion_state_pool.deinit();
self.allocator.free(self.json_version_response);
}

Expand All @@ -99,40 +99,37 @@ const Server = struct {
) void {
std.debug.assert(self.client == null);
std.debug.assert(completion == &self.accept_completion);

const socket = result catch |err| {
self.doCallbackAccept(result) catch |err| {
log.err("accept error: {any}", .{err});
self.queueAccept();
return;
};
}

const client = self.client_pool.create() catch |err| {
log.err("failed to create client: {any}", .{err});
posix.close(socket);
return;
};
fn doCallbackAccept(
self: *Server,
result: AcceptError!posix.socket_t,
) !void {
const socket = try result;
const client = try self.client_pool.create();
errdefer self.client_pool.destroy(client);

self.current_client_id += 1;
client.* = Client.init(socket, self);

self.client = client;

log.info("client connected", .{});
self.queueRead();
self.queueTimeout();
try self.queueRead();
try self.queueTimeout();
}

fn queueTimeout(self: *Server) void {
const completion = self.timeout_completion_pool.create() catch |err| {
log.err("failed to create timeout completion: {any}", .{err});
return;
};

fn queueTimeout(self: *Server) !void {
const cs = try self.createCompletionState();
self.loop.io.timeout(
*Server,
self,
callbackTimeout,
completion,
&cs.completion,
TimeoutCheck,
);
}
Expand All @@ -142,7 +139,16 @@ const Server = struct {
completion: *Completion,
result: TimeoutError!void,
) void {
self.timeout_completion_pool.destroy(completion);
const cs: *CompletionState = @alignCast(
@fieldParentPtr("completion", completion),
);
defer self.completion_state_pool.destroy(cs);

if (cs.client_id != self.current_client_id) {
// completion for a previously-connected client
return;
}

const client = self.client orelse return;

if (result) |_| {
Expand All @@ -160,28 +166,39 @@ const Server = struct {
// very unlikely IO timeout error.
// AKA: we don't requeue this if the connection timed out and we
// closed the connection.s
self.queueTimeout();
self.queueTimeout() catch |err| {
log.err("queueTimeout error: {any}", .{err});
};
}

fn queueRead(self: *Server) void {
if (self.client) |client| {
self.loop.io.recv(
*Server,
self,
callbackRead,
&self.conn_completion,
client.socket,
client.readBuf(),
);
}
fn queueRead(self: *Server) !void {
var client = self.client orelse return;

const cs = try self.createCompletionState();
self.loop.io.recv(
*Server,
self,
callbackRead,
&cs.completion,
client.socket,
client.readBuf(),
);
}

fn callbackRead(
self: *Server,
completion: *Completion,
result: RecvError!usize,
) void {
std.debug.assert(completion == &self.conn_completion);
const cs: *CompletionState = @alignCast(
@fieldParentPtr("completion", completion),
);
defer self.completion_state_pool.destroy(cs);

if (cs.client_id != self.current_client_id) {
// completion for a previously-connected client
return;
}

var client = self.client orelse return;

Expand All @@ -205,7 +222,10 @@ const Server = struct {

// if more == false, the client is disconnecting
if (more) {
self.queueRead();
self.queueRead() catch |err| {
log.err("queueRead error: {any}", .{err});
client.close(null);
};
}
}

Expand All @@ -218,12 +238,15 @@ const Server = struct {
const sd = try self.send_pool.create();
errdefer self.send_pool.destroy(sd);

const cs = try self.createCompletionState();
errdefer self.completion_state_pool.destroy(cs);

sd.* = .{
.unsent = data,
.server = self,
.socket = socket,
.completion = undefined,
.arena = arena,
.completion_state = cs,
};
sd.queueSend();
}
Expand All @@ -246,6 +269,18 @@ const Server = struct {
std.debug.assert(completion == &self.close_completion);
self.queueAccept();
}

fn createCompletionState(self: *Server) !*CompletionState {
var cs = try self.completion_state_pool.create();
cs.client_id = self.current_client_id;
cs.completion = undefined;
return cs;
}
};

const CompletionState = struct {
client_id: usize,
completion: Completion,
};

// I/O Send
Expand All @@ -259,17 +294,19 @@ const Send = struct { // Any unsent data we have.
unsent: []const u8,

server: *Server,
completion: Completion,
socket: posix.socket_t,
completion_state: *CompletionState,

// If we need to free anything when we're done
arena: ?ArenaAllocator,

fn deinit(self: *Send) void {
var server = self.server;
if (self.arena) |arena| {
arena.deinit();
}

var server = self.server;
server.completion_state_pool.destroy(self.completion_state);
server.send_pool.destroy(self);
}

Expand All @@ -278,16 +315,25 @@ const Send = struct { // Any unsent data we have.
*Send,
self,
sendCallback,
&self.completion,
&self.completion_state.completion,
self.socket,
self.unsent,
);
}

fn sendCallback(self: *Send, _: *Completion, result: SendError!usize) void {
const server = self.server;
const cs = self.completion_state;

if (cs.client_id != server.current_client_id) {
// completion for a previously-connected client
self.deinit();
return;
}

const sent = result catch |err| {
log.info("send error: {any}", .{err});
if (self.server.client) |client| {
if (server.client) |client| {
client.close(null);
}
self.deinit();
Expand Down Expand Up @@ -1011,13 +1057,12 @@ pub fn run(
.timeout = timeout,
.listener = listener,
.allocator = allocator,
.conn_completion = undefined,
.close_completion = undefined,
.accept_completion = undefined,
.json_version_response = json_version_response,
.send_pool = std.heap.MemoryPool(Send).init(allocator),
.client_pool = std.heap.MemoryPool(Client).init(allocator),
.timeout_completion_pool = std.heap.MemoryPool(Completion).init(allocator),
.completion_state_pool = std.heap.MemoryPool(CompletionState).init(allocator),
};
defer server.deinit();

Expand Down

0 comments on commit 756d662

Please sign in to comment.