diff --git a/src/server/websocket.ts b/src/server/websocket.ts index 0388a12a..2e09e401 100644 --- a/src/server/websocket.ts +++ b/src/server/websocket.ts @@ -48,6 +48,8 @@ export async function createWsServer( const wss = new WebSocketServer({ noServer: true }) const inputHandler = new InputHandler(inputThrottleMs) + let currentProvider: WebSocket | null = null + const MAX_CONSUMERS = 5 let LAN_IP = "127.0.0.1" try { LAN_IP = await getLocalIp() @@ -110,6 +112,10 @@ export async function createWsServer( }, ) + function isPrivileged(ws: WebSocket, isLocal: boolean): boolean { + return isLocal // simplest safe rule: only localhost can do dangerous stuff + } + wss.on( "connection", ( @@ -144,7 +150,7 @@ export async function createWsServer( try { if (isBinary) { // Relay frames from Providers to Consumers - if ((ws as ExtWebSocket).isProvider) { + if (ws === currentProvider) { for (const client of wss.clients) { if ( client !== ws && @@ -165,7 +171,13 @@ export async function createWsServer( return } - const msg = JSON.parse(raw) + let msg + try { + msg = JSON.parse(raw) + } catch { + logger.warn("Invalid JSON received") + return + } // Throttle token touch to once per second — avoids crypto comparison on every event if (token && msg.type !== "get-ip" && msg.type !== "generate-token") { @@ -215,6 +227,20 @@ export async function createWsServer( } if (msg.type === "start-mirror") { + if (!isPrivileged(ws, isLocal)) { + logger.warn("Unauthorized mirror attempt") + return + } + + const consumerCount = [...wss.clients].filter( + c => (c as ExtWebSocket).isConsumer + ).length + + if (consumerCount >= MAX_CONSUMERS) { + logger.warn("Too many consumers, rejecting") + return + } + startMirror() return } @@ -225,12 +251,33 @@ export async function createWsServer( } if (msg.type === "start-provider") { + if (!isPrivileged(ws, isLocal)) { + logger.warn("Unauthorized provider attempt") + return + } + + if (currentProvider) { + logger.warn("Provider already exists, rejecting new one") + return + } + + currentProvider = ws ;(ws as ExtWebSocket).isProvider = true logger.info("Client registered as Screen Provider") return } if (msg.type === "update-config") { + if (!isPrivileged(ws, isLocal)) { + logger.warn("Unauthorized config update attempt") + ws.send(JSON.stringify({ + type: "config-updated", + success: false, + error: "Not authorized", + })) + return + } + try { if ( !msg.config || @@ -351,6 +398,11 @@ export async function createWsServer( return } + if (!isPrivileged(ws, isLocal)) { + logger.warn("Unauthorized input injection attempt") + return + } + await inputHandler.handleMessage(msg as InputMessage) } catch (err: unknown) { logger.error( @@ -362,6 +414,11 @@ export async function createWsServer( }) ws.on("close", () => { + if (ws === currentProvider) { + currentProvider = null + logger.info("Provider disconnected, slot freed") + } + stopMirror() logger.info("Client disconnected") })