Skip to content

Commit

Permalink
speak api (#1001)
Browse files Browse the repository at this point in the history
* speak api

* feat: ✨ add speech synthesis support to language models

* feat: ✨ add speech synthesis support with new APIs
  • Loading branch information
pelikhan authored Jan 14, 2025
1 parent 368fde7 commit 195df26
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 19 deletions.
4 changes: 4 additions & 0 deletions docs/public/schemas/llms.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@
"type": "boolean",
"description": "Indicates if speech transcription is supported"
},
"speech": {
"type": "boolean",
"description": "Indicates if speech synthesis is supported"
},
"openaiCompatibility": {
"type": "string",
"description": "Uses OpenAI API compatibility layer documentation URL"
Expand Down
18 changes: 18 additions & 0 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,30 @@ export type TranscribeFunction = (
options: TraceOptions & CancellationOptions
) => Promise<TranscriptionResult>

export type CreateSpeechRequest = {
input: string
model: string
voice?: string
}

export type CreateSpeechResult = {
audio: Uint8Array
error?: SerializedError
}

export type SpeechFunction = (
req: CreateSpeechRequest,
cfg: LanguageModelConfiguration,
options: TraceOptions & CancellationOptions
) => Promise<CreateSpeechResult>

export interface LanguageModel {
id: string
completer: ChatCompletionHandler
listModels?: ListModelsFunction
pullModel?: PullModelFunction
transcriber?: TranscribeFunction
speaker?: SpeechFunction
}

async function runToolCalls(
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export const SMALL_MODEL_ID = "small"
export const LARGE_MODEL_ID = "large"
export const VISION_MODEL_ID = "vision"
export const TRANSCRIPTION_MODEL_ID = "transcription"
export const SPEECH_MODEL_ID = "speech"
export const DEFAULT_FENCE_FORMAT: FenceFormat = "xml"
export const DEFAULT_TEMPERATURE = 0.8
export const BUILTIN_PREFIX = "_builtin/"
Expand Down Expand Up @@ -200,6 +201,7 @@ export const MODEL_PROVIDERS = Object.freeze<
bearerToken?: boolean
listModels?: boolean
transcribe?: boolean
speech?: boolean
aliases?: Record<string, string>
}[]
>(CONFIGURATION_DATA.providers)
Expand Down
8 changes: 6 additions & 2 deletions packages/core/src/llms.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,26 @@
"detail": "OpenAI (or compatible)",
"bearerToken": true,
"transcribe": true,
"speech": true,
"aliases": {
"large": "gpt-4o",
"small": "gpt-4o-mini",
"vision": "gpt-4o",
"embeddings": "text-embedding-3-small",
"reasoning": "o1",
"reasoning_small": "o1-mini",
"transcription": "whisper-1"
"transcription": "whisper-1",
"speech": "tts-1"
}
},
{
"id": "azure",
"detail": "Azure OpenAI deployment",
"listModels": false,
"bearerToken": false,
"prediction": false
"prediction": false,
"transcribe": true,
"speech": true
},
{
"id": "azure_serverless",
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/lm.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { Transform } from "stream"
import { AICIModel } from "./aici"
import { AnthropicBedrockModel, AnthropicModel } from "./anthropic"
import { LanguageModel } from "./chat"
Expand Down Expand Up @@ -39,5 +38,6 @@ export function resolveLanguageModel(provider: string): LanguageModel {
return LocalOpenAICompatibleModel(provider, {
listModels: features?.listModels !== false,
transcribe: features?.transcribe,
speech: features?.speech,
})
}
58 changes: 52 additions & 6 deletions packages/core/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import {
import { estimateTokens } from "./tokens"
import {
ChatCompletionHandler,
CreateSpeechRequest,
CreateSpeechResult,
CreateTranscriptionRequest,
LanguageModel,
PullModelFunction,
} from "./chat"
import { RequestError, errorMessage, serializeError } from "./error"
import { createFetch, iterateBody, traceFetchPost } from "./fetch"
import { createFetch, traceFetchPost } from "./fetch"
import { parseModelIdentifier } from "./models"
import { JSON5TryParse } from "./json5"
import {
Expand Down Expand Up @@ -435,7 +436,6 @@ export async function OpenAITranscribe(
options: TraceOptions & CancellationOptions
): Promise<TranscriptionResult> {
const { trace } = options || {}
const fetch = await createFetch(options)
try {
logVerbose(`${cfg.provider}: transcribe with ${cfg.model}`)
const route = req.translate ? "translations" : "transcriptions"
Expand All @@ -453,33 +453,79 @@ export async function OpenAITranscribe(
method: "POST",
headers: {
...getConfigHeaders(cfg),
ContentType: "multipart/form-data",
"Content-Type": "multipart/form-data",
Accept: "application/json",
},
body: body,
}
traceFetchPost(trace, url, freq.headers, freq.body)
// TODO: switch back to cross-fetch in the future
const res = await global.fetch(url, freq as any)
trace.itemValue(`status`, `${res.status} ${res.statusText}`)
const j = await res.json()
return j
if (!res.ok) return { text: undefined, error: j?.error }
else return j
} catch (e) {
logError(e)
trace?.error(e)
return { text: undefined, error: serializeError(e) }
}
}

export async function OpenAISpeech(
req: CreateSpeechRequest,
cfg: LanguageModelConfiguration,
options: TraceOptions & CancellationOptions
): Promise<CreateSpeechResult> {
const { model, input, voice = "alloy", ...rest } = req
const { trace } = options || {}
const fetch = await createFetch(options)
try {
logVerbose(`${cfg.provider}: speak with ${cfg.model}`)
const url = `${cfg.base}/audio/speech`
trace.itemValue(`url`, `[${url}](${url})`)
const body = {
model,
input,
voice,
}
const freq = {
method: "POST",
headers: {
...getConfigHeaders(cfg),
"Content-Type": "application/json",
},
body: JSON.stringify(body),
}
traceFetchPost(trace, url, freq.headers, body)
// TODO: switch back to cross-fetch in the future
const res = await fetch(url, freq as any)
trace.itemValue(`status`, `${res.status} ${res.statusText}`)
if (!res.ok)
return { audio: undefined, error: (await res.json())?.error }
const j = await res.arrayBuffer()
return { audio: new Uint8Array(j) } satisfies CreateSpeechResult
} catch (e) {
logError(e)
trace?.error(e)
return {
audio: undefined,
error: serializeError(e),
} satisfies CreateSpeechResult
}
}

export function LocalOpenAICompatibleModel(
providerId: string,
options: { listModels?: boolean; transcribe?: boolean }
options: { listModels?: boolean; transcribe?: boolean; speech?: boolean }
) {
return Object.freeze<LanguageModel>(
deleteUndefinedValues({
completer: OpenAIChatCompletion,
id: providerId,
listModels: options?.listModels ? OpenAIListModels : undefined,
transcriber: options?.transcribe ? OpenAITranscribe : undefined,
speaker: options?.speech ? OpenAISpeech : undefined,
})
)
}
103 changes: 94 additions & 9 deletions packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ import { GenerationOptions } from "./generation"
import { promptParametersSchemaToJSONSchema } from "./parameters"
import { consoleLogFormat, stdout } from "./logging"
import { isGlobMatch } from "./glob"
import { arrayify, assert, logError, logVerbose, logWarn } from "./util"
import {
arrayify,
assert,
deleteUndefinedValues,
dotGenaiscriptPath,
logError,
logVerbose,
logWarn,
} from "./util"
import { renderShellOutput } from "./chatrender"
import { jinjaRender } from "./jinja"
import { mustacheRender } from "./mustache"
Expand All @@ -39,6 +47,7 @@ import { delay, uniq } from "es-toolkit"
import {
addToolDefinitionsMessage,
appendSystemMessage,
CreateSpeechRequest,
executeChatSession,
mergeGenerationOptions,
toChatCompletionUserMessage,
Expand All @@ -55,6 +64,7 @@ import {
DOCS_DEF_FILES_IS_EMPTY_URL,
TRANSCRIPTION_MEMORY_CACHE_NAME,
TRANSCRIPTION_MODEL_ID,
SPEECH_MODEL_ID,
} from "./constants"
import { renderAICI } from "./aici"
import { resolveSystems, resolveTools } from "./systems"
Expand Down Expand Up @@ -82,6 +92,9 @@ import { BufferToBlob } from "./bufferlike"
import { host } from "./host"
import { srtVttRender } from "./transcription"
import { deleteEmptyValues } from "./clone"
import { hash } from "./crypto"
import { fileTypeFromBuffer } from "file-type"
import { writeFile } from "fs"

export function createChatTurnGenerationContext(
options: GenerationOptions,
Expand Down Expand Up @@ -660,15 +673,18 @@ export function createChatGenerationContext(
configuration.provider
)
if (!transcriber)
throw new Error("model driver not found for " + info.model)
throw new Error("audio transcribe not found for " + info.model)
const audioFile = await videoExtractAudio(audio, {
trace: transcriptionTrace,
})
const file = await BufferToBlob(await host.readFile(audioFile))
const update: () => Promise<TranscriptionResult> = async () => {
trace.itemValue(`model`, configuration.model)
trace.itemValue(`file size`, prettyBytes(file.size))
trace.itemValue(`file type`, file.type)
transcriptionTrace.itemValue(`model`, configuration.model)
transcriptionTrace.itemValue(
`file size`,
prettyBytes(file.size)
)
transcriptionTrace.itemValue(`file type`, file.type)
const res = await transcriber(
{
file,
Expand Down Expand Up @@ -703,12 +719,15 @@ export function createChatGenerationContext(
update,
(res) => !res.error
)
trace.itemValue(`cache ${hit.cached ? "hit" : "miss"}`, hit.key)
transcriptionTrace.itemValue(
`cache ${hit.cached ? "hit" : "miss"}`,
hit.key
)
res = hit.value
} else res = await update()
trace.fence(res.text, "markdown")
if (res.error) trace.error(errorMessage(res.error))
if (res.segments) trace.fence(res.segments, "yaml")
transcriptionTrace.fence(res.text, "markdown")
if (res.error) transcriptionTrace.error(errorMessage(res.error))
if (res.segments) transcriptionTrace.fence(res.segments, "yaml")
return res
} catch (e) {
logError(e)
Expand All @@ -722,6 +741,71 @@ export function createChatGenerationContext(
}
}

const speak = async (
input: string,
options?: SpeechOptions
): Promise<SpeechResult> => {
const { cache, voice, ...rest } = options || {}
const speechTrace = trace.startTraceDetails("🦜 speak")
try {
const conn: ModelConnectionOptions = {
model: options?.model || SPEECH_MODEL_ID,
}
const { info, configuration } = await resolveModelConnectionInfo(
conn,
{
trace: speechTrace,
cancellationToken,
token: true,
}
)
if (info.error) throw new Error(info.error)
if (!configuration) throw new Error("model configuration not found")
checkCancelled(cancellationToken)
const { ok } = await runtimeHost.pullModel(configuration, {
trace: speechTrace,
cancellationToken,
})
if (!ok) throw new Error(`failed to pull model ${conn}`)
checkCancelled(cancellationToken)
const { speaker } = await resolveLanguageModel(
configuration.provider
)
if (!speaker)
throw new Error("speech converter not found for " + info.model)
speechTrace.itemValue(`model`, configuration.model)
const req = deleteUndefinedValues({
input,
model: configuration.model,
voice,
}) satisfies CreateSpeechRequest
const res = await speaker(req, configuration, {
trace: speechTrace,
cancellationToken,
})
if (res.error) {
speechTrace.error(errorMessage(res.error))
return { error: res.error } satisfies SpeechResult
}
const h = await hash(res.audio, { length: 20 })
const { ext } = (await fileTypeFromBuffer(res.audio)) || {}
const filename = dotGenaiscriptPath("speech", h + "." + ext)
await host.writeFile(filename, res.audio)
return {
filename,
} satisfies SpeechResult
} catch (e) {
logError(e)
speechTrace.error(e)
return {
filename: undefined,
error: serializeError(e),
} satisfies SpeechResult
} finally {
speechTrace.endDetails()
}
}

const runPrompt = async (
generator: string | PromptGenerator,
runOptions?: PromptGeneratorOptions
Expand Down Expand Up @@ -968,6 +1052,7 @@ export function createChatGenerationContext(
prompt,
runPrompt,
transcribe,
speak,
})

return ctx
Expand Down
Loading

0 comments on commit 195df26

Please sign in to comment.