Skip to content

Commit

Permalink
wasi-sockets: udp refactor (bytecodealliance#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
guybedford authored Jan 20, 2024
1 parent c12dca2 commit f637555
Show file tree
Hide file tree
Showing 10 changed files with 890 additions and 874 deletions.
14 changes: 0 additions & 14 deletions packages/preview2-shim/lib/browser/sockets.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ export const tcp = {
},
addressFamily() {

},
ipv6Only() {

},
setIpv6Only() {

},
setListenBacklogSize() {

Expand Down Expand Up @@ -152,14 +146,6 @@ export const udp = {

},

ipv6Only () {

},

setIpv6Only () {

},

unicastHopLimit () {

},
Expand Down
19 changes: 13 additions & 6 deletions packages/preview2-shim/lib/io/calls.js
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,25 @@ export const SOCKET_TCP_SET_LISTEN_BACKLOG_SIZE = ++call_id << CALL_SHIFT;
export const SOCKET_TCP_DISPOSE = ++call_id << CALL_SHIFT;
// Udp
export const SOCKET_UDP_CREATE_HANDLE = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_BIND = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_CONNECT = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_DISCONNECT = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_CHECK_SEND = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_SEND = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_RECEIVE = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_BIND_START = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_BIND_FINISH = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_STREAM = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_SUBSCRIBE = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_DISPOSE = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_GET_LOCAL_ADDRESS = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_GET_RECEIVE_BUFFER_SIZE = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_GET_REMOTE_ADDRESS = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_GET_SEND_BUFFER_SIZE = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_GET_UNICAST_HOP_LIMIT = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_SET_RECEIVE_BUFFER_SIZE = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_SET_SEND_BUFFER_SIZE = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_SET_UNICAST_HOP_LIMIT = ++call_id << CALL_SHIFT;
export const SOCKET_INCOMING_DATAGRAM_STREAM_RECEIVE = ++call_id << CALL_SHIFT;
export const SOCKET_OUTGOING_DATAGRAM_STREAM_CHECK_SEND = ++call_id << CALL_SHIFT;
export const SOCKET_OUTGOING_DATAGRAM_STREAM_SEND = ++call_id << CALL_SHIFT;
export const SOCKET_DATAGRAM_STREAM_SUBSCRIBE = ++call_id << CALL_SHIFT;
export const SOCKET_DATAGRAM_STREAM_DISPOSE = ++call_id << CALL_SHIFT;

// Name lookup
export const SOCKET_RESOLVE_ADDRESS_CREATE_REQUEST = ++call_id << CALL_SHIFT;
export const SOCKET_RESOLVE_ADDRESS_TAKE_REQUEST = ++call_id << CALL_SHIFT;
Expand Down
133 changes: 38 additions & 95 deletions packages/preview2-shim/lib/io/worker-socket-tcp.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {
createFuture,
createPoll,
createReadableStream,
createReadableStreamPollState,
createWritableStream,
Expand All @@ -10,62 +9,49 @@ import {
verifyPollsDroppedForDrop,
} from "./worker-thread.js";
const { TCP, constants: TCPConstants } = process.binding("tcp_wrap");
import {
deserializeIpAddress,
serializeIpAddress,
isWildcardAddress,
isUnicastIpAddress,
isMulticastIpAddress,
isIPv4MappedAddress,
} from "../nodejs/sockets/socket-common.js";
import {
convertSocketError,
convertSocketErrorCode,
ipSocketAddress,
isIPv4MappedAddress,
isMulticastIpAddress,
isUnicastIpAddress,
isWildcardAddress,
noLookup,
serializeIpAddress,
SOCKET_STATE_BIND,
SOCKET_STATE_BOUND,
SOCKET_STATE_CLOSED,
SOCKET_STATE_CONNECT,
SOCKET_STATE_CONNECTION,
SOCKET_STATE_INIT,
SOCKET_STATE_LISTEN,
SOCKET_STATE_LISTENER,
} from "./worker-sockets.js";
import { Socket, Server } from "node:net";
import { platform } from "node:os";

// As a workaround, we store the bound address in a global map
// this is needed because 'address-in-use' is not always thrown when binding
// more than one socket to the same address
// TODO: remove this workaround when we figure out why!
const globalBoundAddresses = new Set();

const isWindows = platform() === "win32";

let stateCnt = 0;
export const SOCKET_STATE_INIT = ++stateCnt;
export const SOCKET_STATE_BIND = ++stateCnt;
export const SOCKET_STATE_BOUND = ++stateCnt;
export const SOCKET_STATE_LISTEN = ++stateCnt;
export const SOCKET_STATE_LISTENER = ++stateCnt;
export const SOCKET_STATE_CONNECT = ++stateCnt;
export const SOCKET_STATE_CONNECTION = ++stateCnt;
export const SOCKET_STATE_CLOSED = ++stateCnt;

/**
* @typedef {import("../../types/interfaces/wasi-sockets-network.js").IpSocketAddress} IpSocketAddress
* @typedef {import("../../../types/interfaces/wasi-sockets-tcp.js").IpAddressFamily} IpAddressFamily
*
* @typedef {{
* tcpSocket: number | null,
* tcpSocket: number,
* err: Error | null,
* pollState: PollState | null,
* pollState: PollState,
* }} PendingAccept
*
* @typedef {{
* state: number,
* future: number | null,
* serializedLocalAddress: string | null,
* listenBacklogSize: number,
* handle: TCP,
* pendingAccepts: PendingAccept[],
* pollState: PollState,
* }} SocketRecord
* }} TcpSocketRecord
*/

/**
* @type {Map<number, SocketRecord>}
* @type {Map<number, TcpSocketRecord>}
*/
export const tcpSockets = new Map();

Expand All @@ -79,7 +65,6 @@ export function createTcpSocket() {
tcpSockets.set(++tcpSocketCnt, {
state: SOCKET_STATE_INIT,
future: null,
serializedLocalAddress: null,
listenBacklogSize: 128,
handle,
pendingAccepts: [],
Expand All @@ -88,10 +73,6 @@ export function createTcpSocket() {
return tcpSocketCnt;
}

export function socketTcpSubscribe(id) {
return createPoll(tcpSockets.get(id).pollState);
}

export function socketTcpFinish(id, fromState, toState) {
const socket = tcpSockets.get(id);
if (socket.state !== fromState) throw "not-in-progress";
Expand All @@ -105,39 +86,37 @@ export function socketTcpFinish(id, fromState, toState) {
} else {
socket.state = toState;
// for the listener, we must immediately transition back to unresolved
if (toState === SOCKET_STATE_LISTENER)
socket.pollState.ready = false;
if (toState === SOCKET_STATE_LISTENER) socket.pollState.ready = false;
return val;
}
}

export function socketTcpBindStart(id, localAddress, family) {
const socket = tcpSockets.get(id);
if (socket.state !== SOCKET_STATE_INIT) throw "invalid-state";
if (family !== localAddress.tag || !isUnicastIpAddress(localAddress))
if (
family !== localAddress.tag ||
!isUnicastIpAddress(localAddress) ||
isIPv4MappedAddress(localAddress)
)
throw "invalid-argument";
if (isIPv4MappedAddress(localAddress)) throw "invalid-argument";
socket.state = SOCKET_STATE_BIND;
const { handle } = socket;
socket.future = createFuture(
(async () => {
const address = serializeIpAddress(localAddress);
const port = localAddress.val.port;
if (globalBoundAddresses.has(`${address}:${port}`))
throw "address-in-use";
const code =
localAddress.tag === "ipv6"
? handle.bind6(address, port, TCPConstants.UV_TCP_IPV6ONLY)
: handle.bind(address, port);
if (code !== 0) throw convertSocketErrorCode(-code);
// This is a Node.js / libuv quirk to force the bind error to be thrown
// (specifically address-in-use).
{
const localAddress = socketTcpGetLocalAddress(id);
const serializedLocalAddress = `${serializeIpAddress(localAddress)}:${
localAddress.val.port
}`;
globalBoundAddresses.add(
(socket.serializedLocalAddress = serializedLocalAddress)
);
const out = {};
const code = handle.getsockname(out);
if (code !== 0) throw convertSocketErrorCode(-code);
}
})(),
socket.pollState
Expand All @@ -148,17 +127,16 @@ export function socketTcpConnectStart(id, remoteAddress, family) {
const socket = tcpSockets.get(id);
if (socket.state !== SOCKET_STATE_INIT && socket.state !== SOCKET_STATE_BOUND)
throw "invalid-state";
if (remoteAddress.val.port === 0 && isWindows) throw "invalid-argument";
if (
isWildcardAddress(remoteAddress) ||
family !== remoteAddress.tag ||
!isUnicastIpAddress(remoteAddress) ||
isMulticastIpAddress(remoteAddress) ||
remoteAddress.val.port === 0
remoteAddress.val.port === 0 ||
isIPv4MappedAddress(remoteAddress)
) {
throw "invalid-argument";
}
if (isIPv4MappedAddress(remoteAddress)) throw "invalid-argument";
socket.state = SOCKET_STATE_CONNECT;
socket.future = createFuture(
new Promise((resolve, reject) => {
Expand All @@ -169,19 +147,10 @@ export function socketTcpConnectStart(id, remoteAddress, family) {
});
function handleErr(err) {
tcpSocket.off("connect", handleConnect);
reject(err);
reject(convertSocketError(err));
}
function handleConnect() {
tcpSocket.off("error", handleErr);
if (!tcpSocket.serializedLocalAddress) {
const localAddress = socketTcpGetLocalAddress(id);
const serializedLocalAddress = `${serializeIpAddress(localAddress)}:${
localAddress.val.port
}`;
globalBoundAddresses.add(
(tcpSocket.serializedLocalAddress = serializedLocalAddress)
);
}
resolve([
createReadableStream(tcpSocket),
createWritableStream(tcpSocket),
Expand All @@ -192,9 +161,7 @@ export function socketTcpConnectStart(id, remoteAddress, family) {
tcpSocket.connect({
port: remoteAddress.val.port,
host: serializeIpAddress(remoteAddress),
lookup: () => {
throw "invalid-argument";
},
lookup: noLookup,
});
}),
socket.pollState
Expand All @@ -211,7 +178,7 @@ export function socketTcpListenStart(id) {
const server = new Server({ pauseOnConnect: true, allowHalfOpen: true });
function handleErr(err) {
server.off("listening", handleListen);
reject(err);
reject(convertSocketError(err));
}
function handleListen() {
server.off("error", handleErr);
Expand Down Expand Up @@ -243,12 +210,10 @@ export function socketTcpAccept(id) {
socket.state = SOCKET_STATE_CLOSED;
throw convertSocketError(accept.err);
}
if (socket.pendingAccepts.length === 0)
socket.pollState.ready = false;
if (socket.pendingAccepts.length === 0) socket.pollState.ready = false;
tcpSockets.set(++tcpSocketCnt, {
state: SOCKET_STATE_CONNECTION,
future: null,
serializedLocalAddress: null,
listenBacklogSize: 128,
handle: accept.tcpSocket._handle,
pendingAccepts: [],
Expand All @@ -261,10 +226,6 @@ export function socketTcpAccept(id) {
];
}

export function socketTcpIsListening(id) {
return tcpSockets.get(id).state === SOCKET_STATE_LISTENER;
}

export function socketTcpSetListenBacklogSize(id, backlogSize) {
const socket = tcpSockets.get(id);
if (
Expand All @@ -286,31 +247,15 @@ export function socketTcpGetLocalAddress(id) {
const out = {};
const code = handle.getsockname(out);
if (code !== 0) throw convertSocketErrorCode(-code);
const family = out.family.toLowerCase();
const { address, port } = out;
return {
tag: family,
val: {
address: deserializeIpAddress(address, family),
port,
},
};
return ipSocketAddress(out.family.toLowerCase(), out.address, out.port);
}

export function socketTcpGetRemoteAddress(id) {
const { handle } = tcpSockets.get(id);
const out = {};
const code = handle.getpeername(out);
if (code !== 0) throw convertSocketErrorCode(-code);
const family = out.family.toLowerCase();
const { address, port } = out;
return {
tag: family,
val: {
address: deserializeIpAddress(address, family),
port,
},
};
return ipSocketAddress(out.family.toLowerCase(), out.address, out.port);
}

// Node.js only supports a write shutdown
Expand All @@ -333,8 +278,6 @@ export function socketTcpSetKeepAlive(id, { keepAlive, keepAliveIdleTime }) {
export function socketTcpDispose(id) {
const socket = tcpSockets.get(id);
verifyPollsDroppedForDrop(socket.pollState, "tcp socket");
if (socket.serializedLocalAddress)
globalBoundAddresses.delete(socket.serializedLocalAddress);
socket.handle.close();
tcpSockets.delete(id);
}
Loading

0 comments on commit f637555

Please sign in to comment.