diff --git a/clients/new-js/packages/ai-embeddings/openai/src/index.ts b/clients/new-js/packages/ai-embeddings/openai/src/index.ts index 6c76109de73..0530a8ac7d6 100644 --- a/clients/new-js/packages/ai-embeddings/openai/src/index.ts +++ b/clients/new-js/packages/ai-embeddings/openai/src/index.ts @@ -5,6 +5,7 @@ import { registerEmbeddingFunction, } from "chromadb"; import OpenAI from "openai"; +import type { ClientOptions } from 'openai'; import { validateConfigSchema } from "@chroma-core/ai-embeddings-common"; const NAME = "openai"; @@ -14,6 +15,7 @@ export interface OpenAIConfig { model_name: string; organization_id?: string; dimensions?: number; + api_base?: string; } export interface OpenAIArgs { @@ -22,6 +24,7 @@ export interface OpenAIArgs { organizationId?: string; dimensions?: number; apiKey?: string; + apiBase?: string; } export class OpenAIEmbeddingFunction implements EmbeddingFunction { @@ -30,6 +33,7 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction { private readonly modelName: string; private readonly dimensions: number | undefined; private readonly organizationId: string | undefined; + private readonly apiBase: string | undefined; private client: OpenAI; constructor(args: OpenAIArgs) { @@ -38,6 +42,7 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction { modelName, dimensions, organizationId, + apiBase, } = args; const apiKey = args.apiKey || process.env[apiKeyEnvVar]; @@ -51,8 +56,11 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction { this.organizationId = organizationId; this.apiKeyEnvVar = apiKeyEnvVar; this.dimensions = dimensions; + this.apiBase = apiBase; - this.client = new OpenAI({ apiKey, organization: this.organizationId }); + let clientConfig: ClientOptions = { apiKey, organization: this.organizationId }; + if (this.apiBase) clientConfig.baseURL = this.apiBase; + this.client = new OpenAI(clientConfig); } public async generate(texts: string[]): Promise { @@ -78,6 +86,7 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction { modelName: config.model_name, organizationId: config.organization_id, dimensions: config.dimensions, + apiBase: config.api_base, }); } @@ -87,6 +96,7 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction { model_name: this.modelName, organization_id: this.organizationId, dimensions: this.dimensions, + api_base: this.apiBase, }; } diff --git a/docs/docs.trychroma.com/markdoc/content/integrations/embedding-models/openai.md b/docs/docs.trychroma.com/markdoc/content/integrations/embedding-models/openai.md index 4a74ee0956a..4efb986690b 100644 --- a/docs/docs.trychroma.com/markdoc/content/integrations/embedding-models/openai.md +++ b/docs/docs.trychroma.com/markdoc/content/integrations/embedding-models/openai.md @@ -76,6 +76,18 @@ collection = await client.getCollection({ }); ``` +To use the OpenAI embedding models on other platforms such as GitHub AI Models, you can use the `apiBase` parameter: +```typescript +import { OpenAIEmbeddingFunction } from "@chroma-core/openai"; + +const embeddingFunction = new OpenAIEmbeddingFunction({ + apiKey: "apiKey", + apiBase: "https://models.github.ai/inference", + modelName: "text-embedding-3-small", +}); +``` + + {% /Tab %} {% /Tabs %}