diff --git a/README.md b/README.md index f391df541c..bbb7438ff0 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ We built Pydantic AI with one simple aim: to bring that FastAPI feeling to GenAI [Pydantic Validation](https://docs.pydantic.dev/latest/) is the validation layer of the OpenAI SDK, the Google ADK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more. _Why use the derivative when you can go straight to the source?_ :smiley: 2. **Model-agnostic**: -Supports virtually every [model](https://ai.pydantic.dev/models/overview) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel, Nebius, OVHcloud, Alibaba Cloud, and Outlines. If your favorite model or provider is not listed, you can easily implement a [custom model](https://ai.pydantic.dev/models/overview#custom-models). +Supports virtually every [model](https://ai.pydantic.dev/models/overview) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel, Nebius, OVHcloud, Alibaba Cloud, SambaNova, and Outlines. If your favorite model or provider is not listed, you can easily implement a [custom model](https://ai.pydantic.dev/models/overview#custom-models). 3. **Seamless Observability**: Tightly [integrates](https://ai.pydantic.dev/logfire) with [Pydantic Logfire](https://pydantic.dev/logfire), our general-purpose OpenTelemetry observability platform, for real-time debugging, evals-based performance monitoring, and behavior, tracing, and cost tracking. If you already have an observability platform that supports OTel, you can [use that too](https://ai.pydantic.dev/logfire#alternative-observability-backends). diff --git a/docs/api/providers.md b/docs/api/providers.md index f5b23ed8d7..97a12f4cb0 100644 --- a/docs/api/providers.md +++ b/docs/api/providers.md @@ -51,3 +51,5 @@ ::: pydantic_ai.providers.ovhcloud.OVHcloudProvider ::: pydantic_ai.providers.alibaba.AlibabaProvider + +::: pydantic_ai.providers.sambanova.SambaNovaProvider diff --git a/docs/index.md b/docs/index.md index f73307db87..333226e159 100644 --- a/docs/index.md +++ b/docs/index.md @@ -18,7 +18,7 @@ We built Pydantic AI with one simple aim: to bring that FastAPI feeling to GenAI [Pydantic Validation](https://docs.pydantic.dev/latest/) is the validation layer of the OpenAI SDK, the Google ADK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more. _Why use the derivative when you can go straight to the source?_ :smiley: 2. **Model-agnostic**: -Supports virtually every [model](models/overview.md) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel, Nebius, OVHcloud, Alibaba Cloud, and Outlines. If your favorite model or provider is not listed, you can easily implement a [custom model](models/overview.md#custom-models). +Supports virtually every [model](models/overview.md) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel, Nebius, OVHcloud, Alibaba Cloud, SambaNova, and Outlines. If your favorite model or provider is not listed, you can easily implement a [custom model](models/overview.md#custom-models). 3. **Seamless Observability**: Tightly [integrates](logfire.md) with [Pydantic Logfire](https://pydantic.dev/logfire), our general-purpose OpenTelemetry observability platform, for real-time debugging, evals-based performance monitoring, and behavior, tracing, and cost tracking. If you already have an observability platform that supports OTel, you can [use that too](logfire.md#alternative-observability-backends). diff --git a/docs/models/openai.md b/docs/models/openai.md index 821bc5ce18..220a0bd2e2 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -739,3 +739,57 @@ result = agent.run_sync('What is the capital of France?') print(result.output) #> The capital of France is Paris. ``` + +### SambaNova + +To use [SambaNova Cloud](https://cloud.sambanova.ai/), you need to obtain an API key from the [SambaNova Cloud dashboard](https://cloud.sambanova.ai/dashboard). + +SambaNova provides access to multiple model families including Meta Llama, DeepSeek, Qwen, and Mistral models with fast inference speeds. + +You can set the `SAMBANOVA_API_KEY` environment variable and use [`SambaNovaProvider`][pydantic_ai.providers.sambanova.SambaNovaProvider] by name: + +```python +from pydantic_ai import Agent + +agent = Agent('sambanova:Meta-Llama-3.1-8B-Instruct') +result = agent.run_sync('What is the capital of France?') +print(result.output) +#> The capital of France is Paris. +``` + +Or initialise the model and provider directly: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.sambanova import SambaNovaProvider + +model = OpenAIChatModel( + 'Meta-Llama-3.1-8B-Instruct', + provider=SambaNovaProvider(api_key='your-api-key'), +) +agent = Agent(model) +result = agent.run_sync('What is the capital of France?') +print(result.output) +#> The capital of France is Paris. +``` + +For a complete list of available models, see the [SambaNova supported models documentation](https://docs.sambanova.ai/docs/en/models/sambacloud-models). + +You can customize the base URL if needed: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.sambanova import SambaNovaProvider + +model = OpenAIChatModel( + 'DeepSeek-R1-0528', + provider=SambaNovaProvider( + api_key='your-api-key', + base_url='https://custom.endpoint.com/v1', + ), +) +agent = Agent(model) +... +``` diff --git a/docs/models/overview.md b/docs/models/overview.md index c7fe46993b..04e7a77014 100644 --- a/docs/models/overview.md +++ b/docs/models/overview.md @@ -30,6 +30,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used - [Ollama](openai.md#ollama) - [OVHcloud AI Endpoints](openai.md#ovhcloud-ai-endpoints) - [Perplexity](openai.md#perplexity) +- [SambaNova](openai.md#sambanova) - [Together AI](openai.md#together-ai) - [Vercel AI Gateway](openai.md#vercel-ai-gateway) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index bcad0a0bf0..7b7657f407 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -525,35 +525,37 @@ OpenAIChatCompatibleProvider = TypeAliasType( 'OpenAIChatCompatibleProvider', Literal[ + 'alibaba', 'azure', - 'deepseek', 'cerebras', + 'deepseek', 'fireworks', 'github', 'grok', 'heroku', + 'litellm', 'moonshotai', + 'nebius', 'ollama', 'openrouter', + 'ovhcloud', + 'sambanova', 'together', 'vercel', - 'litellm', - 'nebius', - 'ovhcloud', - 'alibaba', ], ) OpenAIResponsesCompatibleProvider = TypeAliasType( 'OpenAIResponsesCompatibleProvider', Literal[ - 'deepseek', 'azure', - 'openrouter', - 'grok', + 'deepseek', 'fireworks', - 'together', + 'grok', 'nebius', + 'openrouter', 'ovhcloud', + 'sambanova', + 'together', ], ) diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index eb29e341cc..0bf498b5b2 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -149,6 +149,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901 from .alibaba import AlibabaProvider return AlibabaProvider + elif provider == 'sambanova': + from .sambanova import SambaNovaProvider + + return SambaNovaProvider elif provider == 'outlines': from .outlines import OutlinesProvider diff --git a/pydantic_ai_slim/pydantic_ai/providers/sambanova.py b/pydantic_ai_slim/pydantic_ai/providers/sambanova.py new file mode 100644 index 0000000000..da5f5cc342 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/sambanova.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import os + +import httpx +from openai import AsyncOpenAI + +from pydantic_ai import ModelProfile +from pydantic_ai.exceptions import UserError +from pydantic_ai.models import cached_async_http_client +from pydantic_ai.profiles.deepseek import deepseek_model_profile +from pydantic_ai.profiles.meta import meta_model_profile +from pydantic_ai.profiles.mistral import mistral_model_profile +from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile +from pydantic_ai.profiles.qwen import qwen_model_profile +from pydantic_ai.providers import Provider + +try: + from openai import AsyncOpenAI +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `openai` package to use the SambaNova provider, ' + 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' + ) from _import_error + +__all__ = ['SambaNovaProvider'] + + +class SambaNovaProvider(Provider[AsyncOpenAI]): + """Provider for SambaNova AI models. + + SambaNova uses an OpenAI-compatible API. + """ + + @property + def name(self) -> str: + """Return the provider name.""" + return 'sambanova' + + @property + def base_url(self) -> str: + """Return the base URL.""" + return self._base_url + + @property + def client(self) -> AsyncOpenAI: + """Return the AsyncOpenAI client.""" + return self._client + + def model_profile(self, model_name: str) -> ModelProfile | None: + """Get model profile for SambaNova models. + + SambaNova serves models from multiple families including Meta Llama, + DeepSeek, Qwen, and Mistral. Model profiles are matched based on + model name prefixes. + """ + prefix_to_profile = { + 'deepseek-': deepseek_model_profile, + 'meta-llama-': meta_model_profile, + 'llama-': meta_model_profile, + 'qwen': qwen_model_profile, + 'mistral': mistral_model_profile, + } + + profile = None + model_name_lower = model_name.lower() + + for prefix, profile_func in prefix_to_profile.items(): + if model_name_lower.startswith(prefix): + profile = profile_func(model_name) + break + + # Wrap into OpenAIModelProfile since SambaNova is OpenAI-compatible + return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) + + def __init__( + self, + *, + api_key: str | None = None, + base_url: str | None = None, + openai_client: AsyncOpenAI | None = None, + http_client: httpx.AsyncClient | None = None, + ) -> None: + """Initialize SambaNova provider. + + Args: + api_key: SambaNova API key. If not provided, reads from SAMBANOVA_API_KEY env var. + base_url: Custom API base URL. Defaults to https://api.sambanova.ai/v1 + openai_client: Optional pre-configured OpenAI client + http_client: Optional custom httpx.AsyncClient for making HTTP requests + + Raises: + UserError: If API key is not provided and SAMBANOVA_API_KEY env var is not set + """ + if openai_client is not None: + self._client = openai_client + self._base_url = str(openai_client.base_url) + else: + # Get API key from parameter or environment + api_key = api_key or os.getenv('SAMBANOVA_API_KEY') + if not api_key: + raise UserError( + 'Set the `SAMBANOVA_API_KEY` environment variable or pass it via ' + '`SambaNovaProvider(api_key=...)` to use the SambaNova provider.' + ) + + # Set base URL (default to SambaNova API endpoint) + self._base_url = base_url or os.getenv('SAMBANOVA_BASE_URL', 'https://api.sambanova.ai/v1') + + # Create http client and AsyncOpenAI client + http_client = http_client or cached_async_http_client(provider='sambanova') + self._client = AsyncOpenAI(base_url=self._base_url, api_key=api_key, http_client=http_client) diff --git a/tests/providers/test_sambanova_provider.py b/tests/providers/test_sambanova_provider.py new file mode 100644 index 0000000000..f96d14114f --- /dev/null +++ b/tests/providers/test_sambanova_provider.py @@ -0,0 +1,116 @@ +import httpx +import pytest + +from pydantic_ai.exceptions import UserError +from pydantic_ai.profiles.openai import OpenAIModelProfile + +from ..conftest import TestEnv, try_import + +with try_import() as imports_successful: + import openai + + from pydantic_ai.providers import infer_provider + from pydantic_ai.providers.sambanova import SambaNovaProvider + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') + + +def test_sambanova_provider_init(): + provider = SambaNovaProvider(api_key='test-key') + assert provider.name == 'sambanova' + assert provider.base_url == 'https://api.sambanova.ai/v1' + assert isinstance(provider.client, openai.AsyncOpenAI) + assert provider.client.api_key == 'test-key' + + +def test_sambanova_provider_env_key(env: TestEnv): + env.set('SAMBANOVA_API_KEY', 'env-key') + provider = SambaNovaProvider() + assert provider.client.api_key == 'env-key' + + +def test_sambanova_provider_missing_key(env: TestEnv): + env.remove('SAMBANOVA_API_KEY') + with pytest.raises(UserError, match='Set the `SAMBANOVA_API_KEY`'): + SambaNovaProvider() + + +def test_infer_provider(env: TestEnv): + # infer_provider instantiates the class, so we need an env var or it raises UserError + env.set('SAMBANOVA_API_KEY', 'key') + provider = infer_provider('sambanova') + assert isinstance(provider, SambaNovaProvider) + + +def test_meta_llama_profile(): + provider = SambaNovaProvider(api_key='key') + # Meta Llama model -> expect meta profile wrapped in OpenAI compatibility + profile = provider.model_profile('Meta-Llama-3.1-8B-Instruct') + assert isinstance(profile, OpenAIModelProfile) + assert profile is not None + + +def test_deepseek_profile(): + provider = SambaNovaProvider(api_key='key') + # DeepSeek model -> expect deepseek profile wrapped in OpenAI compatibility + profile = provider.model_profile('DeepSeek-R1-0528') + assert isinstance(profile, OpenAIModelProfile) + assert profile is not None + + +def test_qwen_profile(): + provider = SambaNovaProvider(api_key='key') + # Qwen model -> expect qwen profile wrapped in OpenAI compatibility + profile = provider.model_profile('Qwen3-32B') + assert isinstance(profile, OpenAIModelProfile) + assert profile is not None + + +def test_llama4_profile(): + provider = SambaNovaProvider(api_key='key') + # Llama 4 model -> expect meta profile wrapped in OpenAI compatibility + profile = provider.model_profile('Llama-4-Maverick-17B-128E-Instruct') + assert isinstance(profile, OpenAIModelProfile) + assert profile is not None + + +def test_mistral_profile(): + provider = SambaNovaProvider(api_key='key') + # Mistral-based model -> expect mistral profile wrapped in OpenAI compatibility + profile = provider.model_profile('E5-Mistral-7B-Instruct') + assert isinstance(profile, OpenAIModelProfile) + assert profile is not None + + +def test_unknown_model_profile(): + provider = SambaNovaProvider(api_key='key') + # Unknown model -> should return OpenAI compatibility wrapper with None base profile + profile = provider.model_profile('unknown-model') + assert isinstance(profile, OpenAIModelProfile) + + +def test_sambanova_provider_with_openai_client(): + client = openai.AsyncOpenAI(api_key='foo', base_url='https://api.sambanova.ai/v1') + provider = SambaNovaProvider(openai_client=client) + assert provider.client is client + + +def test_sambanova_provider_with_http_client(): + http_client = httpx.AsyncClient() + provider = SambaNovaProvider(api_key='foo', http_client=http_client) + assert provider.client.api_key == 'foo' + # The line `self._client = AsyncOpenAI(..., http_client=http_client)` is executed, + # which is enough for coverage. + + +def test_sambanova_provider_custom_base_url(): + provider = SambaNovaProvider(api_key='test-key', base_url='https://custom.endpoint.com/v1') + assert provider.base_url == 'https://custom.endpoint.com/v1' + assert str(provider.client.base_url).rstrip('/') == 'https://custom.endpoint.com/v1' + + +def test_sambanova_provider_env_base_url(env: TestEnv): + env.set('SAMBANOVA_API_KEY', 'key') + env.set('SAMBANOVA_BASE_URL', 'https://env.endpoint.com/v1') + provider = SambaNovaProvider() + assert provider.base_url == 'https://env.endpoint.com/v1' diff --git a/tests/test_examples.py b/tests/test_examples.py index c6146ea406..b86bfe74a0 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -184,6 +184,7 @@ def print(self, *args: Any, **kwargs: Any) -> None: env.set('DEEPSEEK_API_KEY', 'testing') env.set('OVHCLOUD_API_KEY', 'testing') env.set('ALIBABA_API_KEY', 'testing') + env.set('SAMBANOVA_API_KEY', 'testing') env.set('PYDANTIC_AI_GATEWAY_API_KEY', 'testing') prefix_settings = example.prefix_settings()