Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions app/services/provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,24 @@ async def cache_models(
def _get_adapters(self) -> dict[str, ProviderAdapter]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any usage of such function. I think we could remove it.

Also, in the init function, we do self.adapters = self._get_adapters(). It's not needed as we don't use the member variable adapters

"""Get adapters from cache or create new ones"""
if not ProviderService._adapters_cache:
ProviderService._adapters_cache = ProviderAdapterFactory.get_all_adapters()
ProviderService._adapters_cache = {}
return ProviderService._adapters_cache

def _get_or_create_adapter(self, provider_name: str, base_url: str | None = None, config: dict[str, Any] | None = None) -> ProviderAdapter:
"""Get an adapter instance from cache or create a new one"""
# Create a cache key that includes provider name, base_url, and config hash
config_hash = hash(frozenset((config or {}).items()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Write a dedicated function in the cache file to create/get/invalidate such cache

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, it's not an actual "cache". NVM, this should be good

cache_key = f"{provider_name}:{base_url or 'default'}:{config_hash}"

if cache_key not in self._adapters_cache:
adapter = ProviderAdapterFactory.get_adapter(provider_name, base_url, config)
self._adapters_cache[cache_key] = adapter
logger.debug(f"Created new adapter instance for {cache_key}")
else:
logger.debug(f"Using cached adapter instance for {cache_key}")

return self._adapters_cache[cache_key]

async def _load_provider_keys(self) -> dict[str, dict[str, Any]]:
"""Load all provider keys for the user synchronously, with lazy loading and caching."""
if self._keys_loaded:
Expand Down Expand Up @@ -437,7 +452,7 @@ async def _list_models_helper(
base_url = self.provider_keys[provider_name]["base_url"]
tasks.append(
_list_models_helper(
ProviderAdapterFactory.get_adapter(provider_name, base_url, config),
self._get_or_create_adapter(provider_name, base_url, config),
api_key,
provider_data,
)
Expand Down Expand Up @@ -503,8 +518,8 @@ async def process_request(
serialized_api_key_config
)

# Get the appropriate adapter
adapter = ProviderAdapterFactory.get_adapter(provider_name, base_url, config)
# Get the appropriate adapter (cached)
adapter = self._get_or_create_adapter(provider_name, base_url, config)

# Process the request through the adapter
usage_tracker_id = None
Expand Down
Loading