Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from chromadb.api.types import (
SparseEmbeddingFunction,
SparseEmbeddings,
Documents,
)
from typing import Dict, Any
from enum import Enum
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.utils.sparse_embedding_utils import _sort_sparse_vectors
import os
from typing import Union


class ChromaCloudSpladeEmbeddingModel(Enum):
SPLADE_PP_EN_V1 = "prithivida/Splade_PP_en_v1"


class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]):
def __init__(
self,
api_key_env_var: str = "CHROMA_API_KEY",
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
):
"""
Initialize the ChromaCloudSpladeEmbeddingFunction.

Args:
api_key_env_var (str, optional): Environment variable name that contains your API key.
Defaults to "CHROMA_API_KEY".
"""
try:
import httpx
except ImportError:
raise ValueError(
"The httpx python package is not installed. Please install it with `pip install httpx`"
)
self.api_key_env_var = api_key_env_var
self.api_key = os.getenv(self.api_key_env_var)
if not self.api_key:
raise ValueError(
f"API key not found in environment variable {self.api_key_env_var}"
)
self.model = model
self._api_url = "https://embed.trychroma.com/embed_sparse"
self._session = httpx.Client()
self._session.headers.update(
{
"x-chroma-token": self.api_key,
"x-chroma-embedding-model": self.model.value,
}
)

def __del__(self) -> None:
"""
Cleanup the HTTP client session when the object is destroyed.
"""
if hasattr(self, "_session"):
self._session.close()

def close(self) -> None:
"""
Explicitly close the HTTP client session.
Call this method when you're done using the embedding function.
"""
if hasattr(self, "_session"):
self._session.close()

def __call__(self, input: Documents) -> SparseEmbeddings:
"""
Generate embeddings for the given documents.

Args:
input (Documents): The documents to generate embeddings for.
"""
if not input:
return []

payload: Dict[str, Union[str, Documents]] = {
"texts": list(input),
"task": "",
"target": "",
}

try:
import httpx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[BestPractice]

Redundant httpx import inside method. The httpx library is already imported and validated in the constructor (line 32). The import on line 85 inside the __call__ method is unnecessary and could impact performance with repeated calls.

Suggested change
import httpx
response = self._session.post(self._api_url, json=payload, timeout=60)

Note: Reusing the existing httpx.Client instance (self._session) is more efficient as it leverages connection pooling and avoids redundant imports.

Committable suggestion

Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Context for Agents
[**BestPractice**]

Redundant httpx import inside method. The httpx library is already imported and validated in the constructor (line 32). The import on line 85 inside the `__call__` method is unnecessary and could impact performance with repeated calls.

```suggestion
response = self._session.post(self._api_url, json=payload, timeout=60)
```

*Note: Reusing the existing httpx.Client instance (self._session) is more efficient as it leverages connection pooling and avoids redundant imports.*

⚡ **Committable suggestion**

Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

File: chromadb/utils/embedding_functions/chroma_cloud_splade_embedding_function.py
Line: 85


response = self._session.post(self._api_url, json=payload, timeout=60)
response.raise_for_status()
json_response = response.json()
return self._parse_response(json_response)
except httpx.HTTPStatusError as e:
raise RuntimeError(
f"Failed to get embeddings from Chroma Cloud API: HTTP {e.response.status_code} - {e.response.text}"
)
except httpx.TimeoutException:
raise RuntimeError("Request to Chroma Cloud API timed out after 60 seconds")
except httpx.HTTPError as e:
raise RuntimeError(f"Failed to get embeddings from Chroma Cloud API: {e}")
except Exception as e:
raise RuntimeError(f"Unexpected error calling Chroma Cloud API: {e}")

def _parse_response(self, response: Any) -> SparseEmbeddings:
"""
Parse the response from the Chroma Cloud Sparse Embedding API.
"""
embeddings: SparseEmbeddings = response["embeddings"]

# Ensure indices are sorted in ascending order
_sort_sparse_vectors(embeddings)

return embeddings

@staticmethod
def name() -> str:
return "chroma-cloud-splade"

@staticmethod
def build_from_config(
config: Dict[str, Any]
) -> "SparseEmbeddingFunction[Documents]":
api_key_env_var = config.get("api_key_env_var")
model = config.get("model")
if model is None:
raise ValueError("model must be provided in config")
if not api_key_env_var:
raise ValueError("api_key_env_var must be provided in config")
return ChromaCloudSpladeEmbeddingFunction(
api_key_env_var=api_key_env_var,
model=ChromaCloudSpladeEmbeddingModel(model),
)

