diff --git a/deployment/BedrockProxy.template b/deployment/BedrockProxy.template index 17387dfb..0e5fd189 100644 --- a/deployment/BedrockProxy.template +++ b/deployment/BedrockProxy.template @@ -139,6 +139,10 @@ Resources: Properties: PolicyDocument: Statement: + - Action: + - iam:SimulatePrincipalPolicy + Effect: Allow + Resource: "*" - Action: - bedrock:ListFoundationModels - bedrock:ListInferenceProfiles diff --git a/deployment/BedrockProxyFargate.template b/deployment/BedrockProxyFargate.template index bae785cc..1fbddbd7 100644 --- a/deployment/BedrockProxyFargate.template +++ b/deployment/BedrockProxyFargate.template @@ -181,6 +181,10 @@ Resources: Properties: PolicyDocument: Statement: + - Action: + - iam:SimulatePrincipalPolicy + Effect: Allow + Resource: "*" - Action: - bedrock:ListFoundationModels - bedrock:ListInferenceProfiles diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index be3fab28..cb99429e 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -1,16 +1,18 @@ import base64 import json import logging +import random import re import time from abc import ABC -from typing import AsyncIterable, Iterable, Literal +from typing import AsyncIterable, Dict, Iterable, List, Literal import boto3 import numpy as np import requests import tiktoken from botocore.config import Config +from botocore.exceptions import ClientError from fastapi import HTTPException from starlette.concurrency import run_in_threadpool @@ -53,6 +55,10 @@ config=config, ) +# Create IAM and STS clients for permission checking +iam_client = boto3.client('iam', region_name=AWS_REGION, config=config) +sts_client = boto3.client('sts', region_name=AWS_REGION, config=config) + def get_inference_region_prefix(): if AWS_REGION.startswith("ap-"): @@ -74,15 +80,106 @@ def get_inference_region_prefix(): ENCODER = tiktoken.get_encoding("cl100k_base") +def check_model_permissions(principal_arn, model_arns): + """Check permissions for multiple models in a single API call. + + Args: + principal_arn: The ARN of the principal (user/role) + model_arns: List of model ARNs to check + + Returns: + dict: Mapping of model ARNs to boolean permission status + """ + # Convert assumed role ARN to IAM role ARN if needed + if ":assumed-role/" in principal_arn: + parts = principal_arn.split(':') + account_id = parts[4] + role_session = parts[5].split('/') + role_name = role_session[1] + principal_arn = f"arn:aws:iam::{account_id}:role/{role_name}" + + # Initialize results dictionary + permissions = {arn: False for arn in model_arns} + + try: + logger.info(f"Checking permissions for {len(model_arns)} models") + + # Both actions needed for streaming models + actions = ["bedrock:InvokeModel", "bedrock:InvokeModelWithResponseStream"] + + # Implement retry with exponential backoff + max_retries = 5 + retry_delay = 1 + last_exception = None + + for attempt in range(max_retries): + try: + response = iam_client.simulate_principal_policy( + PolicySourceArn=principal_arn, + ActionNames=actions, + ResourceArns=model_arns + ) + + # Process resource-specific results directly + for result in response['EvaluationResults']: + for resource_result in result['ResourceSpecificResults']: + resource = resource_result['EvalResourceName'] + if resource in permissions and resource_result['EvalResourceDecision'] == 'allowed': + permissions[resource] = True + + # Successfully got and processed the response + break + + except ClientError as e: + last_exception = e + error_code = e.response['Error']['Code'] + + if error_code == 'Throttling' and attempt < max_retries - 1: + # Exponential backoff with jitter + sleep_time = retry_delay * (2 ** attempt) + random.uniform(0, 1) + logger.warning(f"Throttling detected, retrying in {sleep_time:.2f} seconds...") + time.sleep(sleep_time) + else: + # Either not a throttling error or we've exhausted retries + raise + + if last_exception is not None and attempt == max_retries - 1: + # If we've used all retries and still have an exception + raise last_exception + + return permissions + + except Exception as e: + logger.error(f"Error checking model permissions: {str(e)}", exc_info=True) + # If we can't check permissions, default to not allowing any + return permissions + + def list_bedrock_models() -> dict: - """Automatically getting a list of supported models. + """Automatically getting a list of supported models that the user has permission to invoke. Returns a model list combines: - - ON_DEMAND models. + - ON_DEMAND models - Cross-Region Inference Profiles (if enabled via Env) """ model_list = {} try: + # Get current role/user ARN + try: + caller_identity = sts_client.get_caller_identity() + principal_arn = caller_identity['Arn'] + account_id = caller_identity['Account'] + except Exception as e: + logger.error(f"Unable to get caller identity: {str(e)}") + # If we can't get the caller identity, return default model + model_list[DEFAULT_MODEL] = {"modalities": ["TEXT", "IMAGE"]} + return model_list + + # First, collect all potential models + potential_models = {} + model_arns = [] + + # Get inference profiles if enabled profile_list = [] if ENABLE_CROSS_REGION_INFERENCE: # List system defined inference profile IDs @@ -105,19 +202,50 @@ def list_bedrock_models() -> dict: input_modalities = model["inputModalities"] # Add on-demand model list if "ON_DEMAND" in inference_types: - model_list[model_id] = {"modalities": input_modalities} + model_arn = f"arn:aws:bedrock:{AWS_REGION}::foundation-model/{model_id}" + potential_models[model_arn] = { + "id": model_id, + "modalities": input_modalities + } + model_arns.append(model_arn) # Add cross-region inference model list. profile_id = cr_inference_prefix + "." + model_id if profile_id in profile_list: - model_list[profile_id] = {"modalities": input_modalities} + profile_arn = f"arn:aws:bedrock:{AWS_REGION}:{account_id}:inference-profile/{profile_id}" + potential_models[profile_arn] = { + "id": profile_id, + "modalities": input_modalities + } + model_arns.append(profile_arn) + + # Check permissions for all models in batches + if model_arns: + # Split into batches if needed (API might have limits on number of resources) + batch_size = 20 # Adjust based on API limits + permissions = {} + + deduped = list(set(model_arns)) + + for i in range(0, len(deduped), batch_size): + batch = deduped[i:i+batch_size] + batch_permissions = check_model_permissions(principal_arn, batch) + permissions.update(batch_permissions) + + # Build final model list based on permissions + for model_arn, has_permission in permissions.items(): + if has_permission and model_arn in potential_models: + model_info = potential_models[model_arn] + model_list[model_info["id"]] = {"modalities": model_info["modalities"]} except Exception as e: logger.error(f"Unable to list models: {str(e)}") if not model_list: - # In case stack not updated. + # In case stack not updated or permissions check failed model_list[DEFAULT_MODEL] = {"modalities": ["TEXT", "IMAGE"]} + + logger.info("final model_list:" + repr(model_list)) return model_list