Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speak api #1001

Merged
merged 4 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/public/schemas/llms.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@
"type": "boolean",
"description": "Indicates if speech transcription is supported"
},
"speech": {
"type": "boolean",
"description": "Indicates if speech synthesis is supported"
},
"openaiCompatibility": {
"type": "string",
"description": "Uses OpenAI API compatibility layer documentation URL"
Expand Down
18 changes: 18 additions & 0 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,30 @@ export type TranscribeFunction = (
options: TraceOptions & CancellationOptions
) => Promise<TranscriptionResult>

export type CreateSpeechRequest = {
input: string
model: string
voice?: string
}

export type CreateSpeechResult = {
audio: Uint8Array
error?: SerializedError
}

export type SpeechFunction = (
req: CreateSpeechRequest,
cfg: LanguageModelConfiguration,
options: TraceOptions & CancellationOptions
) => Promise<CreateSpeechResult>

export interface LanguageModel {
id: string
completer: ChatCompletionHandler
listModels?: ListModelsFunction
pullModel?: PullModelFunction
transcriber?: TranscribeFunction
speaker?: SpeechFunction
}

async function runToolCalls(
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export const SMALL_MODEL_ID = "small"
export const LARGE_MODEL_ID = "large"
export const VISION_MODEL_ID = "vision"
export const TRANSCRIPTION_MODEL_ID = "transcription"
export const SPEECH_MODEL_ID = "speech"
export const DEFAULT_FENCE_FORMAT: FenceFormat = "xml"
export const DEFAULT_TEMPERATURE = 0.8
export const BUILTIN_PREFIX = "_builtin/"
Expand Down Expand Up @@ -200,6 +201,7 @@ export const MODEL_PROVIDERS = Object.freeze<
bearerToken?: boolean
listModels?: boolean
transcribe?: boolean
speech?: boolean
aliases?: Record<string, string>
}[]
>(CONFIGURATION_DATA.providers)
Expand Down
8 changes: 6 additions & 2 deletions packages/core/src/llms.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,26 @@
"detail": "OpenAI (or compatible)",
"bearerToken": true,
"transcribe": true,
"speech": true,
"aliases": {
"large": "gpt-4o",
"small": "gpt-4o-mini",
"vision": "gpt-4o",
"embeddings": "text-embedding-3-small",
"reasoning": "o1",
"reasoning_small": "o1-mini",
"transcription": "whisper-1"
"transcription": "whisper-1",
"speech": "tts-1"
}
},
{
"id": "azure",
"detail": "Azure OpenAI deployment",
"listModels": false,
"bearerToken": false,
"prediction": false
"prediction": false,
"transcribe": true,
"speech": true
},
{
"id": "azure_serverless",
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/lm.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { Transform } from "stream"
import { AICIModel } from "./aici"
import { AnthropicBedrockModel, AnthropicModel } from "./anthropic"
import { LanguageModel } from "./chat"
Expand Down Expand Up @@ -39,5 +38,6 @@ export function resolveLanguageModel(provider: string): LanguageModel {
return LocalOpenAICompatibleModel(provider, {
listModels: features?.listModels !== false,
transcribe: features?.transcribe,
speech: features?.speech,
})
}
58 changes: 52 additions & 6 deletions packages/core/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import {
import { estimateTokens } from "./tokens"
import {
ChatCompletionHandler,
CreateSpeechRequest,
CreateSpeechResult,
CreateTranscriptionRequest,
LanguageModel,
PullModelFunction,
} from "./chat"
import { RequestError, errorMessage, serializeError } from "./error"
import { createFetch, iterateBody, traceFetchPost } from "./fetch"
import { createFetch, traceFetchPost } from "./fetch"
import { parseModelIdentifier } from "./models"
import { JSON5TryParse } from "./json5"
import {
Expand Down Expand Up @@ -435,7 +436,6 @@ export async function OpenAITranscribe(
options: TraceOptions & CancellationOptions
): Promise<TranscriptionResult> {
const { trace } = options || {}
const fetch = await createFetch(options)
try {
logVerbose(`${cfg.provider}: transcribe with ${cfg.model}`)
const route = req.translate ? "translations" : "transcriptions"
Expand All @@ -453,33 +453,79 @@ export async function OpenAITranscribe(
method: "POST",
headers: {
...getConfigHeaders(cfg),
ContentType: "multipart/form-data",
"Content-Type": "multipart/form-data",
Accept: "application/json",
},
body: body,
}
traceFetchPost(trace, url, freq.headers, freq.body)
// TODO: switch back to cross-fetch in the future
const res = await global.fetch(url, freq as any)
trace.itemValue(`status`, `${res.status} ${res.statusText}`)
const j = await res.json()
return j
if (!res.ok) return { text: undefined, error: j?.error }
else return j
} catch (e) {
logError(e)
trace?.error(e)
return { text: undefined, error: serializeError(e) }
}
}

export async function OpenAISpeech(
req: CreateSpeechRequest,
cfg: LanguageModelConfiguration,
options: TraceOptions & CancellationOptions
): Promise<CreateSpeechResult> {
const { model, input, voice = "alloy", ...rest } = req
const { trace } = options || {}
const fetch = await createFetch(options)
try {
logVerbose(`${cfg.provider}: speak with ${cfg.model}`)
const url = `${cfg.base}/audio/speech`
trace.itemValue(`url`, `[${url}](${url})`)
const body = {
model,
input,
voice,
}
const freq = {
method: "POST",
headers: {
...getConfigHeaders(cfg),
"Content-Type": "application/json",
},
body: JSON.stringify(body),
}
traceFetchPost(trace, url, freq.headers, body)
// TODO: switch back to cross-fetch in the future
const res = await fetch(url, freq as any)
trace.itemValue(`status`, `${res.status} ${res.statusText}`)
if (!res.ok)
return { audio: undefined, error: (await res.json())?.error }
const j = await res.arrayBuffer()
return { audio: new Uint8Array(j) } satisfies CreateSpeechResult
} catch (e) {
logError(e)
trace?.error(e)
return {
audio: undefined,
error: serializeError(e),
} satisfies CreateSpeechResult
}
}

export function LocalOpenAICompatibleModel(
providerId: string,
options: { listModels?: boolean; transcribe?: boolean }
options: { listModels?: boolean; transcribe?: boolean; speech?: boolean }
) {
return Object.freeze<LanguageModel>(
deleteUndefinedValues({
completer: OpenAIChatCompletion,
id: providerId,
listModels: options?.listModels ? OpenAIListModels : undefined,
transcriber: options?.transcribe ? OpenAITranscribe : undefined,
speaker: options?.speech ? OpenAISpeech : undefined,
})
)
}
103 changes: 94 additions & 9 deletions packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ import { GenerationOptions } from "./generation"
import { promptParametersSchemaToJSONSchema } from "./parameters"
import { consoleLogFormat, stdout } from "./logging"
import { isGlobMatch } from "./glob"
import { arrayify, assert, logError, logVerbose, logWarn } from "./util"
import {
arrayify,
assert,
deleteUndefinedValues,
dotGenaiscriptPath,
logError,
logVerbose,
logWarn,
} from "./util"
import { renderShellOutput } from "./chatrender"
import { jinjaRender } from "./jinja"
import { mustacheRender } from "./mustache"
Expand All @@ -39,6 +47,7 @@ import { delay, uniq } from "es-toolkit"
import {
addToolDefinitionsMessage,
appendSystemMessage,
CreateSpeechRequest,
executeChatSession,
mergeGenerationOptions,
toChatCompletionUserMessage,
Expand All @@ -55,6 +64,7 @@ import {
DOCS_DEF_FILES_IS_EMPTY_URL,
TRANSCRIPTION_MEMORY_CACHE_NAME,
TRANSCRIPTION_MODEL_ID,
SPEECH_MODEL_ID,
} from "./constants"
import { renderAICI } from "./aici"
import { resolveSystems, resolveTools } from "./systems"
Expand Down Expand Up @@ -82,6 +92,9 @@ import { BufferToBlob } from "./bufferlike"
import { host } from "./host"
import { srtVttRender } from "./transcription"
import { deleteEmptyValues } from "./clone"
import { hash } from "./crypto"
import { fileTypeFromBuffer } from "file-type"
import { writeFile } from "fs"

export function createChatTurnGenerationContext(
options: GenerationOptions,
Expand Down Expand Up @@ -660,15 +673,18 @@ export function createChatGenerationContext(
configuration.provider
)
if (!transcriber)
throw new Error("model driver not found for " + info.model)
throw new Error("audio transcribe not found for " + info.model)
const audioFile = await videoExtractAudio(audio, {
trace: transcriptionTrace,
})
const file = await BufferToBlob(await host.readFile(audioFile))
const update: () => Promise<TranscriptionResult> = async () => {
trace.itemValue(`model`, configuration.model)
trace.itemValue(`file size`, prettyBytes(file.size))
trace.itemValue(`file type`, file.type)
transcriptionTrace.itemValue(`model`, configuration.model)
transcriptionTrace.itemValue(
`file size`,
prettyBytes(file.size)
)
transcriptionTrace.itemValue(`file type`, file.type)
const res = await transcriber(
{
file,
Expand Down Expand Up @@ -703,12 +719,15 @@ export function createChatGenerationContext(
update,
(res) => !res.error
)
trace.itemValue(`cache ${hit.cached ? "hit" : "miss"}`, hit.key)
transcriptionTrace.itemValue(
`cache ${hit.cached ? "hit" : "miss"}`,
hit.key
)
res = hit.value
} else res = await update()
trace.fence(res.text, "markdown")
if (res.error) trace.error(errorMessage(res.error))
if (res.segments) trace.fence(res.segments, "yaml")
transcriptionTrace.fence(res.text, "markdown")
if (res.error) transcriptionTrace.error(errorMessage(res.error))
if (res.segments) transcriptionTrace.fence(res.segments, "yaml")
return res
} catch (e) {
logError(e)
Expand All @@ -722,6 +741,71 @@ export function createChatGenerationContext(
}
}

const speak = async (
input: string,
options?: SpeechOptions
): Promise<SpeechResult> => {
const { cache, voice, ...rest } = options || {}
const speechTrace = trace.startTraceDetails("🦜 speak")
try {
const conn: ModelConnectionOptions = {
model: options?.model || SPEECH_MODEL_ID,
}
const { info, configuration } = await resolveModelConnectionInfo(
conn,
{
trace: speechTrace,
cancellationToken,
token: true,
}
)
if (info.error) throw new Error(info.error)
if (!configuration) throw new Error("model configuration not found")
checkCancelled(cancellationToken)
const { ok } = await runtimeHost.pullModel(configuration, {
trace: speechTrace,
cancellationToken,
})
if (!ok) throw new Error(`failed to pull model ${conn}`)
checkCancelled(cancellationToken)
const { speaker } = await resolveLanguageModel(
configuration.provider
)
if (!speaker)
throw new Error("speech converter not found for " + info.model)
speechTrace.itemValue(`model`, configuration.model)
const req = deleteUndefinedValues({
input,
model: configuration.model,
voice,
}) satisfies CreateSpeechRequest
const res = await speaker(req, configuration, {
trace: speechTrace,
cancellationToken,
})
if (res.error) {
speechTrace.error(errorMessage(res.error))
return { error: res.error } satisfies SpeechResult
}
const h = await hash(res.audio, { length: 20 })
const { ext } = (await fileTypeFromBuffer(res.audio)) || {}
const filename = dotGenaiscriptPath("speech", h + "." + ext)
await host.writeFile(filename, res.audio)
return {
filename,
} satisfies SpeechResult
} catch (e) {
logError(e)
speechTrace.error(e)
return {
filename: undefined,
error: serializeError(e),
} satisfies SpeechResult
} finally {
speechTrace.endDetails()
}
}

const runPrompt = async (
generator: string | PromptGenerator,
runOptions?: PromptGeneratorOptions
Expand Down Expand Up @@ -968,6 +1052,7 @@ export function createChatGenerationContext(
prompt,
runPrompt,
transcribe,
speak,
})

return ctx
Expand Down
Loading
Loading