def get_config(self) -> Dict[str, Any]:
return {"api_key_env_var": self.api_key_env_var, "model": self.model.value}

def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model" in new_config:
raise ValueError(
"model cannot be changed after the embedding function has been initialized"
)

@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
validate_config_schema(config, "chroma-cloud-splade")
5 changes: 3 additions & 2 deletions clients/new-js/packages/ai-embeddings/all/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@
"@chroma-core/openai": "workspace:^",
"@chroma-core/together-ai": "workspace:^",
"@chroma-core/voyageai": "workspace:^",
"@chroma-core/chroma-cloud-qwen": "workspace:^"
"@chroma-core/chroma-cloud-qwen": "workspace:^",
"@chroma-core/chroma-cloud-splade": "workspace:^"
},
"engines": {
"node": ">=20"
},
"publishConfig": {
"access": "public"
}
}
}
1 change: 1 addition & 0 deletions clients/new-js/packages/ai-embeddings/all/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ export * from "@chroma-core/openai";
export * from "@chroma-core/together-ai";
export * from "@chroma-core/voyageai";
export * from "@chroma-core/chroma-cloud-qwen";
export * from "@chroma-core/chroma-cloud-splade";
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Chroma Cloud Splade Embeddings

This package provides a sparse embedding function for the Splade model family hosted on Chroma's cloud embedding service. Splade (Sparse Lexical and Expansion) embeddings are particularly effective for information retrieval tasks, combining the benefits of sparse representations with learned relevance.

## Installation

```bash
npm install @chroma-core/chroma-cloud-splade
```

## Usage

```typescript
import { ChromaClient } from "chromadb";
import {
ChromaCloudSpladeEmbeddingFunction,
ChromaCloudSpladeEmbeddingModel,
} from "@chroma-core/chroma-cloud-splade";

// Initialize the embedder
const embedder = new ChromaCloudSpladeEmbeddingFunction({
model: ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
apiKeyEnvVar: "CHROMA_API_KEY",
});

## Configuration

Set your Chroma API key as an environment variable:

```bash
export CHROMA_API_KEY=your-api-key
```

Get your API key from [Chroma's dashboard](https://trychroma.com/).

## Configuration Options

- **model**: Model to use for sparse embeddings (default: `SPLADE_PP_EN_V1`)
- **apiKeyEnvVar**: Environment variable name for API key (default: `CHROMA_API_KEY`)

## Supported Models

- `prithivida/Splade_PP_en_v1` - Splade++ English v1 model optimized for information retrieval
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import type { Config } from "jest";

const config: Config = {
preset: "ts-jest",
testEnvironment: "node",
testMatch: ["**/*.test.ts"],
transform: {
"^.+\\.tsx?$": [
"ts-jest",
{
useESM: true,
},
],
},
extensionsToTreatAsEsm: [".ts"],
moduleNameMapper: {
"^(\\.{1,2}/.*)\\.js$": "$1",
},
setupFiles: ["./jest.setup.ts"],
};

export default config;
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import * as dotenv from "dotenv";

dotenv.config({ path: "../../../.env" });
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"name": "@chroma-core/chroma-cloud-splade",
"version": "0.1.7",
"private": false,
"description": "Chroma Cloud Splade sparse embedding function",
"main": "dist/cjs/chroma-cloud-splade.cjs",
"types": "dist/chroma-cloud-splade.d.ts",
"module": "dist/chroma-cloud-splade.legacy-esm.js",
"type": "module",
"exports": {
".": {
"import": {
"types": "./dist/chroma-cloud-splade.d.ts",
"default": "./dist/chroma-cloud-splade.mjs"
},
"require": {
"types": "./dist/cjs/chroma-cloud-splade.d.cts",
"default": "./dist/cjs/chroma-cloud-splade.cjs"
}
}
},
"files": [
"src",
"dist"
],
"scripts": {
"clean": "rimraf dist",
"prebuild": "rimraf dist",
"build": "tsup",
"watch": "tsup --watch",
"test": "jest"
},
"devDependencies": {
"@jest/globals": "^29.7.0",
"dotenv": "^16.3.1",
"jest": "^29.7.0",
"rimraf": "^5.0.0",
"ts-jest": "^29.1.2",
"ts-node": "^10.9.2",
"tsup": "^8.3.5"
},
"peerDependencies": {
"chromadb": "workspace:^"
},
"dependencies": {
"@chroma-core/ai-embeddings-common": "workspace:^"
},
"engines": {
"node": ">=20"
},
"publishConfig": {
"access": "public"
}
}
Loading
Loading