diff --git a/aisuite/client.py b/aisuite/client.py index ca8b326c..d96bd3dd 100644 --- a/aisuite/client.py +++ b/aisuite/client.py @@ -1,8 +1,7 @@ -from .provider import ProviderFactory +from .provider import ProviderFactory, Provider import os from .utils.tools import Tools - class Client: def __init__(self, provider_configs: dict = {}): """ @@ -222,7 +221,7 @@ def create(self, model: str, messages: list, **kwargs): provider_key, config ) - provider = self.client.providers.get(provider_key) + provider: Provider = self.client.providers.get(provider_key) if not provider: raise ValueError(f"Could not load provider for '{provider_key}'.")