diff --git a/src/api/auth.py b/src/api/auth.py index 1a64653d..9c274e7d 100644 --- a/src/api/auth.py +++ b/src/api/auth.py @@ -2,11 +2,11 @@ import os from typing import Annotated -import boto3 from botocore.exceptions import ClientError from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from api.models.bedrock import client_manager from api.setting import DEFAULT_API_KEYS api_key_param = os.environ.get("API_KEY_PARAM_NAME") @@ -15,10 +15,10 @@ if api_key_param: # For backward compatibility. # Please now use secrets manager instead. - ssm = boto3.client("ssm") + ssm = client_manager.get_client("ssm") api_key = ssm.get_parameter(Name=api_key_param, WithDecryption=True)["Parameter"]["Value"] elif api_key_secret_arn: - sm = boto3.client("secretsmanager") + sm = client_manager.get_client("secretsmanager") try: response = sm.get_secret_value(SecretId=api_key_secret_arn) if "SecretString" in response: @@ -41,4 +41,4 @@ def api_key_auth( credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)], ): if credentials.credentials != api_key: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key") + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key") \ No newline at end of file diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 270a6bb0..902eca07 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -3,8 +3,9 @@ import logging import re import time +import threading from abc import ABC -from typing import AsyncIterable, Iterable, Literal +from typing import AsyncIterable, Dict, Iterable, Literal, Optional import boto3 import numpy as np @@ -42,18 +43,64 @@ logger = logging.getLogger(__name__) -config = Config(connect_timeout=60, read_timeout=120, retries={"max_attempts": 1}) -bedrock_runtime = boto3.client( - service_name="bedrock-runtime", - region_name=AWS_REGION, - config=config, -) -bedrock_client = boto3.client( - service_name="bedrock", - region_name=AWS_REGION, - config=config, -) +class BedrockClientManager: + """ + Singleton class to manage AWS Bedrock client connections with connection pooling. + """ + _instance = None + _lock = threading.RLock() + _clients: Dict[str, any] = {} + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super(BedrockClientManager, cls).__new__(cls) + return cls._instance + + def get_client(self, service_name: str, region_name: Optional[str] = None) -> any: + """ + Get or create a boto3 client with connection pooling. + + Args: + service_name: The AWS service name (e.g., "bedrock-runtime", "bedrock") + region_name: AWS region. If not specified, uses the default region. + + Returns: + A boto3 client instance + """ + region = region_name or AWS_REGION + key = f"{service_name}:{region}" + + with self._lock: + if key not in self._clients: + logger.debug(f"Creating new boto3 client for {service_name} in {region}") + + config = Config( + connect_timeout=60, + read_timeout=120, + retries={"max_attempts": 2, "mode": "adaptive"}, + max_pool_connections=50, # Improve connection reuse + tcp_keepalive=True # Keep connections alive + ) + + client = boto3.client( + service_name=service_name, + region_name=region, + config=config + ) + + self._clients[key] = client + + return self._clients[key] + + +# Create the client manager instance +client_manager = BedrockClientManager() + +# Get clients with connection pooling +bedrock_runtime = client_manager.get_client("bedrock-runtime") +bedrock_client = client_manager.get_client("bedrock") def get_inference_region_prefix(): @@ -865,4 +912,4 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: raise HTTPException( status_code=400, detail="Unsupported embedding model id " + model_id, - ) + ) \ No newline at end of file