-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[ENH] Add hosted splade embedding function to python and js #5610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,145 @@ | ||||||
from chromadb.api.types import ( | ||||||
SparseEmbeddingFunction, | ||||||
SparseEmbeddings, | ||||||
Documents, | ||||||
) | ||||||
from typing import Dict, Any | ||||||
from enum import Enum | ||||||
from chromadb.utils.embedding_functions.schemas import validate_config_schema | ||||||
from chromadb.utils.sparse_embedding_utils import _sort_sparse_vectors | ||||||
import os | ||||||
from typing import Union | ||||||
|
||||||
|
||||||
class ChromaCloudSpladeEmbeddingModel(Enum): | ||||||
SPLADE_PP_EN_V1 = "prithivida/Splade_PP_en_v1" | ||||||
|
||||||
|
||||||
class ChromaCloudSpladeEmbeddingFunction(SparseEmbeddingFunction[Documents]): | ||||||
def __init__( | ||||||
self, | ||||||
api_key_env_var: str = "CHROMA_API_KEY", | ||||||
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1, | ||||||
): | ||||||
""" | ||||||
Initialize the ChromaCloudSpladeEmbeddingFunction. | ||||||
|
||||||
Args: | ||||||
api_key_env_var (str, optional): Environment variable name that contains your API key. | ||||||
Defaults to "CHROMA_API_KEY". | ||||||
""" | ||||||
try: | ||||||
import httpx | ||||||
except ImportError: | ||||||
raise ValueError( | ||||||
"The httpx python package is not installed. Please install it with `pip install httpx`" | ||||||
) | ||||||
self.api_key_env_var = api_key_env_var | ||||||
self.api_key = os.getenv(self.api_key_env_var) | ||||||
if not self.api_key: | ||||||
raise ValueError( | ||||||
f"API key not found in environment variable {self.api_key_env_var}" | ||||||
) | ||||||
self.model = model | ||||||
self._api_url = "https://embed.trychroma.com/embed_sparse" | ||||||
jairad26 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
self._session = httpx.Client() | ||||||
self._session.headers.update( | ||||||
{ | ||||||
"x-chroma-token": self.api_key, | ||||||
"x-chroma-embedding-model": self.model.value, | ||||||
} | ||||||
) | ||||||
|
||||||
def __del__(self) -> None: | ||||||
""" | ||||||
Cleanup the HTTP client session when the object is destroyed. | ||||||
""" | ||||||
if hasattr(self, "_session"): | ||||||
self._session.close() | ||||||
|
||||||
def close(self) -> None: | ||||||
""" | ||||||
Explicitly close the HTTP client session. | ||||||
Call this method when you're done using the embedding function. | ||||||
""" | ||||||
if hasattr(self, "_session"): | ||||||
self._session.close() | ||||||
|
||||||
def __call__(self, input: Documents) -> SparseEmbeddings: | ||||||
""" | ||||||
Generate embeddings for the given documents. | ||||||
|
||||||
Args: | ||||||
input (Documents): The documents to generate embeddings for. | ||||||
""" | ||||||
if not input: | ||||||
return [] | ||||||
|
||||||
payload: Dict[str, Union[str, Documents]] = { | ||||||
"texts": list(input), | ||||||
"task": "", | ||||||
"target": "", | ||||||
} | ||||||
|
||||||
jairad26 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
try: | ||||||
import httpx | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [BestPractice] Redundant httpx import inside method. The httpx library is already imported and validated in the constructor (line 32). The import on line 85 inside the
Suggested change
Note: Reusing the existing httpx.Client instance (self._session) is more efficient as it leverages connection pooling and avoids redundant imports. ⚡ Committable suggestion Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Context for Agents
|
||||||
|
||||||
response = self._session.post(self._api_url, json=payload, timeout=60) | ||||||
response.raise_for_status() | ||||||
propel-code-bot[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
json_response = response.json() | ||||||
return self._parse_response(json_response) | ||||||
except httpx.HTTPStatusError as e: | ||||||
raise RuntimeError( | ||||||
f"Failed to get embeddings from Chroma Cloud API: HTTP {e.response.status_code} - {e.response.text}" | ||||||
) | ||||||
except httpx.TimeoutException: | ||||||
raise RuntimeError("Request to Chroma Cloud API timed out after 60 seconds") | ||||||
except httpx.HTTPError as e: | ||||||
raise RuntimeError(f"Failed to get embeddings from Chroma Cloud API: {e}") | ||||||
except Exception as e: | ||||||
raise RuntimeError(f"Unexpected error calling Chroma Cloud API: {e}") | ||||||
|
||||||
def _parse_response(self, response: Any) -> SparseEmbeddings: | ||||||
""" | ||||||
Parse the response from the Chroma Cloud Sparse Embedding API. | ||||||
""" | ||||||
jairad26 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
embeddings: SparseEmbeddings = response["embeddings"] | ||||||
|
||||||
# Ensure indices are sorted in ascending order | ||||||
_sort_sparse_vectors(embeddings) | ||||||
|
||||||
return embeddings | ||||||
|
||||||
@staticmethod | ||||||
def name() -> str: | ||||||
return "chroma-cloud-splade" | ||||||
|
||||||
@staticmethod | ||||||
def build_from_config( | ||||||
config: Dict[str, Any] | ||||||
) -> "SparseEmbeddingFunction[Documents]": | ||||||
api_key_env_var = config.get("api_key_env_var") | ||||||
model = config.get("model") | ||||||
if model is None: | ||||||
raise ValueError("model must be provided in config") | ||||||
if not api_key_env_var: | ||||||
raise ValueError("api_key_env_var must be provided in config") | ||||||
return ChromaCloudSpladeEmbeddingFunction( | ||||||
api_key_env_var=api_key_env_var, | ||||||
model=ChromaCloudSpladeEmbeddingModel(model), | ||||||
) | ||||||
|
||||||
def get_config(self) -> Dict[str, Any]: | ||||||
return {"api_key_env_var": self.api_key_env_var, "model": self.model.value} | ||||||
|
||||||
def validate_config_update( | ||||||
self, old_config: Dict[str, Any], new_config: Dict[str, Any] | ||||||
) -> None: | ||||||
if "model" in new_config: | ||||||
raise ValueError( | ||||||
"model cannot be changed after the embedding function has been initialized" | ||||||
) | ||||||
|
||||||
@staticmethod | ||||||
def validate_config(config: Dict[str, Any]) -> None: | ||||||
validate_config_schema(config, "chroma-cloud-splade") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Chroma Cloud Splade Embeddings | ||
|
||
This package provides a sparse embedding function for the Splade model family hosted on Chroma's cloud embedding service. Splade (Sparse Lexical and Expansion) embeddings are particularly effective for information retrieval tasks, combining the benefits of sparse representations with learned relevance. | ||
|
||
## Installation | ||
|
||
```bash | ||
npm install @chroma-core/chroma-cloud-splade | ||
``` | ||
|
||
## Usage | ||
|
||
```typescript | ||
import { ChromaClient } from "chromadb"; | ||
import { | ||
ChromaCloudSpladeEmbeddingFunction, | ||
ChromaCloudSpladeEmbeddingModel, | ||
} from "@chroma-core/chroma-cloud-splade"; | ||
|
||
// Initialize the embedder | ||
const embedder = new ChromaCloudSpladeEmbeddingFunction({ | ||
model: ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1, | ||
apiKeyEnvVar: "CHROMA_API_KEY", | ||
}); | ||
|
||
## Configuration | ||
|
||
Set your Chroma API key as an environment variable: | ||
|
||
```bash | ||
export CHROMA_API_KEY=your-api-key | ||
``` | ||
|
||
Get your API key from [Chroma's dashboard](https://trychroma.com/). | ||
|
||
## Configuration Options | ||
|
||
- **model**: Model to use for sparse embeddings (default: `SPLADE_PP_EN_V1`) | ||
- **apiKeyEnvVar**: Environment variable name for API key (default: `CHROMA_API_KEY`) | ||
|
||
## Supported Models | ||
|
||
- `prithivida/Splade_PP_en_v1` - Splade++ English v1 model optimized for information retrieval |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import type { Config } from "jest"; | ||
|
||
const config: Config = { | ||
preset: "ts-jest", | ||
testEnvironment: "node", | ||
testMatch: ["**/*.test.ts"], | ||
transform: { | ||
"^.+\\.tsx?$": [ | ||
"ts-jest", | ||
{ | ||
useESM: true, | ||
}, | ||
], | ||
}, | ||
extensionsToTreatAsEsm: [".ts"], | ||
moduleNameMapper: { | ||
"^(\\.{1,2}/.*)\\.js$": "$1", | ||
}, | ||
setupFiles: ["./jest.setup.ts"], | ||
}; | ||
|
||
export default config; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import * as dotenv from "dotenv"; | ||
|
||
dotenv.config({ path: "../../../.env" }); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
{ | ||
"name": "@chroma-core/chroma-cloud-splade", | ||
"version": "0.1.7", | ||
"private": false, | ||
"description": "Chroma Cloud Splade sparse embedding function", | ||
"main": "dist/cjs/chroma-cloud-splade.cjs", | ||
"types": "dist/chroma-cloud-splade.d.ts", | ||
"module": "dist/chroma-cloud-splade.legacy-esm.js", | ||
"type": "module", | ||
"exports": { | ||
".": { | ||
"import": { | ||
"types": "./dist/chroma-cloud-splade.d.ts", | ||
"default": "./dist/chroma-cloud-splade.mjs" | ||
}, | ||
"require": { | ||
"types": "./dist/cjs/chroma-cloud-splade.d.cts", | ||
"default": "./dist/cjs/chroma-cloud-splade.cjs" | ||
} | ||
} | ||
}, | ||
"files": [ | ||
"src", | ||
"dist" | ||
], | ||
"scripts": { | ||
"clean": "rimraf dist", | ||
"prebuild": "rimraf dist", | ||
"build": "tsup", | ||
"watch": "tsup --watch", | ||
"test": "jest" | ||
}, | ||
"devDependencies": { | ||
"@jest/globals": "^29.7.0", | ||
"dotenv": "^16.3.1", | ||
"jest": "^29.7.0", | ||
"rimraf": "^5.0.0", | ||
"ts-jest": "^29.1.2", | ||
"ts-node": "^10.9.2", | ||
"tsup": "^8.3.5" | ||
}, | ||
"peerDependencies": { | ||
"chromadb": "workspace:^" | ||
}, | ||
"dependencies": { | ||
"@chroma-core/ai-embeddings-common": "workspace:^" | ||
}, | ||
"engines": { | ||
"node": ">=20" | ||
}, | ||
"publishConfig": { | ||
"access": "public" | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.