Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket web #1005

Merged
merged 11 commits into from
Jan 14, 2025
1 change: 1 addition & 0 deletions docs/src/content/docs/reference/cli/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,10 @@
-c, --cors <string> Enable CORS and sets the allowed origin. Use '*' to
allow any origin.
--remote <string> Remote repository URL to serve
--remote-branch <string> Branch to serve from the remote
--remote-force Force pull from remote repository
--remote-install Install dependencies from remote repository
--dispatch-progress Dispatch progress events to all clients

Check warning on line 402 in docs/src/content/docs/reference/cli/commands.md

View workflow job for this annotation

GitHub Actions / build

The `--dispatch-progress` option is added but lacks a description. Please provide a brief explanation of what this option does and its purpose.
pelikhan marked this conversation as resolved.
Show resolved Hide resolved
-h, --help display help for command
```

Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"cache:clear": "cd packages/sample/ && yarn cache:clear",
"test:scripts": "cd packages/sample/ && yarn test:scripts",
"test:scripts:view": "cd packages/sample/ && yarn test:scripts:view",
"serve:cli": "node --watch --watch-path=packages/cli/built packages/cli/built/genaiscript.cjs serve",
"serve:cli": "node --watch --watch-path=packages/cli/built packages/cli/built/genaiscript.cjs serve --dispatch-progress",
"serve:web": "yarn --cwd packages/web watch",
"serve": "run-p serve:*",
"docs": "cd docs && ./node_modules/.bin/astro telemetry disable && ./node_modules/.bin/astro dev --host",
Expand Down
4 changes: 4 additions & 0 deletions packages/cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ export async function cli() {
"--remote-install",
"Install dependencies from remote repository"
)
.option(
"--dispatch-progress",
"Dispatch progress events to all clients"
)
.action(startServer) // Action to start the server

// Define 'parse' command group for parsing tasks
Expand Down
1 change: 0 additions & 1 deletion packages/cli/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ export async function runScriptInternal(
CancellationOptions & {
outputTrace?: MarkdownTrace
cli?: boolean
workspaceFiles?: WorkspaceFile[]
infoCb?: (partialResponse: { text: string }) => void
partialCb?: (progress: ChatCompletionsProgressReport) => void
}
Expand Down
114 changes: 66 additions & 48 deletions packages/cli/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@ import {
UNHANDLED_ERROR_CODE,
MODEL_PROVIDER_CLIENT,
} from "../../core/src/constants"
import {
isCancelError,
errorMessage,
serializeError,
} from "../../core/src/error"
import { ServerResponse, host, runtimeHost } from "../../core/src/host"
import { isCancelError, serializeError } from "../../core/src/error"
import { host, runtimeHost } from "../../core/src/host"
import { MarkdownTrace, TraceChunkEvent } from "../../core/src/trace"
import {
logVerbose,
Expand All @@ -37,8 +33,8 @@ import {
ResponseStatus,
LanguageModelConfiguration,
ServerEnvResponse,
ServerResponse,
} from "../../core/src/server/messages"
import { envInfo } from "./info"
import { LanguageModel } from "../../core/src/chat"
import {
ChatCompletionResponse,
Expand Down Expand Up @@ -70,13 +66,15 @@ export async function startServer(options: {
remoteBranch?: string
remoteForce?: boolean
remoteInstall?: boolean
dispatchProgress?: boolean
}) {
// Parse and set the server port, using a default if not specified.
const corsOrigin = options.cors || process.env.GENAISCRIPT_CORS_ORIGIN
const port = parseInt(options.port) || SERVER_PORT
const apiKey = options.apiKey || process.env.GENAISCRIPT_API_KEY
const serverHost = options.network ? "0.0.0.0" : "127.0.0.1"
const remote = options.remote
const dispatchProgress = !!options.dispatchProgress

// store original working directory
const cwd = process.cwd()
Expand Down Expand Up @@ -266,6 +264,45 @@ export async function startServer(options: {
logVerbose(`clients: closed (${wss.clients.size} clients)`)
)

const send = (payload: object) => {
const cmsg = JSON.stringify(payload)
if (dispatchProgress)
for (const client of this.clients) client.send(cmsg)
else ws?.send(cmsg)
}
const sendProgress = (
runId: string,
payload: Omit<PromptScriptProgressResponseEvent, "type" | "runId">
) => {
send({
type: "script.progress",
runId,
...payload,
} satisfies PromptScriptProgressResponseEvent)
}

// send traces of in-flight runs
for (const [runId, run] of Object.entries(runs)) {
chunkString(run.outputTrace.content).forEach((c) =>
ws.send(
JSON.stringify({
type: "script.progress",
runId,
trace: c,
} satisfies PromptScriptProgressResponseEvent)
)
)
chunkString(run.trace.content).forEach((c) =>
ws.send(
JSON.stringify({
type: "script.progress",
runId,
trace: c,
} satisfies PromptScriptProgressResponseEvent)
)
)
}

// Handle incoming messages based on their type.
ws.on("message", async (msg) => {
const data = JSON.parse(msg.toString()) as RequestMessages
Expand Down Expand Up @@ -340,31 +377,16 @@ export async function startServer(options: {
new AbortSignalCancellationController()
const trace = new MarkdownTrace()
const outputTrace = new MarkdownTrace()
const send = (
payload: Omit<
PromptScriptProgressResponseEvent,
"type" | "runId"
>
) =>
ws?.send(
JSON.stringify(<
PromptScriptProgressResponseEvent
>{
type: "script.progress",
runId,
...payload,
})
)
trace.addEventListener(TRACE_CHUNK, (ev) => {
const tev = ev as TraceChunkEvent
chunkString(tev.chunk, 2 << 14).forEach((c) =>
send({ trace: c })
sendProgress(runId, { trace: c })
)
})
outputTrace.addEventListener(TRACE_CHUNK, (ev) => {
const tev = ev as TraceChunkEvent
chunkString(tev.chunk, 2 << 14).forEach((c) =>
send({ output: c })
sendProgress(runId, { output: c })
)
})
logVerbose(`run ${runId}: starting ${script}`)
Expand All @@ -375,15 +397,15 @@ export async function startServer(options: {
outputTrace,
cancellationToken: canceller.token,
infoCb: ({ text }) => {
send({ progress: text })
sendProgress(runId, { progress: text })
},
partialCb: ({
responseChunk,
responseSoFar,
tokensSoFar,
responseTokens,
}) => {
send({
sendProgress(runId, {
response: responseSoFar,
responseChunk,
tokens: tokensSoFar,
Expand All @@ -396,33 +418,29 @@ export async function startServer(options: {
logVerbose(
`\nrun ${runId}: completed with ${exitCode}`
)
ws?.send(
JSON.stringify(<
PromptScriptEndResponseEvent
>{
type: "script.end",
runId,
exitCode,
result,
})
)
send({
type: "script.end",
runId,
exitCode,
result,
} satisfies PromptScriptEndResponseEvent)
})
.catch((e) => {
if (canceller.controller.signal.aborted) return
if (!isCancelError(e)) trace.error(e)
logError(`\nrun ${runId}: failed`)
logError(e)
ws?.send(
JSON.stringify(<
PromptScriptEndResponseEvent
>{
type: "script.end",
runId,
exitCode: isCancelError(e)
? USER_CANCELLED_ERROR_CODE
: UNHANDLED_ERROR_CODE,
})
)
send({
type: "script.end",
runId,
result: {
status: "error",
error: serializeError(e),
},
exitCode: isCancelError(e)
? USER_CANCELLED_ERROR_CODE
: UNHANDLED_ERROR_CODE,
} satisfies PromptScriptEndResponseEvent)
})
runs[runId] = {
runner,
Expand Down Expand Up @@ -467,7 +485,7 @@ export async function startServer(options: {
} finally {
assert(!!response)
if (response.error) logError(response.error)
ws.send(JSON.stringify({ id, response }))
send({ id, type, response })
}
})
})
Expand Down
13 changes: 13 additions & 0 deletions packages/core/src/assert.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
export function assert(
cond: boolean,
msg = "Assertion failed",
// eslint-disable-next-line @typescript-eslint/no-explicit-any
debugData?: any
) {
if (!cond) {
if (debugData) console.error(msg || `assertion failed`, debugData)
// eslint-disable-next-line no-debugger
debugger
throw new Error(msg)
}
}
4 changes: 4 additions & 0 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ export const TRACE_CHUNK = "traceChunk"
export const TRACE_DETAILS = "traceDetails"
export const RECONNECT = "reconnect"
export const OPEN = "open"
export const CLOSE = "close"
export const MESSAGE = "message"
export const ERROR = "error"
export const CONNECT = "connect"
export const MAX_TOOL_CALLS = 10000

// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference
Expand Down
8 changes: 0 additions & 8 deletions packages/core/src/host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,6 @@ export interface RetrievalService {
): Promise<RetrievalSearchResponse>
}

export interface ServerResponse extends ResponseStatus {
version: string
node: string
platform: string
arch: string
pid: number
}

export interface ServerManager {
start(): Promise<void>
close(): Promise<void>
Expand Down
Loading
Loading