Skip to content

Commit a0982a4

Browse files
author
Sean Smith
committed
Support imported models
Signed-off-by: Sean Smith <[email protected]>
1 parent d025efe commit a0982a4

File tree

4 files changed

+79
-15
lines changed

4 files changed

+79
-15
lines changed

deployment/BedrockProxy.template renamed to deployment/BedrockProxy.yaml

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,19 @@ Parameters:
88
Type: String
99
Default: anthropic.claude-3-sonnet-20240229-v1:0
1010
Description: The default model ID, please make sure the model ID is supported in the current region
11+
ImageUri:
12+
Type: String
13+
Default: ""
14+
Description: Specify a custom ECR image, if left blank defaults to 366590864501.dkr.ecr.us-east-1.amazonaws.com/bedrock-proxy-api:latest.
15+
EnableImportedModels:
16+
Type: String
17+
Default: false
18+
AllowedValues:
19+
- true
20+
- false
21+
Description: If enabled, models imported into Bedrock will be available to use.
22+
Conditions:
23+
UseDefaultImage: !Equals [!Ref ImageUri, ""]
1124
Resources:
1225
VPCB9E5F0B4:
1326
Type: AWS::EC2::VPC
@@ -142,6 +155,7 @@ Resources:
142155
- Action:
143156
- bedrock:ListFoundationModels
144157
- bedrock:ListInferenceProfiles
158+
- bedrock:ListImportedModels
145159
Effect: Allow
146160
Resource: "*"
147161
- Action:
@@ -151,6 +165,7 @@ Resources:
151165
Resource:
152166
- arn:aws:bedrock:*::foundation-model/*
153167
- arn:aws:bedrock:*:*:inference-profile/*
168+
- arn:aws:bedrock:*:*:imported-model/*
154169
- Action:
155170
- secretsmanager:GetSecretValue
156171
- secretsmanager:DescribeSecret
@@ -167,14 +182,16 @@ Resources:
167182
Architectures:
168183
- arm64
169184
Code:
170-
ImageUri:
171-
Fn::Join:
185+
ImageUri: !If
186+
- UseDefaultImage
187+
- !Join
172188
- ""
173-
- - 366590864501.dkr.ecr.
174-
- Ref: AWS::Region
189+
- - "366590864501.dkr.ecr."
190+
- !Ref AWS::Region
175191
- "."
176-
- Ref: AWS::URLSuffix
177-
- /bedrock-proxy-api:latest
192+
- !Ref AWS::URLSuffix
193+
- "/bedrock-proxy-api:latest"
194+
- !Ref ImageUri
178195
Description: Bedrock Proxy API Handler
179196
Environment:
180197
Variables:
@@ -185,6 +202,7 @@ Resources:
185202
Ref: DefaultModelId
186203
DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3
187204
ENABLE_CROSS_REGION_INFERENCE: "true"
205+
ENABLE_IMPORTED_MODELS: !Ref EnableImportedModels
188206
MemorySize: 1024
189207
PackageType: Image
190208
Role:

deployment/BedrockProxyFargate.template renamed to deployment/BedrockProxyFargate.yaml

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,19 @@ Parameters:
88
Type: String
99
Default: anthropic.claude-3-sonnet-20240229-v1:0
1010
Description: The default model ID, please make sure the model ID is supported in the current region
11+
ImageUri:
12+
Type: String
13+
Default: ""
14+
Description: Specify a custom ECR image, if left blank defaults to 366590864501.dkr.ecr.us-east-1.amazonaws.com/bedrock-proxy-api:latest.
15+
EnableImportedModels:
16+
Type: String
17+
Default: false
18+
AllowedValues:
19+
- true
20+
- false
21+
Description: If enabled, models imported into Bedrock will be available to use.
22+
Conditions:
23+
UseDefaultImage: !Equals [!Ref ImageUri, ""]
1124
Resources:
1225
VPCB9E5F0B4:
1326
Type: AWS::EC2::VPC
@@ -184,6 +197,7 @@ Resources:
184197
- Action:
185198
- bedrock:ListFoundationModels
186199
- bedrock:ListInferenceProfiles
200+
- bedrock:ListImportedModels
187201
Effect: Allow
188202
Resource: "*"
189203
- Action:
@@ -193,6 +207,7 @@ Resources:
193207
Resource:
194208
- arn:aws:bedrock:*::foundation-model/*
195209
- arn:aws:bedrock:*:*:inference-profile/*
210+
- arn:aws:bedrock:*:*:imported-model/*
196211
Version: "2012-10-17"
197212
PolicyName: ProxyTaskRoleDefaultPolicy933321B8
198213
Roles:
@@ -222,15 +237,19 @@ Resources:
222237
Value: cohere.embed-multilingual-v3
223238
- Name: ENABLE_CROSS_REGION_INFERENCE
224239
Value: "true"
240+
- Name: ENABLE_IMPORTED_MODELS
241+
Value: !Ref EnableImportedModels
225242
Essential: true
226-
Image:
227-
Fn::Join:
228-
- ""
229-
- - 366590864501.dkr.ecr.
230-
- Ref: AWS::Region
231-
- "."
232-
- Ref: AWS::URLSuffix
233-
- /bedrock-proxy-api-ecs:latest
243+
Image: !If
244+
- UseDefaultImage
245+
- !Join
246+
- ""
247+
- - "366590864501.dkr.ecr."
248+
- !Ref AWS::Region
249+
- "."
250+
- !Ref AWS::URLSuffix
251+
- "/bedrock-proxy-api:latest"
252+
- !Ref ImageUri
234253
Name: proxy-api
235254
PortMappings:
236255
- ContainerPort: 80

src/api/models/bedrock.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
Embedding,
3939

4040
)
41-
from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL
41+
from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL, ENABLE_IMPORTED_MODELS
4242

4343
logger = logging.getLogger(__name__)
4444

@@ -99,6 +99,18 @@ def list_bedrock_models() -> dict:
9999
byOutputModality='TEXT'
100100
)
101101

102+
# Add imported models to the list if ENABLE_IMPORTED_MODELS is true
103+
if ENABLE_IMPORTED_MODELS:
104+
response_imported = bedrock_client.list_imported_models()
105+
print(response_imported)
106+
107+
# Add imported models to the default model list
108+
for model in response_imported['modelSummaries']:
109+
model_id = model.get('modelName')
110+
model_list[f"custom.{model_id}"] = {
111+
'modalities': ["TEXT"]
112+
}
113+
102114
for model in response['modelSummaries']:
103115
model_id = model.get('modelId', 'N/A')
104116
stream_supported = model.get('responseStreamingSupported', True)
@@ -170,6 +182,20 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
170182
if DEBUG:
171183
logger.info("Bedrock request: " + json.dumps(str(args)))
172184

185+
if args["modelId"].startswith("custom."):
186+
# For custom models, get the model ARN by listing models and finding matching name
187+
model_name = args["modelId"].replace("custom.", "")
188+
response = bedrock_client.list_imported_models()
189+
for model in response["modelSummaries"]:
190+
if model["modelName"] == model_name:
191+
args["modelId"] = model["modelArn"]
192+
break
193+
else:
194+
raise HTTPException(
195+
status_code=404,
196+
detail=f"Custom model {model_name} not found"
197+
)
198+
173199
try:
174200
if stream:
175201
response = bedrock_runtime.converse_stream(**args)

src/api/setting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
"DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3"
2121
)
2222
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
23+
ENABLE_IMPORTED_MODELS = os.environ.get("ENABLE_IMPORTED_MODELS", "true").lower() != "false"

0 commit comments

Comments
 (0)