-
Notifications
You must be signed in to change notification settings - Fork 1.6k
VoyageAI embeddings support #3856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
78c4473
50b1f64
a166cb7
aeb5ccc
88ab61b
bcdd6f2
324e3bd
f079fa9
8526d4b
87c7b72
093eaa5
b283b0b
edde0de
95e2f45
31d5e4a
7820d20
c35fa2c
53e08af
0171011
ca963f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,13 @@ | |
| 'cohere:embed-english-light-v3.0', | ||
| 'cohere:embed-multilingual-v3.0', | ||
| 'cohere:embed-multilingual-light-v3.0', | ||
| 'voyageai:voyage-3-large', | ||
| 'voyageai:voyage-3.5', | ||
| 'voyageai:voyage-3.5-lite', | ||
| 'voyageai:voyage-code-3', | ||
| 'voyageai:voyage-finance-2', | ||
| 'voyageai:voyage-law-2', | ||
| 'voyageai:voyage-code-2', | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ], | ||
| ) | ||
| """Known model names that can be used with the `model` parameter of [`Embedder`][pydantic_ai.embeddings.Embedder]. | ||
|
|
@@ -70,14 +77,21 @@ def infer_embedding_model( | |
| except ValueError as e: | ||
| raise ValueError('You must provide a provider prefix when specifying an embedding model name') from e | ||
|
|
||
| provider = provider_factory(provider_name) | ||
|
|
||
| model_kind = provider_name | ||
| if model_kind.startswith('gateway/'): | ||
| from ..providers.gateway import normalize_gateway_provider | ||
|
|
||
| model_kind = normalize_gateway_provider(model_kind) | ||
|
|
||
| # Handle models that don't need a provider first | ||
|
||
| if model_kind == 'sentence-transformers': | ||
| from .sentence_transformers import SentenceTransformerEmbeddingModel | ||
|
|
||
| return SentenceTransformerEmbeddingModel(model_name) | ||
|
|
||
| # For provider-based models, infer the provider | ||
| provider = provider_factory(provider_name) | ||
|
|
||
| if model_kind in ( | ||
| 'openai', | ||
| # For now, we assume that every chat and completions-compatible provider also | ||
|
|
@@ -92,10 +106,10 @@ def infer_embedding_model( | |
| from .cohere import CohereEmbeddingModel | ||
|
|
||
| return CohereEmbeddingModel(model_name, provider=provider) | ||
| elif model_kind == 'sentence-transformers': | ||
| from .sentence_transformers import SentenceTransformerEmbeddingModel | ||
| elif model_kind == 'voyageai': | ||
| from .voyageai import VoyageAIEmbeddingModel | ||
|
|
||
| return SentenceTransformerEmbeddingModel(model_name) | ||
| return VoyageAIEmbeddingModel(model_name, provider=provider) | ||
| else: | ||
| raise UserError(f'Unknown embeddings model: {model}') # pragma: no cover | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Sequence | ||
| from dataclasses import dataclass, field | ||
| from typing import Literal, cast | ||
|
|
||
| from pydantic_ai.exceptions import ModelAPIError | ||
| from pydantic_ai.providers import Provider, infer_provider | ||
| from pydantic_ai.usage import RequestUsage | ||
|
|
||
| from .base import EmbeddingModel, EmbedInputType | ||
| from .result import EmbeddingResult | ||
| from .settings import EmbeddingSettings | ||
|
|
||
| try: | ||
| from voyageai.client_async import AsyncClient | ||
| from voyageai.error import VoyageError | ||
| except ImportError as _import_error: | ||
| raise ImportError( | ||
| 'Please install `voyageai` to use the VoyageAI embeddings model, ' | ||
| 'you can use the `voyageai` optional group — `pip install "pydantic-ai-slim[voyageai]"`' | ||
| ) from _import_error | ||
|
|
||
| LatestVoyageAIEmbeddingModelNames = Literal[ | ||
| 'voyage-3-large', | ||
| 'voyage-3.5', | ||
| 'voyage-3.5-lite', | ||
| 'voyage-code-3', | ||
| 'voyage-finance-2', | ||
| 'voyage-law-2', | ||
| 'voyage-code-2', | ||
| ] | ||
| """Latest VoyageAI embedding models. | ||
| See [VoyageAI Embeddings](https://docs.voyageai.com/docs/embeddings) | ||
| for available models and their capabilities. | ||
| """ | ||
|
|
||
| VoyageAIEmbeddingModelName = str | LatestVoyageAIEmbeddingModelNames | ||
| """Possible VoyageAI embedding model names.""" | ||
|
|
||
|
|
||
| class VoyageAIEmbeddingSettings(EmbeddingSettings, total=False): | ||
| """Settings used for a VoyageAI embedding model request. | ||
| All fields from [`EmbeddingSettings`][pydantic_ai.embeddings.EmbeddingSettings] are supported, | ||
| plus VoyageAI-specific settings prefixed with `voyageai_`. | ||
| """ | ||
|
|
||
| # ALL FIELDS MUST BE `voyageai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. | ||
|
|
||
| voyageai_truncation: bool | ||
|
||
| """Whether to truncate inputs that exceed the model's context length. | ||
| Defaults to False. If True, inputs that are too long will be truncated. | ||
| """ | ||
|
|
||
|
|
||
| _MAX_INPUT_TOKENS: dict[VoyageAIEmbeddingModelName, int] = { | ||
| 'voyage-3-large': 32000, | ||
| 'voyage-3.5': 32000, | ||
| 'voyage-3.5-lite': 32000, | ||
| 'voyage-code-3': 32000, | ||
| 'voyage-finance-2': 32000, | ||
| 'voyage-law-2': 16000, | ||
| 'voyage-code-2': 16000, | ||
| } | ||
|
|
||
|
|
||
| @dataclass(init=False) | ||
| class VoyageAIEmbeddingModel(EmbeddingModel): | ||
| """VoyageAI embedding model implementation. | ||
| VoyageAI provides state-of-the-art embedding models optimized for | ||
| retrieval, with specialized models for code, finance, and legal domains. | ||
| Example: | ||
| ```python | ||
| from pydantic_ai.embeddings.voyageai import VoyageAIEmbeddingModel | ||
| model = VoyageAIEmbeddingModel('voyage-3.5') | ||
| ``` | ||
| """ | ||
|
|
||
| _model_name: VoyageAIEmbeddingModelName = field(repr=False) | ||
| _provider: Provider[AsyncClient] = field(repr=False) | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_name: VoyageAIEmbeddingModelName, | ||
| *, | ||
| provider: Literal['voyageai'] | Provider[AsyncClient] = 'voyageai', | ||
| settings: EmbeddingSettings | None = None, | ||
| ): | ||
| """Initialize a VoyageAI embedding model. | ||
| Args: | ||
| model_name: The name of the VoyageAI model to use. | ||
| See [VoyageAI models](https://docs.voyageai.com/docs/embeddings) | ||
| for available options. | ||
| provider: The provider to use for authentication and API access. Can be: | ||
| - `'voyageai'` (default): Uses the standard VoyageAI API | ||
| - A [`VoyageAIProvider`][pydantic_ai.providers.voyageai.VoyageAIProvider] instance | ||
| for custom configuration | ||
| settings: Model-specific [`EmbeddingSettings`][pydantic_ai.embeddings.EmbeddingSettings] | ||
| to use as defaults for this model. | ||
| """ | ||
| self._model_name = model_name | ||
|
|
||
| if isinstance(provider, str): | ||
| provider = infer_provider(provider) | ||
| self._provider = provider | ||
|
|
||
| super().__init__(settings=settings) | ||
|
|
||
| @property | ||
| def base_url(self) -> str: | ||
| """The base URL for the provider API.""" | ||
| return self._provider.base_url | ||
|
|
||
| @property | ||
| def model_name(self) -> VoyageAIEmbeddingModelName: | ||
| """The embedding model name.""" | ||
| return self._model_name | ||
|
|
||
| @property | ||
| def system(self) -> str: | ||
| """The embedding model provider.""" | ||
| return self._provider.name | ||
|
|
||
| async def embed( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a shame they don't (seem to) support counting tokens :( |
||
| self, | ||
| inputs: str | Sequence[str], | ||
| *, | ||
| input_type: EmbedInputType, | ||
| settings: EmbeddingSettings | None = None, | ||
| ) -> EmbeddingResult: | ||
| inputs, settings = self.prepare_embed(inputs, settings) | ||
| settings = cast(VoyageAIEmbeddingSettings, settings) | ||
|
|
||
| voyageai_input_type = 'document' if input_type == 'document' else 'query' | ||
|
|
||
| try: | ||
| response = await self._provider.client.embed( | ||
| texts=list(inputs), | ||
| model=self.model_name, | ||
| input_type=voyageai_input_type, | ||
DouweM marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| truncation=settings.get('voyageai_truncation', False), | ||
| output_dimension=settings.get('dimensions'), | ||
| ) | ||
| except VoyageError as e: | ||
| raise ModelAPIError(model_name=self.model_name, message=str(e)) from e | ||
|
|
||
| return EmbeddingResult( | ||
| embeddings=response.embeddings, | ||
| inputs=inputs, | ||
| input_type=input_type, | ||
| usage=_map_usage(response.total_tokens), | ||
| model_name=self.model_name, | ||
| provider_name=self.system, | ||
| ) | ||
|
|
||
| async def max_input_tokens(self) -> int | None: | ||
| return _MAX_INPUT_TOKENS.get(self.model_name) | ||
|
|
||
|
|
||
| def _map_usage(total_tokens: int) -> RequestUsage: | ||
| return RequestUsage(input_tokens=total_tokens) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| from __future__ import annotations as _annotations | ||
|
|
||
| import os | ||
|
|
||
| from pydantic_ai.exceptions import UserError | ||
| from pydantic_ai.providers import Provider | ||
|
|
||
| try: | ||
| from voyageai.client_async import AsyncClient | ||
| except ImportError as _import_error: # pragma: no cover | ||
| raise ImportError( | ||
| 'Please install the `voyageai` package to use the VoyageAI provider, ' | ||
| 'you can use the `voyageai` optional group — `pip install "pydantic-ai-slim[voyageai]"`' | ||
| ) from _import_error | ||
|
|
||
|
|
||
| class VoyageAIProvider(Provider[AsyncClient]): | ||
| """Provider for VoyageAI API.""" | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| return 'voyageai' | ||
|
|
||
| @property | ||
| def base_url(self) -> str: | ||
| return self._client._params.get('base_url') or 'https://api.voyageai.com/v1' # type: ignore | ||
|
|
||
| @property | ||
| def client(self) -> AsyncClient: | ||
| return self._client | ||
|
|
||
| def __init__( | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| *, | ||
| api_key: str | None = None, | ||
| voyageai_client: AsyncClient | None = None, | ||
| base_url: str | None = None, | ||
| max_retries: int = 0, | ||
| timeout: float | None = None, | ||
| ) -> None: | ||
| """Create a new VoyageAI provider. | ||
| Args: | ||
| api_key: The API key to use for authentication, if not provided, the `VOYAGE_API_KEY` environment variable | ||
| will be used if available. | ||
| voyageai_client: An existing | ||
| [AsyncClient](https://github.com/voyage-ai/voyageai-python) | ||
| client to use. If provided, `api_key`, `base_url`, `max_retries`, and `timeout` must be `None`/default. | ||
| base_url: The base URL for the VoyageAI API. Defaults to `https://api.voyageai.com/v1`. | ||
| max_retries: Maximum number of retries for failed requests. | ||
| timeout: Request timeout in seconds. | ||
| """ | ||
| if voyageai_client is not None: | ||
| assert api_key is None, 'Cannot provide both `voyageai_client` and `api_key`' | ||
| assert base_url is None, 'Cannot provide both `voyageai_client` and `base_url`' | ||
| assert max_retries == 0, 'Cannot provide both `voyageai_client` and `max_retries`' | ||
| assert timeout is None, 'Cannot provide both `voyageai_client` and `timeout`' | ||
|
||
| self._client = voyageai_client | ||
| else: | ||
| api_key = api_key or os.getenv('VOYAGE_API_KEY') | ||
| if not api_key: | ||
| raise UserError( | ||
| 'Set the `VOYAGE_API_KEY` environment variable or pass it via `VoyageAIProvider(api_key=...)` ' | ||
| 'to use the VoyageAI provider.' | ||
| ) | ||
|
|
||
| # Only pass base_url if explicitly set; otherwise use VoyageAI's default | ||
| base_url = base_url or os.getenv('VOYAGE_BASE_URL') | ||
| self._client = AsyncClient( | ||
|
||
| api_key=api_key, | ||
| max_retries=max_retries, | ||
| timeout=timeout, | ||
| base_url=base_url, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.