Skip to content

Commit c5ec2ff

Browse files
Add Anthropic in Bedrock
1 parent fe38a95 commit c5ec2ff

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

spacy_llm/models/rest/bedrock/model.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class Models(str, Enum):
1414
AI21_JURASSIC_ULTRA = "ai21.j2-ultra-v1"
1515
AI21_JURASSIC_MID = "ai21.j2-mid-v1"
1616
COHERE_COMMAND = "cohere.command-text-v14"
17+
ANTHROPIC_CLAUDE = "anthropic.claude-v2"
18+
ANTHROPIC_CLAUDE_INSTANT = "anthropic.claude-instant-v1"
1719

1820

1921
TITAN_PARAMS = ["maxTokenCount", "stopSequences", "temperature", "topP"]
@@ -26,6 +28,13 @@ class Models(str, Enum):
2628
"frequencyPenalty",
2729
]
2830
COHERE_PARAMS = ["max_tokens", "temperature"]
31+
ANTHROPIC_PARAMS = [
32+
"max_tokens_to_sample",
33+
"temperature",
34+
"top_k",
35+
"top_p",
36+
"stop_sequences",
37+
]
2938

3039

3140
class Bedrock(REST):
@@ -49,6 +58,8 @@ def __init__(
4958
config_params = AI21_JURASSIC_PARAMS
5059
if self._model_id in [Models.COHERE_COMMAND]:
5160
config_params = COHERE_PARAMS
61+
if self._model_id in [Models.ANTHROPIC_CLAUDE_INSTANT, Models.ANTHROPIC_CLAUDE]:
62+
config_params = ANTHROPIC_PARAMS
5263

5364
for i in config_params:
5465
self._config[i] = config[i]
@@ -149,6 +160,11 @@ def _request(json_data: str) -> str:
149160
responses = json.loads(r["body"].read().decode())["generations"][0][
150161
"text"
151162
]
163+
elif self._model_id in [
164+
Models.ANTHROPIC_CLAUDE_INSTANT,
165+
Models.ANTHROPIC_CLAUDE,
166+
]:
167+
responses = json.loads(r["body"].read().decode())["completion"]
152168

153169
return responses
154170

@@ -166,7 +182,15 @@ def _request(json_data: str) -> str:
166182
responses = _request(json.dumps({"prompt": prompt, **self._config}))
167183
elif self._model_id in [Models.COHERE_COMMAND]:
168184
responses = _request(json.dumps({"prompt": prompt, **self._config}))
169-
185+
elif self._model_id in [
186+
Models.ANTHROPIC_CLAUDE_INSTANT,
187+
Models.ANTHROPIC_CLAUDE,
188+
]:
189+
responses = _request(
190+
json.dumps(
191+
{"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **self._config}
192+
)
193+
)
170194
api_responses.append(responses)
171195

172196
return api_responses
@@ -195,4 +219,6 @@ def get_model_names(self) -> Tuple[str, ...]:
195219
"ai21.j2-ultra-v1",
196220
"ai21.j2-mid-v1",
197221
"cohere.command-text-v14",
222+
"anthropic.claude-v2",
223+
"anthropic.claude-instant-v1",
198224
)

spacy_llm/models/rest/bedrock/registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
_DEFAULT_TEMPERATURE: float = 0.0
1010
_DEFAULT_MAX_TOKEN_COUNT: int = 512
1111
_DEFAULT_TOP_P: int = 1
12+
_DEFAULT_TOP_K: int = 250
1213
_DEFAULT_STOP_SEQUENCES: List[str] = []
1314
_DEFAULT_COUNT_PENALTY: Dict[str, Any] = {"scale": 0}
1415
_DEFAULT_PRESENCE_PENALTY: Dict[str, Any] = {"scale": 0}
1516
_DEFAULT_FREQUENCY_PENALTY: Dict[str, Any] = {"scale": 0}
17+
_DEFAULT_MAX_TOKEN_TO_SAMPLE: int = 300
1618

1719

1820
@registry.llm_models("spacy.Bedrock.v1")
@@ -33,6 +35,10 @@ def bedrock(
3335
stop_sequences=_DEFAULT_STOP_SEQUENCES,
3436
# Params for Cohere models
3537
max_tokens=_DEFAULT_MAX_TOKEN_COUNT,
38+
# Params for Anthropic models
39+
max_tokens_to_sample=_DEFAULT_MAX_TOKEN_TO_SAMPLE,
40+
top_k=_DEFAULT_TOP_K,
41+
top_p=_DEFAULT_TOP_P,
3642
),
3743
max_tries: int = _DEFAULT_RETRIES,
3844
) -> Callable[[Iterable[str]], Iterable[str]]:

0 commit comments

Comments
 (0)