-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Add SambaNova provider support #3887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
DouweM
merged 7 commits into
pydantic:main
from
Pavanmanikanta98:feat/sambanova-provider
Jan 14, 2026
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
94c7da4
Add SambaNova provider support
1f7ad1d
Apply ruff formatting to sambanova.py
599dce8
Fix import order in docstring example
d2b0d56
Remove docstring example to match other providers
9728f18
Remove unused sambanova_api_key fixture for 100% coverage
0df65e6
Replace hardcoded model list with link to SambaNova model catalog
3728a2d
Fix SambaNova model catalog link to official documentation
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import os | ||
|
|
||
| import httpx | ||
| from openai import AsyncOpenAI | ||
|
|
||
| from pydantic_ai import ModelProfile | ||
| from pydantic_ai.exceptions import UserError | ||
| from pydantic_ai.models import cached_async_http_client | ||
| from pydantic_ai.profiles.deepseek import deepseek_model_profile | ||
| from pydantic_ai.profiles.meta import meta_model_profile | ||
| from pydantic_ai.profiles.mistral import mistral_model_profile | ||
| from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile | ||
| from pydantic_ai.profiles.qwen import qwen_model_profile | ||
| from pydantic_ai.providers import Provider | ||
|
|
||
| try: | ||
| from openai import AsyncOpenAI | ||
| except ImportError as _import_error: # pragma: no cover | ||
| raise ImportError( | ||
| 'Please install the `openai` package to use the SambaNova provider, ' | ||
| 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' | ||
| ) from _import_error | ||
|
|
||
| __all__ = ['SambaNovaProvider'] | ||
|
|
||
|
|
||
| class SambaNovaProvider(Provider[AsyncOpenAI]): | ||
| """Provider for SambaNova AI models. | ||
|
|
||
| SambaNova uses an OpenAI-compatible API. | ||
| """ | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| """Return the provider name.""" | ||
| return 'sambanova' | ||
|
|
||
| @property | ||
| def base_url(self) -> str: | ||
| """Return the base URL.""" | ||
| return self._base_url | ||
|
|
||
| @property | ||
| def client(self) -> AsyncOpenAI: | ||
| """Return the AsyncOpenAI client.""" | ||
| return self._client | ||
|
|
||
| def model_profile(self, model_name: str) -> ModelProfile | None: | ||
| """Get model profile for SambaNova models. | ||
|
|
||
| SambaNova serves models from multiple families including Meta Llama, | ||
| DeepSeek, Qwen, and Mistral. Model profiles are matched based on | ||
| model name prefixes. | ||
| """ | ||
| prefix_to_profile = { | ||
| 'deepseek-': deepseek_model_profile, | ||
| 'meta-llama-': meta_model_profile, | ||
| 'llama-': meta_model_profile, | ||
| 'qwen': qwen_model_profile, | ||
| 'mistral': mistral_model_profile, | ||
| } | ||
|
|
||
| profile = None | ||
| model_name_lower = model_name.lower() | ||
|
|
||
| for prefix, profile_func in prefix_to_profile.items(): | ||
| if model_name_lower.startswith(prefix): | ||
| profile = profile_func(model_name) | ||
| break | ||
|
|
||
| # Wrap into OpenAIModelProfile since SambaNova is OpenAI-compatible | ||
| return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| api_key: str | None = None, | ||
| base_url: str | None = None, | ||
| openai_client: AsyncOpenAI | None = None, | ||
| http_client: httpx.AsyncClient | None = None, | ||
| ) -> None: | ||
| """Initialize SambaNova provider. | ||
|
|
||
| Args: | ||
| api_key: SambaNova API key. If not provided, reads from SAMBANOVA_API_KEY env var. | ||
| base_url: Custom API base URL. Defaults to https://api.sambanova.ai/v1 | ||
| openai_client: Optional pre-configured OpenAI client | ||
| http_client: Optional custom httpx.AsyncClient for making HTTP requests | ||
|
|
||
| Raises: | ||
| UserError: If API key is not provided and SAMBANOVA_API_KEY env var is not set | ||
| """ | ||
| if openai_client is not None: | ||
| self._client = openai_client | ||
| self._base_url = str(openai_client.base_url) | ||
| else: | ||
| # Get API key from parameter or environment | ||
| api_key = api_key or os.getenv('SAMBANOVA_API_KEY') | ||
| if not api_key: | ||
| raise UserError( | ||
| 'Set the `SAMBANOVA_API_KEY` environment variable or pass it via ' | ||
| '`SambaNovaProvider(api_key=...)` to use the SambaNova provider.' | ||
| ) | ||
|
|
||
| # Set base URL (default to SambaNova API endpoint) | ||
| self._base_url = base_url or os.getenv('SAMBANOVA_BASE_URL', 'https://api.sambanova.ai/v1') | ||
|
|
||
| # Create http client and AsyncOpenAI client | ||
| http_client = http_client or cached_async_http_client(provider='sambanova') | ||
| self._client = AsyncOpenAI(base_url=self._base_url, api_key=api_key, http_client=http_client) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| import httpx | ||
| import pytest | ||
|
|
||
| from pydantic_ai.exceptions import UserError | ||
| from pydantic_ai.profiles.openai import OpenAIModelProfile | ||
|
|
||
| from ..conftest import TestEnv, try_import | ||
|
|
||
| with try_import() as imports_successful: | ||
| import openai | ||
|
|
||
| from pydantic_ai.providers import infer_provider | ||
| from pydantic_ai.providers.sambanova import SambaNovaProvider | ||
|
|
||
| pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') | ||
|
|
||
|
|
||
| def test_sambanova_provider_init(): | ||
| provider = SambaNovaProvider(api_key='test-key') | ||
| assert provider.name == 'sambanova' | ||
| assert provider.base_url == 'https://api.sambanova.ai/v1' | ||
| assert isinstance(provider.client, openai.AsyncOpenAI) | ||
| assert provider.client.api_key == 'test-key' | ||
|
|
||
|
|
||
| def test_sambanova_provider_env_key(env: TestEnv): | ||
| env.set('SAMBANOVA_API_KEY', 'env-key') | ||
| provider = SambaNovaProvider() | ||
| assert provider.client.api_key == 'env-key' | ||
|
|
||
|
|
||
| def test_sambanova_provider_missing_key(env: TestEnv): | ||
| env.remove('SAMBANOVA_API_KEY') | ||
| with pytest.raises(UserError, match='Set the `SAMBANOVA_API_KEY`'): | ||
| SambaNovaProvider() | ||
|
|
||
|
|
||
| def test_infer_provider(env: TestEnv): | ||
| # infer_provider instantiates the class, so we need an env var or it raises UserError | ||
| env.set('SAMBANOVA_API_KEY', 'key') | ||
| provider = infer_provider('sambanova') | ||
| assert isinstance(provider, SambaNovaProvider) | ||
|
|
||
|
|
||
| def test_meta_llama_profile(): | ||
| provider = SambaNovaProvider(api_key='key') | ||
| # Meta Llama model -> expect meta profile wrapped in OpenAI compatibility | ||
| profile = provider.model_profile('Meta-Llama-3.1-8B-Instruct') | ||
| assert isinstance(profile, OpenAIModelProfile) | ||
| assert profile is not None | ||
|
|
||
|
|
||
| def test_deepseek_profile(): | ||
| provider = SambaNovaProvider(api_key='key') | ||
| # DeepSeek model -> expect deepseek profile wrapped in OpenAI compatibility | ||
| profile = provider.model_profile('DeepSeek-R1-0528') | ||
| assert isinstance(profile, OpenAIModelProfile) | ||
| assert profile is not None | ||
|
|
||
|
|
||
| def test_qwen_profile(): | ||
| provider = SambaNovaProvider(api_key='key') | ||
| # Qwen model -> expect qwen profile wrapped in OpenAI compatibility | ||
| profile = provider.model_profile('Qwen3-32B') | ||
| assert isinstance(profile, OpenAIModelProfile) | ||
| assert profile is not None | ||
|
|
||
|
|
||
| def test_llama4_profile(): | ||
| provider = SambaNovaProvider(api_key='key') | ||
| # Llama 4 model -> expect meta profile wrapped in OpenAI compatibility | ||
| profile = provider.model_profile('Llama-4-Maverick-17B-128E-Instruct') | ||
| assert isinstance(profile, OpenAIModelProfile) | ||
| assert profile is not None | ||
|
|
||
|
|
||
| def test_mistral_profile(): | ||
| provider = SambaNovaProvider(api_key='key') | ||
| # Mistral-based model -> expect mistral profile wrapped in OpenAI compatibility | ||
| profile = provider.model_profile('E5-Mistral-7B-Instruct') | ||
| assert isinstance(profile, OpenAIModelProfile) | ||
| assert profile is not None | ||
|
|
||
|
|
||
| def test_unknown_model_profile(): | ||
| provider = SambaNovaProvider(api_key='key') | ||
| # Unknown model -> should return OpenAI compatibility wrapper with None base profile | ||
| profile = provider.model_profile('unknown-model') | ||
| assert isinstance(profile, OpenAIModelProfile) | ||
|
|
||
|
|
||
| def test_sambanova_provider_with_openai_client(): | ||
| client = openai.AsyncOpenAI(api_key='foo', base_url='https://api.sambanova.ai/v1') | ||
| provider = SambaNovaProvider(openai_client=client) | ||
| assert provider.client is client | ||
|
|
||
|
|
||
| def test_sambanova_provider_with_http_client(): | ||
| http_client = httpx.AsyncClient() | ||
| provider = SambaNovaProvider(api_key='foo', http_client=http_client) | ||
| assert provider.client.api_key == 'foo' | ||
| # The line `self._client = AsyncOpenAI(..., http_client=http_client)` is executed, | ||
| # which is enough for coverage. | ||
|
|
||
|
|
||
| def test_sambanova_provider_custom_base_url(): | ||
| provider = SambaNovaProvider(api_key='test-key', base_url='https://custom.endpoint.com/v1') | ||
| assert provider.base_url == 'https://custom.endpoint.com/v1' | ||
| assert str(provider.client.base_url).rstrip('/') == 'https://custom.endpoint.com/v1' | ||
|
|
||
|
|
||
| def test_sambanova_provider_env_base_url(env: TestEnv): | ||
| env.set('SAMBANOVA_API_KEY', 'key') | ||
| env.set('SAMBANOVA_BASE_URL', 'https://env.endpoint.com/v1') | ||
| provider = SambaNovaProvider() | ||
| assert provider.base_url == 'https://env.endpoint.com/v1' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this will have to be moved if #3941 merges first.