Skip to content

feat: filter model list by permissions #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deployment/BedrockProxy.template
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ Resources:
Properties:
PolicyDocument:
Statement:
- Action:
- iam:SimulatePrincipalPolicy
Effect: Allow
Resource: "*"
- Action:
- bedrock:ListFoundationModels
- bedrock:ListInferenceProfiles
Expand Down
4 changes: 4 additions & 0 deletions deployment/BedrockProxyFargate.template
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ Resources:
Properties:
PolicyDocument:
Statement:
- Action:
- iam:SimulatePrincipalPolicy
Effect: Allow
Resource: "*"
- Action:
- bedrock:ListFoundationModels
- bedrock:ListInferenceProfiles
Expand Down
140 changes: 134 additions & 6 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import base64
import json
import logging
import random
import re
import time
from abc import ABC
from typing import AsyncIterable, Iterable, Literal
from typing import AsyncIterable, Dict, Iterable, List, Literal

import boto3
import numpy as np
import requests
import tiktoken
from botocore.config import Config
from botocore.exceptions import ClientError
from fastapi import HTTPException
from starlette.concurrency import run_in_threadpool

Expand Down Expand Up @@ -53,6 +55,10 @@
config=config,
)

# Create IAM and STS clients for permission checking
iam_client = boto3.client('iam', region_name=AWS_REGION, config=config)
sts_client = boto3.client('sts', region_name=AWS_REGION, config=config)


def get_inference_region_prefix():
if AWS_REGION.startswith("ap-"):
Expand All @@ -74,15 +80,106 @@ def get_inference_region_prefix():
ENCODER = tiktoken.get_encoding("cl100k_base")


def check_model_permissions(principal_arn, model_arns):
"""Check permissions for multiple models in a single API call.

Args:
principal_arn: The ARN of the principal (user/role)
model_arns: List of model ARNs to check

Returns:
dict: Mapping of model ARNs to boolean permission status
"""
# Convert assumed role ARN to IAM role ARN if needed
if ":assumed-role/" in principal_arn:
parts = principal_arn.split(':')
account_id = parts[4]
role_session = parts[5].split('/')
role_name = role_session[1]
principal_arn = f"arn:aws:iam::{account_id}:role/{role_name}"

# Initialize results dictionary
permissions = {arn: False for arn in model_arns}

try:
logger.info(f"Checking permissions for {len(model_arns)} models")

# Both actions needed for streaming models
actions = ["bedrock:InvokeModel", "bedrock:InvokeModelWithResponseStream"]

# Implement retry with exponential backoff
max_retries = 5
retry_delay = 1
last_exception = None

for attempt in range(max_retries):
try:
response = iam_client.simulate_principal_policy(
PolicySourceArn=principal_arn,
ActionNames=actions,
ResourceArns=model_arns
)

# Process resource-specific results directly
for result in response['EvaluationResults']:
for resource_result in result['ResourceSpecificResults']:
resource = resource_result['EvalResourceName']
if resource in permissions and resource_result['EvalResourceDecision'] == 'allowed':
permissions[resource] = True

# Successfully got and processed the response
break

except ClientError as e:
last_exception = e
error_code = e.response['Error']['Code']

if error_code == 'Throttling' and attempt < max_retries - 1:
# Exponential backoff with jitter
sleep_time = retry_delay * (2 ** attempt) + random.uniform(0, 1)
logger.warning(f"Throttling detected, retrying in {sleep_time:.2f} seconds...")
time.sleep(sleep_time)
else:
# Either not a throttling error or we've exhausted retries
raise

if last_exception is not None and attempt == max_retries - 1:
# If we've used all retries and still have an exception
raise last_exception

return permissions

except Exception as e:
logger.error(f"Error checking model permissions: {str(e)}", exc_info=True)
# If we can't check permissions, default to not allowing any
return permissions


def list_bedrock_models() -> dict:
"""Automatically getting a list of supported models.
"""Automatically getting a list of supported models that the user has permission to invoke.

Returns a model list combines:
- ON_DEMAND models.
- ON_DEMAND models
- Cross-Region Inference Profiles (if enabled via Env)
"""
model_list = {}
try:
# Get current role/user ARN
try:
caller_identity = sts_client.get_caller_identity()
principal_arn = caller_identity['Arn']
account_id = caller_identity['Account']
except Exception as e:
logger.error(f"Unable to get caller identity: {str(e)}")
# If we can't get the caller identity, return default model
model_list[DEFAULT_MODEL] = {"modalities": ["TEXT", "IMAGE"]}
return model_list

# First, collect all potential models
potential_models = {}
model_arns = []

# Get inference profiles if enabled
profile_list = []
if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs
Expand All @@ -105,19 +202,50 @@ def list_bedrock_models() -> dict:
input_modalities = model["inputModalities"]
# Add on-demand model list
if "ON_DEMAND" in inference_types:
model_list[model_id] = {"modalities": input_modalities}
model_arn = f"arn:aws:bedrock:{AWS_REGION}::foundation-model/{model_id}"
potential_models[model_arn] = {
"id": model_id,
"modalities": input_modalities
}
model_arns.append(model_arn)

# Add cross-region inference model list.
profile_id = cr_inference_prefix + "." + model_id
if profile_id in profile_list:
model_list[profile_id] = {"modalities": input_modalities}
profile_arn = f"arn:aws:bedrock:{AWS_REGION}:{account_id}:inference-profile/{profile_id}"
potential_models[profile_arn] = {
"id": profile_id,
"modalities": input_modalities
}
model_arns.append(profile_arn)

# Check permissions for all models in batches
if model_arns:
# Split into batches if needed (API might have limits on number of resources)
batch_size = 20 # Adjust based on API limits
permissions = {}

deduped = list(set(model_arns))

for i in range(0, len(deduped), batch_size):
batch = deduped[i:i+batch_size]
batch_permissions = check_model_permissions(principal_arn, batch)
permissions.update(batch_permissions)

# Build final model list based on permissions
for model_arn, has_permission in permissions.items():
if has_permission and model_arn in potential_models:
model_info = potential_models[model_arn]
model_list[model_info["id"]] = {"modalities": model_info["modalities"]}

except Exception as e:
logger.error(f"Unable to list models: {str(e)}")

if not model_list:
# In case stack not updated.
# In case stack not updated or permissions check failed
model_list[DEFAULT_MODEL] = {"modalities": ["TEXT", "IMAGE"]}

logger.info("final model_list:" + repr(model_list))

return model_list

Expand Down