Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ We built Pydantic AI with one simple aim: to bring that FastAPI feeling to GenAI
[Pydantic Validation](https://docs.pydantic.dev/latest/) is the validation layer of the OpenAI SDK, the Google ADK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more. _Why use the derivative when you can go straight to the source?_ :smiley:

2. **Model-agnostic**:
Supports virtually every [model](https://ai.pydantic.dev/models/overview) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel, Nebius, OVHcloud, Alibaba Cloud, and Outlines. If your favorite model or provider is not listed, you can easily implement a [custom model](https://ai.pydantic.dev/models/overview#custom-models).
Supports virtually every [model](https://ai.pydantic.dev/models/overview) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel, Nebius, OVHcloud, Alibaba Cloud, SambaNova, and Outlines. If your favorite model or provider is not listed, you can easily implement a [custom model](https://ai.pydantic.dev/models/overview#custom-models).

3. **Seamless Observability**:
Tightly [integrates](https://ai.pydantic.dev/logfire) with [Pydantic Logfire](https://pydantic.dev/logfire), our general-purpose OpenTelemetry observability platform, for real-time debugging, evals-based performance monitoring, and behavior, tracing, and cost tracking. If you already have an observability platform that supports OTel, you can [use that too](https://ai.pydantic.dev/logfire#alternative-observability-backends).
Expand Down
2 changes: 2 additions & 0 deletions docs/api/providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@
::: pydantic_ai.providers.ovhcloud.OVHcloudProvider

::: pydantic_ai.providers.alibaba.AlibabaProvider

::: pydantic_ai.providers.sambanova.SambaNovaProvider
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ We built Pydantic AI with one simple aim: to bring that FastAPI feeling to GenAI
[Pydantic Validation](https://docs.pydantic.dev/latest/) is the validation layer of the OpenAI SDK, the Google ADK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more. _Why use the derivative when you can go straight to the source?_ :smiley:

2. **Model-agnostic**:
Supports virtually every [model](models/overview.md) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel, Nebius, OVHcloud, Alibaba Cloud, and Outlines. If your favorite model or provider is not listed, you can easily implement a [custom model](models/overview.md#custom-models).
Supports virtually every [model](models/overview.md) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel, Nebius, OVHcloud, Alibaba Cloud, SambaNova, and Outlines. If your favorite model or provider is not listed, you can easily implement a [custom model](models/overview.md#custom-models).

3. **Seamless Observability**:
Tightly [integrates](logfire.md) with [Pydantic Logfire](https://pydantic.dev/logfire), our general-purpose OpenTelemetry observability platform, for real-time debugging, evals-based performance monitoring, and behavior, tracing, and cost tracking. If you already have an observability platform that supports OTel, you can [use that too](logfire.md#alternative-observability-backends).
Expand Down
54 changes: 54 additions & 0 deletions docs/models/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,57 @@ result = agent.run_sync('What is the capital of France?')
print(result.output)
#> The capital of France is Paris.
```

### SambaNova
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this will have to be moved if #3941 merges first.


To use [SambaNova Cloud](https://cloud.sambanova.ai/), you need to obtain an API key from the [SambaNova Cloud dashboard](https://cloud.sambanova.ai/dashboard).

SambaNova provides access to multiple model families including Meta Llama, DeepSeek, Qwen, and Mistral models with fast inference speeds.

You can set the `SAMBANOVA_API_KEY` environment variable and use [`SambaNovaProvider`][pydantic_ai.providers.sambanova.SambaNovaProvider] by name:

```python
from pydantic_ai import Agent

agent = Agent('sambanova:Meta-Llama-3.1-8B-Instruct')
result = agent.run_sync('What is the capital of France?')
print(result.output)
#> The capital of France is Paris.
```

Or initialise the model and provider directly:

```python
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.sambanova import SambaNovaProvider

model = OpenAIChatModel(
'Meta-Llama-3.1-8B-Instruct',
provider=SambaNovaProvider(api_key='your-api-key'),
)
agent = Agent(model)
result = agent.run_sync('What is the capital of France?')
print(result.output)
#> The capital of France is Paris.
```

For a complete list of available models, see the [SambaNova supported models documentation](https://docs.sambanova.ai/docs/en/models/sambacloud-models).

You can customize the base URL if needed:

```python
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.sambanova import SambaNovaProvider

model = OpenAIChatModel(
'DeepSeek-R1-0528',
provider=SambaNovaProvider(
api_key='your-api-key',
base_url='https://custom.endpoint.com/v1',
),
)
agent = Agent(model)
...
```
1 change: 1 addition & 0 deletions docs/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used
- [Ollama](openai.md#ollama)
- [OVHcloud AI Endpoints](openai.md#ovhcloud-ai-endpoints)
- [Perplexity](openai.md#perplexity)
- [SambaNova](openai.md#sambanova)
- [Together AI](openai.md#together-ai)
- [Vercel AI Gateway](openai.md#vercel-ai-gateway)

Expand Down
20 changes: 11 additions & 9 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,35 +525,37 @@
OpenAIChatCompatibleProvider = TypeAliasType(
'OpenAIChatCompatibleProvider',
Literal[
'alibaba',
'azure',
'deepseek',
'cerebras',
'deepseek',
'fireworks',
'github',
'grok',
'heroku',
'litellm',
'moonshotai',
'nebius',
'ollama',
'openrouter',
'ovhcloud',
'sambanova',
'together',
'vercel',
'litellm',
'nebius',
'ovhcloud',
'alibaba',
],
)
OpenAIResponsesCompatibleProvider = TypeAliasType(
'OpenAIResponsesCompatibleProvider',
Literal[
'deepseek',
'azure',
'openrouter',
'grok',
'deepseek',
'fireworks',
'together',
'grok',
'nebius',
'openrouter',
'ovhcloud',
'sambanova',
'together',
],
)

Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
from .alibaba import AlibabaProvider

return AlibabaProvider
elif provider == 'sambanova':
from .sambanova import SambaNovaProvider

return SambaNovaProvider
elif provider == 'outlines':
from .outlines import OutlinesProvider

Expand Down
112 changes: 112 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/sambanova.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import annotations

import os

import httpx
from openai import AsyncOpenAI

from pydantic_ai import ModelProfile
from pydantic_ai.exceptions import UserError
from pydantic_ai.models import cached_async_http_client
from pydantic_ai.profiles.deepseek import deepseek_model_profile
from pydantic_ai.profiles.meta import meta_model_profile
from pydantic_ai.profiles.mistral import mistral_model_profile
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
from pydantic_ai.profiles.qwen import qwen_model_profile
from pydantic_ai.providers import Provider

try:
from openai import AsyncOpenAI
except ImportError as _import_error: # pragma: no cover
raise ImportError(
'Please install the `openai` package to use the SambaNova provider, '
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
) from _import_error

__all__ = ['SambaNovaProvider']


class SambaNovaProvider(Provider[AsyncOpenAI]):
"""Provider for SambaNova AI models.

SambaNova uses an OpenAI-compatible API.
"""

@property
def name(self) -> str:
"""Return the provider name."""
return 'sambanova'

@property
def base_url(self) -> str:
"""Return the base URL."""
return self._base_url

@property
def client(self) -> AsyncOpenAI:
"""Return the AsyncOpenAI client."""
return self._client

def model_profile(self, model_name: str) -> ModelProfile | None:
"""Get model profile for SambaNova models.

SambaNova serves models from multiple families including Meta Llama,
DeepSeek, Qwen, and Mistral. Model profiles are matched based on
model name prefixes.
"""
prefix_to_profile = {
'deepseek-': deepseek_model_profile,
'meta-llama-': meta_model_profile,
'llama-': meta_model_profile,
'qwen': qwen_model_profile,
'mistral': mistral_model_profile,
}

profile = None
model_name_lower = model_name.lower()

for prefix, profile_func in prefix_to_profile.items():
if model_name_lower.startswith(prefix):
profile = profile_func(model_name)
break

# Wrap into OpenAIModelProfile since SambaNova is OpenAI-compatible
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)

def __init__(
self,
*,
api_key: str | None = None,
base_url: str | None = None,
openai_client: AsyncOpenAI | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
"""Initialize SambaNova provider.

Args:
api_key: SambaNova API key. If not provided, reads from SAMBANOVA_API_KEY env var.
base_url: Custom API base URL. Defaults to https://api.sambanova.ai/v1
openai_client: Optional pre-configured OpenAI client
http_client: Optional custom httpx.AsyncClient for making HTTP requests

Raises:
UserError: If API key is not provided and SAMBANOVA_API_KEY env var is not set
"""
if openai_client is not None:
self._client = openai_client
self._base_url = str(openai_client.base_url)
else:
# Get API key from parameter or environment
api_key = api_key or os.getenv('SAMBANOVA_API_KEY')
if not api_key:
raise UserError(
'Set the `SAMBANOVA_API_KEY` environment variable or pass it via '
'`SambaNovaProvider(api_key=...)` to use the SambaNova provider.'
)

# Set base URL (default to SambaNova API endpoint)
self._base_url = base_url or os.getenv('SAMBANOVA_BASE_URL', 'https://api.sambanova.ai/v1')

# Create http client and AsyncOpenAI client
http_client = http_client or cached_async_http_client(provider='sambanova')
self._client = AsyncOpenAI(base_url=self._base_url, api_key=api_key, http_client=http_client)
116 changes: 116 additions & 0 deletions tests/providers/test_sambanova_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import httpx
import pytest

from pydantic_ai.exceptions import UserError
from pydantic_ai.profiles.openai import OpenAIModelProfile

from ..conftest import TestEnv, try_import

with try_import() as imports_successful:
import openai

from pydantic_ai.providers import infer_provider
from pydantic_ai.providers.sambanova import SambaNovaProvider

pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed')


def test_sambanova_provider_init():
provider = SambaNovaProvider(api_key='test-key')
assert provider.name == 'sambanova'
assert provider.base_url == 'https://api.sambanova.ai/v1'
assert isinstance(provider.client, openai.AsyncOpenAI)
assert provider.client.api_key == 'test-key'


def test_sambanova_provider_env_key(env: TestEnv):
env.set('SAMBANOVA_API_KEY', 'env-key')
provider = SambaNovaProvider()
assert provider.client.api_key == 'env-key'


def test_sambanova_provider_missing_key(env: TestEnv):
env.remove('SAMBANOVA_API_KEY')
with pytest.raises(UserError, match='Set the `SAMBANOVA_API_KEY`'):
SambaNovaProvider()


def test_infer_provider(env: TestEnv):
# infer_provider instantiates the class, so we need an env var or it raises UserError
env.set('SAMBANOVA_API_KEY', 'key')
provider = infer_provider('sambanova')
assert isinstance(provider, SambaNovaProvider)


def test_meta_llama_profile():
provider = SambaNovaProvider(api_key='key')
# Meta Llama model -> expect meta profile wrapped in OpenAI compatibility
profile = provider.model_profile('Meta-Llama-3.1-8B-Instruct')
assert isinstance(profile, OpenAIModelProfile)
assert profile is not None


def test_deepseek_profile():
provider = SambaNovaProvider(api_key='key')
# DeepSeek model -> expect deepseek profile wrapped in OpenAI compatibility
profile = provider.model_profile('DeepSeek-R1-0528')
assert isinstance(profile, OpenAIModelProfile)
assert profile is not None


def test_qwen_profile():
provider = SambaNovaProvider(api_key='key')
# Qwen model -> expect qwen profile wrapped in OpenAI compatibility
profile = provider.model_profile('Qwen3-32B')
assert isinstance(profile, OpenAIModelProfile)
assert profile is not None


def test_llama4_profile():
provider = SambaNovaProvider(api_key='key')
# Llama 4 model -> expect meta profile wrapped in OpenAI compatibility
profile = provider.model_profile('Llama-4-Maverick-17B-128E-Instruct')
assert isinstance(profile, OpenAIModelProfile)
assert profile is not None


def test_mistral_profile():
provider = SambaNovaProvider(api_key='key')
# Mistral-based model -> expect mistral profile wrapped in OpenAI compatibility
profile = provider.model_profile('E5-Mistral-7B-Instruct')
assert isinstance(profile, OpenAIModelProfile)
assert profile is not None


def test_unknown_model_profile():
provider = SambaNovaProvider(api_key='key')
# Unknown model -> should return OpenAI compatibility wrapper with None base profile
profile = provider.model_profile('unknown-model')
assert isinstance(profile, OpenAIModelProfile)


def test_sambanova_provider_with_openai_client():
client = openai.AsyncOpenAI(api_key='foo', base_url='https://api.sambanova.ai/v1')
provider = SambaNovaProvider(openai_client=client)
assert provider.client is client


def test_sambanova_provider_with_http_client():
http_client = httpx.AsyncClient()
provider = SambaNovaProvider(api_key='foo', http_client=http_client)
assert provider.client.api_key == 'foo'
# The line `self._client = AsyncOpenAI(..., http_client=http_client)` is executed,
# which is enough for coverage.


def test_sambanova_provider_custom_base_url():
provider = SambaNovaProvider(api_key='test-key', base_url='https://custom.endpoint.com/v1')
assert provider.base_url == 'https://custom.endpoint.com/v1'
assert str(provider.client.base_url).rstrip('/') == 'https://custom.endpoint.com/v1'


def test_sambanova_provider_env_base_url(env: TestEnv):
env.set('SAMBANOVA_API_KEY', 'key')
env.set('SAMBANOVA_BASE_URL', 'https://env.endpoint.com/v1')
provider = SambaNovaProvider()
assert provider.base_url == 'https://env.endpoint.com/v1'
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def print(self, *args: Any, **kwargs: Any) -> None:
env.set('DEEPSEEK_API_KEY', 'testing')
env.set('OVHCLOUD_API_KEY', 'testing')
env.set('ALIBABA_API_KEY', 'testing')
env.set('SAMBANOVA_API_KEY', 'testing')
env.set('PYDANTIC_AI_GATEWAY_API_KEY', 'testing')

prefix_settings = example.prefix_settings()
Expand Down