Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions src/server/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ export async function createWsServer(

const wss = new WebSocketServer({ noServer: true })
const inputHandler = new InputHandler(inputThrottleMs)
let currentProvider: WebSocket | null = null
const MAX_CONSUMERS = 5
let LAN_IP = "127.0.0.1"
try {
LAN_IP = await getLocalIp()
Expand Down Expand Up @@ -110,6 +112,10 @@ export async function createWsServer(
},
)

function isPrivileged(ws: WebSocket, isLocal: boolean): boolean {
return isLocal // simplest safe rule: only localhost can do dangerous stuff
Comment thread
Akash504-ai marked this conversation as resolved.
}

wss.on(
"connection",
(
Expand Down Expand Up @@ -144,7 +150,7 @@ export async function createWsServer(
try {
if (isBinary) {
// Relay frames from Providers to Consumers
if ((ws as ExtWebSocket).isProvider) {
if (ws === currentProvider) {
for (const client of wss.clients) {
if (
client !== ws &&
Expand All @@ -165,7 +171,13 @@ export async function createWsServer(
return
}

const msg = JSON.parse(raw)
let msg
try {
msg = JSON.parse(raw)
} catch {
logger.warn("Invalid JSON received")
return
}

// Throttle token touch to once per second — avoids crypto comparison on every event
if (token && msg.type !== "get-ip" && msg.type !== "generate-token") {
Expand Down Expand Up @@ -215,6 +227,20 @@ export async function createWsServer(
}

if (msg.type === "start-mirror") {
if (!isPrivileged(ws, isLocal)) {
logger.warn("Unauthorized mirror attempt")
return
}

const consumerCount = [...wss.clients].filter(
c => (c as ExtWebSocket).isConsumer
).length

if (consumerCount >= MAX_CONSUMERS) {
logger.warn("Too many consumers, rejecting")
return
}

startMirror()
return
}
Expand All @@ -225,12 +251,33 @@ export async function createWsServer(
}

if (msg.type === "start-provider") {
if (!isPrivileged(ws, isLocal)) {
logger.warn("Unauthorized provider attempt")
return
}

if (currentProvider) {
logger.warn("Provider already exists, rejecting new one")
return
}

currentProvider = ws
;(ws as ExtWebSocket).isProvider = true
logger.info("Client registered as Screen Provider")
return
}

if (msg.type === "update-config") {
if (!isPrivileged(ws, isLocal)) {
logger.warn("Unauthorized config update attempt")
ws.send(JSON.stringify({
type: "config-updated",
success: false,
error: "Not authorized",
}))
return
}

try {
if (
!msg.config ||
Expand Down Expand Up @@ -351,6 +398,11 @@ export async function createWsServer(
return
}

if (!isPrivileged(ws, isLocal)) {
logger.warn("Unauthorized input injection attempt")
return
}

await inputHandler.handleMessage(msg as InputMessage)
} catch (err: unknown) {
logger.error(
Expand All @@ -362,6 +414,11 @@ export async function createWsServer(
})

ws.on("close", () => {
if (ws === currentProvider) {
currentProvider = null
logger.info("Provider disconnected, slot freed")
}

stopMirror()
logger.info("Client disconnected")
})
Expand Down
Loading