diff --git a/aisuite/__init__.py b/aisuite/__init__.py index bc0eb2ce..539ce65a 100644 --- a/aisuite/__init__.py +++ b/aisuite/__init__.py @@ -1,3 +1,4 @@ from .client import Client +from .async_client import AsyncClient from .framework.message import Message from .utils.tools import Tools diff --git a/aisuite/async_client.py b/aisuite/async_client.py new file mode 100644 index 00000000..6a9389c6 --- /dev/null +++ b/aisuite/async_client.py @@ -0,0 +1,81 @@ +from .client import Client, Chat, Completions +from .base_client import BaseClient +from .provider import ProviderFactory +from .tool_runner import ToolRunner + + +class AsyncClient(BaseClient): + def __init__(self, provider_configs: dict = {}): + super().__init__(provider_configs, is_async=True) + + def configure(self, provider_configs: dict = None): + super().configure(provider_configs, True) + + @property + def chat(self): + """Return the async chat API interface.""" + if not self._chat: + self._chat = AsyncChat(self) + return self._chat + + +class AsyncChat(Chat): + def __init__(self, client: "AsyncClient"): + self.client = client + self._completions = AsyncCompletions(self.client) + + +class AsyncCompletions(Completions): + async def create(self, model: str, messages: list, **kwargs): + """ + Create async chat completion based on the model, messages, and any extra arguments. + Supports automatic tool execution when max_turns is specified. + """ + # Check that correct format is used + if ":" not in model: + raise ValueError( + f"Invalid model format. Expected 'provider:model', got '{model}'" + ) + + # Extract the provider key from the model identifier, e.g., "google:gemini-xx" + provider_key, model_name = model.split(":", 1) + + # Validate if the provider is supported + supported_providers = ProviderFactory.get_supported_providers() + if provider_key not in supported_providers: + raise ValueError( + f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. " + "Make sure the model string is formatted correctly as 'provider:model'." + ) + + # Initialize provider if not already initialized + if provider_key not in self.client.providers: + config = self.client.provider_configs.get(provider_key, {}) + self.client.providers[provider_key] = ProviderFactory.create_provider( + provider_key, config, is_async=True + ) + + provider = self.client.providers.get(provider_key) + if not provider: + raise ValueError(f"Could not load provider for '{provider_key}'.") + + # Extract tool-related parameters + max_turns = kwargs.pop("max_turns", None) + tools = kwargs.get("tools", None) + automatic_tool_calling = kwargs.get("automatic_tool_calling", False) + + # Check environment variable before allowing multi-turn tool execution + if max_turns is not None and tools is not None: + tool_runner = ToolRunner(provider, model_name, messages.copy(), tools, max_turns, automatic_tool_calling) + return await tool_runner.run_async( + provider, + model_name, + messages.copy(), + tools, + max_turns, + ) + + # Default behavior without tool execution + # Delegate the chat completion to the correct provider's async implementation + response = await provider.chat_completions_create_async(model_name, messages, **kwargs) + return self._extract_thinking_content(response) diff --git a/aisuite/base_client.py b/aisuite/base_client.py new file mode 100644 index 00000000..afbec2d1 --- /dev/null +++ b/aisuite/base_client.py @@ -0,0 +1,67 @@ +from .provider import ProviderFactory +from abc import ABC, abstractproperty + +class BaseClient(ABC): + def __init__(self, provider_configs: dict = {}, is_async: bool = False): + """ + Initialize the client with provider configurations. + Use the ProviderFactory to create provider instances. + + Args: + provider_configs (dict): A dictionary containing provider configurations. + Each key should be a provider string (e.g., "google" or "aws-bedrock"), + and the value should be a dictionary of configuration options for that provider. + For example: + { + "openai": {"api_key": "your_openai_api_key"}, + "aws-bedrock": { + "aws_access_key": "your_aws_access_key", + "aws_secret_key": "your_aws_secret_key", + "aws_region": "us-west-2" + } + } + """ + self.providers = {} + self.provider_configs = provider_configs + self._chat = None + self._initialize_providers(is_async) + + + def _initialize_providers(self, is_async): + """Helper method to initialize or update providers.""" + for provider_key, config in self.provider_configs.items(): + provider_key = self._validate_provider_key(provider_key) + self.providers[provider_key] = ProviderFactory.create_provider( + provider_key, config, is_async + ) + + def _validate_provider_key(self, provider_key): + """ + Validate if the provider key corresponds to a supported provider. + """ + supported_providers = ProviderFactory.get_supported_providers() + + if provider_key not in supported_providers: + raise ValueError( + f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. " + "Make sure the model string is formatted correctly as 'provider:model'." + ) + + return provider_key + + def configure(self, provider_configs: dict = None, is_async: bool = False): + """ + Configure the client with provider configurations. + """ + if provider_configs is None: + return + + self.provider_configs.update(provider_configs) + self._initialize_providers(is_async) # NOTE: This will override existing provider instances. + + + @property + @abstractproperty + def chat(self): + """Return the chat API interface.""" + raise NotImplementedError("Chat is not implemented for this client.") diff --git a/aisuite/client.py b/aisuite/client.py index ca8b326c..b0d29a72 100644 --- a/aisuite/client.py +++ b/aisuite/client.py @@ -1,9 +1,11 @@ from .provider import ProviderFactory +from .base_client import BaseClient import os from .utils.tools import Tools +from .tool_runner import ToolRunner -class Client: +class Client(BaseClient): def __init__(self, provider_configs: dict = {}): """ Initialize the client with provider configurations. @@ -23,42 +25,10 @@ def __init__(self, provider_configs: dict = {}): } } """ - self.providers = {} - self.provider_configs = provider_configs - self._chat = None - self._initialize_providers() - - def _initialize_providers(self): - """Helper method to initialize or update providers.""" - for provider_key, config in self.provider_configs.items(): - provider_key = self._validate_provider_key(provider_key) - self.providers[provider_key] = ProviderFactory.create_provider( - provider_key, config - ) - - def _validate_provider_key(self, provider_key): - """ - Validate if the provider key corresponds to a supported provider. - """ - supported_providers = ProviderFactory.get_supported_providers() - - if provider_key not in supported_providers: - raise ValueError( - f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. " - "Make sure the model string is formatted correctly as 'provider:model'." - ) - - return provider_key + super().__init__(provider_configs, is_async=False) def configure(self, provider_configs: dict = None): - """ - Configure the client with provider configurations. - """ - if provider_configs is None: - return - - self.provider_configs.update(provider_configs) - self._initialize_providers() # NOTE: This will override existing provider instances. + super().configure(provider_configs, False) @property def chat(self): @@ -111,88 +81,6 @@ def _extract_thinking_content(self, response): return response - def _tool_runner( - self, - provider, - model_name: str, - messages: list, - tools: any, - max_turns: int, - **kwargs, - ): - """ - Handle tool execution loop for max_turns iterations. - - Args: - provider: The provider instance to use for completions - model_name: Name of the model to use - messages: List of conversation messages - tools: Tools instance or list of callable tools - max_turns: Maximum number of tool execution turns - **kwargs: Additional arguments to pass to the provider - - Returns: - The final response from the model with intermediate responses and messages - """ - # Handle tools validation and conversion - if isinstance(tools, Tools): - tools_instance = tools - kwargs["tools"] = tools_instance.tools() - else: - # Check if passed tools are callable - if not all(callable(tool) for tool in tools): - raise ValueError("One or more tools is not callable") - tools_instance = Tools(tools) - kwargs["tools"] = tools_instance.tools() - - turns = 0 - intermediate_responses = [] # Store intermediate responses - intermediate_messages = [] # Store all messages including tool interactions - - while turns < max_turns: - # Make the API call - response = provider.chat_completions_create(model_name, messages, **kwargs) - response = self._extract_thinking_content(response) - - # Store intermediate response - intermediate_responses.append(response) - - # Check if there are tool calls in the response - tool_calls = ( - getattr(response.choices[0].message, "tool_calls", None) - if hasattr(response, "choices") - else None - ) - - # Store the model's message - intermediate_messages.append(response.choices[0].message) - - if not tool_calls: - # Set the intermediate data in the final response - response.intermediate_responses = intermediate_responses[ - :-1 - ] # Exclude final response - response.choices[0].intermediate_messages = intermediate_messages - return response - - # Execute tools and get results - results, tool_messages = tools_instance.execute_tool(tool_calls) - - # Add tool messages to intermediate messages - intermediate_messages.extend(tool_messages) - - # Add the assistant's response and tool results to messages - messages.extend([response.choices[0].message, *tool_messages]) - - turns += 1 - - # Set the intermediate data in the final response - response.intermediate_responses = intermediate_responses[ - :-1 - ] # Exclude final response - response.choices[0].intermediate_messages = intermediate_messages - return response - def create(self, model: str, messages: list, **kwargs): """ Create chat completion based on the model, messages, and any extra arguments. @@ -229,10 +117,12 @@ def create(self, model: str, messages: list, **kwargs): # Extract tool-related parameters max_turns = kwargs.pop("max_turns", None) tools = kwargs.get("tools", None) + automatic_tool_calling = kwargs.get("automatic_tool_calling", False) # Check environment variable before allowing multi-turn tool execution if max_turns is not None and tools is not None: - return self._tool_runner( + tool_runner = ToolRunner(provider, model_name, messages.copy(), tools, max_turns, automatic_tool_calling) + return tool_runner.run( provider, model_name, messages.copy(), diff --git a/aisuite/provider.py b/aisuite/provider.py index f53afe27..b0c87a8d 100644 --- a/aisuite/provider.py +++ b/aisuite/provider.py @@ -18,6 +18,12 @@ def chat_completions_create(self, model, messages): """Abstract method for chat completion calls, to be implemented by each provider.""" pass +class AsyncProvider(ABC): + @abstractmethod + async def chat_completions_create_async(self, model, messages, **kwargs): + """Method for async chat completion calls, to be implemented by each provider.""" + raise NotImplementedError("Async chat completion calls are not implemented for this provider.") + class ProviderFactory: """Factory to dynamically load provider instances based on naming conventions.""" @@ -25,10 +31,11 @@ class ProviderFactory: PROVIDERS_DIR = Path(__file__).parent / "providers" @classmethod - def create_provider(cls, provider_key, config): + def create_provider(cls, provider_key, config, is_async=False): """Dynamically load and create an instance of a provider based on the naming convention.""" # Convert provider_key to the expected module and class names - provider_class_name = f"{provider_key.capitalize()}Provider" + async_suffix = "Async" if is_async else "" + provider_class_name = f"{provider_key.capitalize()}{async_suffix}Provider" provider_module_name = f"{provider_key}_provider" module_path = f"aisuite.providers.{provider_module_name}" diff --git a/aisuite/providers/anthropic_provider.py b/aisuite/providers/anthropic_provider.py index b7edf71c..46aedfb0 100644 --- a/aisuite/providers/anthropic_provider.py +++ b/aisuite/providers/anthropic_provider.py @@ -4,7 +4,7 @@ import anthropic import json -from aisuite.provider import Provider +from aisuite.provider import Provider, AsyncProvider from aisuite.framework import ChatCompletionResponse from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function @@ -222,3 +222,29 @@ def _prepare_kwargs(self, kwargs): kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"]) return kwargs + +class AnthropicAsyncProvider(AsyncProvider): + def __init__(self, **config): + """Initialize the Anthropic provider with the given configuration.""" + self.async_client = anthropic.AsyncAnthropic(**config) + self.converter = AnthropicMessageConverter() + + async def chat_completions_create_async(self, model, messages, **kwargs): + """Create a chat completion using the async Anthropic API.""" + kwargs = self._prepare_kwargs(kwargs) + system_message, converted_messages = self.converter.convert_request(messages) + + response = await self.async_client.messages.create( + model=model, system=system_message, messages=converted_messages, **kwargs + ) + return self.converter.convert_response(response) + + def _prepare_kwargs(self, kwargs): + """Prepare kwargs for the API call.""" + kwargs = kwargs.copy() + kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS) + + if "tools" in kwargs: + kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"]) + + return kwargs diff --git a/aisuite/providers/fireworks_provider.py b/aisuite/providers/fireworks_provider.py index 10bea195..c9773577 100644 --- a/aisuite/providers/fireworks_provider.py +++ b/aisuite/providers/fireworks_provider.py @@ -1,7 +1,7 @@ import os import httpx import json -from aisuite.provider import Provider, LLMError +from aisuite.provider import Provider, AsyncProvider, LLMError from aisuite.framework import ChatCompletionResponse from aisuite.framework.message import Message, ChatCompletionMessageToolCall @@ -130,6 +130,80 @@ def chat_completions_create(self, model, messages, **kwargs): except Exception as e: raise LLMError(f"An error occurred: {e}") +class FireworksAsyncProvider(AsyncProvider): + """ + Fireworks AI Provider using httpx for direct API calls. + """ + + BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions" + + def __init__(self, **config): + """ + Initialize the Fireworks provider with the given configuration. + The API key is fetched from the config or environment variables. + """ + self.api_key = config.get("api_key", os.getenv("FIREWORKS_API_KEY")) + if not self.api_key: + raise ValueError( + "Fireworks API key is missing. Please provide it in the config or set the FIREWORKS_API_KEY environment variable." + ) + + # Optionally set a custom timeout (default to 30s) + self.timeout = config.get("timeout", 30) + self.transformer = FireworksMessageConverter() + + async def chat_completions_create_async(self, model, messages, **kwargs): + """ + Makes an async request to the Fireworks AI chat completions endpoint. + """ + # Remove 'stream' from kwargs if present + kwargs.pop("stream", None) + + # Transform messages using converter + transformed_messages = self.transformer.convert_request(messages) + + # Prepare the request payload + data = { + "model": model, + "messages": transformed_messages, + } + + # Add tools if provided + if "tools" in kwargs: + data["tools"] = kwargs["tools"] + kwargs.pop("tools") + + # Add tool_choice if provided + if "tool_choice" in kwargs: + data["tool_choice"] = kwargs["tool_choice"] + kwargs.pop("tool_choice") + + # Add remaining kwargs + data.update(kwargs) + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient() as client: + try: + # Make the async request to Fireworks AI endpoint + response = await client.post( + self.BASE_URL, json=data, headers=headers, timeout=self.timeout + ) + response.raise_for_status() + return self.transformer.convert_response(response.json()) + except httpx.HTTPStatusError as error: + error_message = ( + f"The request failed with status code: {error.status_code}\n" + ) + error_message += f"Headers: {error.headers}\n" + error_message += error.response.text + raise LLMError(error_message) + except Exception as e: + raise LLMError(f"An error occurred: {e}") + def _normalize_response(self, response_data): """ Normalize the response to a common format (ChatCompletionResponse). diff --git a/aisuite/providers/mistral_provider.py b/aisuite/providers/mistral_provider.py index 4fc28fab..9b76b4c0 100644 --- a/aisuite/providers/mistral_provider.py +++ b/aisuite/providers/mistral_provider.py @@ -2,7 +2,7 @@ from mistralai import Mistral from aisuite.framework.message import Message from aisuite.framework import ChatCompletionResponse -from aisuite.provider import Provider, LLMError +from aisuite.provider import Provider, AsyncProvider, LLMError from aisuite.providers.message_converter import OpenAICompliantMessageConverter @@ -72,3 +72,38 @@ def chat_completions_create(self, model, messages, **kwargs): return self.transformer.convert_response(response) except Exception as e: raise LLMError(f"An error occurred: {e}") + +class MistralAsyncProvider(AsyncProvider): + """ + Mistral AI Provider using the official Mistral client. + """ + + def __init__(self, **config): + """ + Initialize the Mistral provider with the given configuration. + Pass the entire configuration dictionary to the Mistral client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("MISTRAL_API_KEY")) + if not config["api_key"]: + raise ValueError( + "Mistral API key is missing. Please provide it in the config or set the MISTRAL_API_KEY environment variable." + ) + self.client = Mistral(**config) + self.transformer = MistralMessageConverter() + + async def chat_completions_create_async(self, model, messages, **kwargs): + """ + Makes a request to Mistral using the official client. + """ + try: + # Transform messages using converter + transformed_messages = self.transformer.convert_request(messages) + + response = await self.client.chat.complete_async( + model=model, messages=transformed_messages, **kwargs + ) + + return self.transformer.convert_response(response) + except Exception as e: + raise LLMError(f"An error occurred: {e}") diff --git a/aisuite/providers/openai_provider.py b/aisuite/providers/openai_provider.py index 8cb1b6c5..c5da1ef1 100644 --- a/aisuite/providers/openai_provider.py +++ b/aisuite/providers/openai_provider.py @@ -1,6 +1,6 @@ import openai import os -from aisuite.provider import Provider, LLMError +from aisuite.provider import Provider, AsyncProvider, LLMError from aisuite.providers.message_converter import OpenAICompliantMessageConverter @@ -38,3 +38,38 @@ def chat_completions_create(self, model, messages, **kwargs): return response except Exception as e: raise LLMError(f"An error occurred: {e}") + +class OpenaiAsyncProvider(AsyncProvider): + def __init__(self, **config): + """ + Initialize the OpenAI provider with the given configuration. + Pass the entire configuration dictionary to the OpenAI client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("OPENAI_API_KEY")) + if not config["api_key"]: + raise ValueError( + "OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." + ) + + # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically + # infer certain values from the environment variables. + # Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, OPENAI_BASE_URL, etc. + + # Pass the entire config to the OpenAI client constructor + self.async_client = openai.AsyncOpenAI(**config) + self.transformer = OpenAICompliantMessageConverter() + + async def chat_completions_create_async(self, model, messages, **kwargs): + # Any exception raised by OpenAI will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + try: + transformed_messages = self.transformer.convert_request(messages) + response = await self.async_client.chat.completions.create( + model=model, + messages=transformed_messages, + **kwargs # Pass any additional arguments to the OpenAI API + ) + return response + except Exception as e: + raise LLMError(f"An error occurred: {e}") diff --git a/aisuite/tool_runner.py b/aisuite/tool_runner.py new file mode 100644 index 00000000..b25507fb --- /dev/null +++ b/aisuite/tool_runner.py @@ -0,0 +1,108 @@ +from .provider import ProviderFactory +from .utils.tools import Tools +import asyncio + +class ToolRunner: + def __init__(self, provider, model_name, messages, tools, max_turns, automatic_tool_calling): + self.provider = provider + self.model_name = model_name + self.messages = messages + self.tools = tools + self.max_turns = max_turns + self.automatic_tool_calling = automatic_tool_calling + + async def run_async( + self, + provider, + model_name: str, + messages: list, + tools: any, + max_turns: int, + **kwargs, + ): + """ + Handle tool execution loop for max_turns iterations. + + Args: + provider: The provider instance to use for completions + model_name: Name of the model to use + messages: List of conversation messages + tools: Tools instance or list of callable tools + max_turns: Maximum number of tool execution turns + **kwargs: Additional arguments to pass to the provider + + Returns: + The final response from the model with intermediate responses and messages + """ + # Handle tools validation and conversion + if isinstance(tools, Tools): + tools_instance = tools + kwargs["tools"] = tools_instance.tools() + else: + # Check if passed tools are callable + if not all(callable(tool) for tool in tools): + raise ValueError("One or more tools is not callable") + tools_instance = Tools(tools) + kwargs["tools"] = tools_instance.tools() + + turns = 0 + intermediate_responses = [] # Store intermediate responses + intermediate_messages = [] # Store all messages including tool interactions + + while turns < max_turns: + # Make the API call + if provider is AsyncProvider: + response = await provider.chat_completions_create_async(model_name, messages, **kwargs) + else: + response = provider.chat_completions_create(model_name, messages, **kwargs) + response = self._extract_thinking_content(response) + + # Store intermediate response + intermediate_responses.append(response) + + # Check if there are tool calls in the response + tool_calls = ( + getattr(response.choices[0].message, "tool_calls", None) + if hasattr(response, "choices") + else None + ) + + # Store the model's message + intermediate_messages.append(response.choices[0].message) + + if not tool_calls or not self.automatic_tool_calling: + # Set the intermediate data in the final response + response.intermediate_responses = intermediate_responses[ + :-1 + ] # Exclude final response + response.choices[0].intermediate_messages = intermediate_messages + return response + + # Execute tools and get results + results, tool_messages = tools_instance.execute_tool(tool_calls) + + # Add tool messages to intermediate messages + intermediate_messages.extend(tool_messages) + + # Add the assistant's response and tool results to messages + messages.extend([response.choices[0].message, *tool_messages]) + + turns += 1 + + # Set the intermediate data in the final response + response.intermediate_responses = intermediate_responses[ + :-1 + ] # Exclude final response + response.choices[0].intermediate_messages = intermediate_messages + return response + + def run( + self, + provider, + model_name: str, + messages: list, + tools: any, + max_turns: int, + **kwargs, + ): + return asyncio.run(self.run_async(provider, model_name, messages, tools, max_turns, **kwargs)) diff --git a/pyproject.toml b/pyproject.toml index 113f30a0..44dc6b5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ optional = true [tool.poetry.group.test.dependencies] pytest = "^8.2.2" pytest-cov = "^6.0.0" +pytest-asyncio = "^0.25.3" [build-system] requires = ["poetry-core"] diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 04d12435..7ea903ce 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -2,7 +2,7 @@ import pytest -from aisuite import Client +from aisuite import Client, AsyncClient @pytest.fixture(scope="module") @@ -152,3 +152,83 @@ def test_invalid_model_format_in_create(monkeypatch): ValueError, match=r"Invalid model format. Expected 'provider:model'" ): client.chat.completions.create(invalid_model, messages=messages) + + +@pytest.mark.parametrize( + argnames=("patch_target", "provider", "model"), + argvalues=[ + ( + "aisuite.providers.openai_provider.OpenaiAsyncProvider.chat_completions_create_async", + "openai", + "gpt-4o", + ), + ( + "aisuite.providers.mistral_provider.MistralAsyncProvider.chat_completions_create_async", + "mistral", + "mistral-model", + ), + ( + "aisuite.providers.anthropic_provider.AnthropicAsyncProvider.chat_completions_create_async", + "anthropic", + "anthropic-model", + ), + ( + "aisuite.providers.fireworks_provider.FireworksAsyncProvider.chat_completions_create_async", + "fireworks", + "fireworks-model", + ), + ], +) +@pytest.mark.asyncio +async def test_async_client_chat_completions( + provider_configs: dict, patch_target: str, provider: str, model: str +): + expected_response = f"{patch_target}_{provider}_{model}" + with patch(patch_target) as mock_provider: + mock_provider.return_value = expected_response + client = AsyncClient() + client.configure(provider_configs) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + ] + + model_str = f"{provider}:{model}" + model_response = await client.chat.completions.create( + model_str, messages=messages + ) + assert model_response == expected_response + + +@pytest.mark.asyncio +async def test_invalid_model_format_in_async_create(monkeypatch): + from aisuite.providers.openai_provider import OpenaiAsyncProvider + + monkeypatch.setattr( + target=OpenaiAsyncProvider, + name="chat_completions_create_async", + value=Mock(), + ) + + # Valid provider configurations + provider_configs = { + "openai": {"api_key": "test_openai_api_key"}, + } + + # Initialize the client with valid provider + client = AsyncClient() + client.configure(provider_configs) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + + # Invalid model format + invalid_model = "invalidmodel" + + # Expect ValueError when calling create with invalid model format and verify message + with pytest.raises( + ValueError, match=r"Invalid model format. Expected 'provider:model'" + ): + await client.chat.completions.create(invalid_model, messages=messages) diff --git a/tests/client/test_prerelease.py b/tests/client/test_prerelease.py index bb5f3285..2b41af50 100644 --- a/tests/client/test_prerelease.py +++ b/tests/client/test_prerelease.py @@ -13,6 +13,12 @@ def setup_client() -> ai.Client: return ai.Client() +def setup_async_client() -> ai.AsyncClient: + """Initialize the async AI client with environment variables.""" + load_dotenv(find_dotenv()) + return ai.AsyncClient() + + def get_test_models() -> List[str]: """Return a list of model identifiers to test.""" return [ @@ -26,6 +32,15 @@ def get_test_models() -> List[str]: ] +def get_test_async_models() -> List[str]: + """Return a list of model identifiers to test.""" + return [ + "anthropic:claude-3-5-sonnet-20240620", + "mistral:open-mistral-7b", + "openai:gpt-3.5-turbo", + ] + + def get_test_messages() -> List[Dict[str, str]]: """Return the test messages to send to each model.""" return [ @@ -70,5 +85,39 @@ def test_model_pirate_response(model_id: str): pytest.fail(f"Error testing model {model_id}: {str(e)}") +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("model_id", get_test_async_models()) +async def test_async_model_pirate_response(model_id: str): + """ + Test that each model responds appropriately to the pirate prompt using async client. + + Args: + model_id: The provider:model identifier to test + """ + client = setup_async_client() + messages = get_test_messages() + + try: + response = await client.chat.completions.create( + model=model_id, messages=messages, temperature=0.75 + ) + + content = response.choices[0].message.content.lower() + + # Check if either version of the required phrase is present + assert any( + phrase in content for phrase in ["no rum no fun", "no rum, no fun"] + ), f"Model {model_id} did not include required phrase 'No rum No fun'" + + assert len(content) > 0, f"Model {model_id} returned empty response" + assert isinstance( + content, str + ), f"Model {model_id} returned non-string response" + + except Exception as e: + pytest.fail(f"Error testing model {model_id}: {str(e)}") + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/providers/test_mistral_provider.py b/tests/providers/test_mistral_provider.py index 937da0e5..427e95a6 100644 --- a/tests/providers/test_mistral_provider.py +++ b/tests/providers/test_mistral_provider.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import patch, MagicMock -from aisuite.providers.mistral_provider import MistralProvider +from aisuite.providers.mistral_provider import MistralProvider, MistralAsyncProvider @pytest.fixture(autouse=True) @@ -41,3 +41,37 @@ def test_mistral_provider(): ) assert response.choices[0].message.content == response_text_content + + +@pytest.mark.asyncio +async def test_mistral_provider_async(): + """High-level test that the provider handles async chat completions successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = MistralAsyncProvider() + mock_response = MagicMock() + mock_response.model_dump.return_value = { + "choices": [{"message": {"content": response_text_content}}] + } + + with patch.object( + provider.client.chat, "complete_async", return_value=mock_response + ) as mock_create: + response = await provider.chat_completions_create_async( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content