@@ -14,6 +14,8 @@ class Models(str, Enum):
1414 AI21_JURASSIC_ULTRA = "ai21.j2-ultra-v1"
1515 AI21_JURASSIC_MID = "ai21.j2-mid-v1"
1616 COHERE_COMMAND = "cohere.command-text-v14"
17+ ANTHROPIC_CLAUDE = "anthropic.claude-v2"
18+ ANTHROPIC_CLAUDE_INSTANT = "anthropic.claude-instant-v1"
1719
1820
1921TITAN_PARAMS = ["maxTokenCount" , "stopSequences" , "temperature" , "topP" ]
@@ -26,6 +28,13 @@ class Models(str, Enum):
2628 "frequencyPenalty" ,
2729]
2830COHERE_PARAMS = ["max_tokens" , "temperature" ]
31+ ANTHROPIC_PARAMS = [
32+ "max_tokens_to_sample" ,
33+ "temperature" ,
34+ "top_k" ,
35+ "top_p" ,
36+ "stop_sequences" ,
37+ ]
2938
3039
3140class Bedrock (REST ):
@@ -49,6 +58,8 @@ def __init__(
4958 config_params = AI21_JURASSIC_PARAMS
5059 if self ._model_id in [Models .COHERE_COMMAND ]:
5160 config_params = COHERE_PARAMS
61+ if self ._model_id in [Models .ANTHROPIC_CLAUDE_INSTANT , Models .ANTHROPIC_CLAUDE ]:
62+ config_params = ANTHROPIC_PARAMS
5263
5364 for i in config_params :
5465 self ._config [i ] = config [i ]
@@ -149,6 +160,11 @@ def _request(json_data: str) -> str:
149160 responses = json .loads (r ["body" ].read ().decode ())["generations" ][0 ][
150161 "text"
151162 ]
163+ elif self ._model_id in [
164+ Models .ANTHROPIC_CLAUDE_INSTANT ,
165+ Models .ANTHROPIC_CLAUDE ,
166+ ]:
167+ responses = json .loads (r ["body" ].read ().decode ())["completion" ]
152168
153169 return responses
154170
@@ -166,7 +182,15 @@ def _request(json_data: str) -> str:
166182 responses = _request (json .dumps ({"prompt" : prompt , ** self ._config }))
167183 elif self ._model_id in [Models .COHERE_COMMAND ]:
168184 responses = _request (json .dumps ({"prompt" : prompt , ** self ._config }))
169-
185+ elif self ._model_id in [
186+ Models .ANTHROPIC_CLAUDE_INSTANT ,
187+ Models .ANTHROPIC_CLAUDE ,
188+ ]:
189+ responses = _request (
190+ json .dumps (
191+ {"prompt" : f"\n \n Human: { prompt } \n \n Assistant:" , ** self ._config }
192+ )
193+ )
170194 api_responses .append (responses )
171195
172196 return api_responses
@@ -195,4 +219,6 @@ def get_model_names(self) -> Tuple[str, ...]:
195219 "ai21.j2-ultra-v1" ,
196220 "ai21.j2-mid-v1" ,
197221 "cohere.command-text-v14" ,
222+ "anthropic.claude-v2" ,
223+ "anthropic.claude-instant-v1" ,
198224 )
0 commit comments