Skip to content

Commit 00f029e

Browse files
Format files
1 parent 8363b22 commit 00f029e

File tree

3 files changed

+82
-64
lines changed

3 files changed

+82
-64
lines changed

spacy_llm/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .bedrock import titan_express, titan_lite
12
from .hf import dolly_hf, openllama_hf, stablelm_hf
23
from .langchain import query_langchain
34
from .rest import anthropic, cohere, noop, openai, palm
@@ -12,4 +13,6 @@
1213
"openllama_hf",
1314
"palm",
1415
"query_langchain",
16+
"titan_lite",
17+
"titan_express",
1518
]

spacy_llm/models/bedrock/model.py

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,108 @@
1-
import os
21
import json
2+
import os
33
import warnings
44
from enum import Enum
5-
import requests
6-
from requests import HTTPError
7-
from typing import Any, Dict, Iterable, Optional, Type, List, Sized, Tuple
8-
9-
from confection import SimpleFrozenDict
5+
from typing import Any, Dict, Iterable, List, Optional
106

11-
from ...registry import registry
12-
13-
try:
14-
import boto3
15-
import botocore
16-
from botocore.config import Config
17-
except ImportError as err:
18-
print("To use Bedrock, you need to install boto3. Use `pip install boto3` ")
19-
raise err
207

218
class Models(str, Enum):
229
# Completion models
2310
TITAN_EXPRESS = "amazon.titan-text-express-v1"
2411
TITAN_LITE = "amazon.titan-text-lite-v1"
2512

26-
class Bedrock():
13+
14+
class Bedrock:
2715
def __init__(
28-
self,
29-
model_id: str,
30-
region: str,
31-
config: Dict[Any, Any],
32-
max_retries: int = 5
16+
self, model_id: str, region: str, config: Dict[Any, Any], max_retries: int = 5
3317
):
34-
3518
self._region = region
3619
self._model_id = model_id
3720
self._config = config
3821
self._max_retries = max_retries
39-
40-
# @property
41-
def get_session(self):
22+
23+
def get_session_kwargs(self) -> Dict[str, Optional[str]]:
4224

4325
# Fetch and check the credentials
44-
profile = os.getenv("AWS_PROFILE") if not None else ""
26+
profile = os.getenv("AWS_PROFILE") if not None else ""
4527
secret_key_id = os.getenv("AWS_ACCESS_KEY_ID")
4628
secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
4729
session_token = os.getenv("AWS_SESSION_TOKEN")
4830

4931
if profile is None:
5032
warnings.warn(
5133
"Could not find the AWS_PROFILE to access the Amazon Bedrock . Ensure you have an AWS_PROFILE "
52-
"set up by making it available as an environment variable 'AWS_PROFILE'."
53-
)
34+
"set up by making it available as an environment variable AWS_PROFILE."
35+
)
5436

5537
if secret_key_id is None:
5638
warnings.warn(
5739
"Could not find the AWS_ACCESS_KEY_ID to access the Amazon Bedrock . Ensure you have an AWS_ACCESS_KEY_ID "
58-
"set up by making it available as an environment variable 'AWS_ACCESS_KEY_ID'."
40+
"set up by making it available as an environment variable AWS_ACCESS_KEY_ID."
5941
)
42+
6043
if secret_access_key is None:
6144
warnings.warn(
6245
"Could not find the AWS_SECRET_ACCESS_KEY to access the Amazon Bedrock . Ensure you have an AWS_SECRET_ACCESS_KEY "
63-
"set up by making it available as an environment variable 'AWS_SECRET_ACCESS_KEY'."
46+
"set up by making it available as an environment variable AWS_SECRET_ACCESS_KEY."
6447
)
48+
6549
if session_token is None:
6650
warnings.warn(
6751
"Could not find the AWS_SESSION_TOKEN to access the Amazon Bedrock . Ensure you have an AWS_SESSION_TOKEN "
68-
"set up by making it available as an environment variable 'AWS_SESSION_TOKEN'."
52+
"set up by making it available as an environment variable AWS_SESSION_TOKEN."
6953
)
7054

7155
assert secret_key_id is not None
7256
assert secret_access_key is not None
7357
assert session_token is not None
74-
75-
session_kwargs = {"profile_name":profile, "region_name":self._region, "aws_access_key_id":secret_key_id, "aws_secret_access_key":secret_access_key, "aws_session_token":session_token}
76-
bedrock = boto3.Session(**session_kwargs)
77-
return bedrock
7858

79-
def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
59+
session_kwargs = {
60+
"profile_name": profile,
61+
"region_name": self._region,
62+
"aws_access_key_id": secret_key_id,
63+
"aws_secret_access_key": secret_access_key,
64+
"aws_session_token": session_token,
65+
}
66+
return session_kwargs
67+
68+
def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
8069
api_responses: List[str] = []
8170
prompts = list(prompts)
82-
api_config = Config(retries = dict(max_attempts = self._max_retries))
8371

84-
def _request(json_data: Dict[str, Any]) -> str:
85-
session = self.get_session()
86-
print("Session:", session)
72+
def _request(json_data: str) -> str:
73+
try:
74+
import boto3
75+
except ImportError as err:
76+
warnings.warn(
77+
"To use Bedrock, you need to install boto3. Use pip install boto3 "
78+
)
79+
raise err
80+
from botocore.config import Config
81+
82+
session_kwargs = self.get_session_kwargs()
83+
session = boto3.Session(**session_kwargs)
84+
api_config = Config(retries=dict(max_attempts=self._max_retries))
8785
bedrock = session.client(service_name="bedrock-runtime", config=api_config)
88-
accept = 'application/json'
89-
contentType = 'application/json'
90-
r = bedrock.invoke_model(body=json_data, modelId=self._model_id, accept=accept, contentType=contentType)
91-
responses = json.loads(r['body'].read().decode())['results'][0]['outputText']
86+
accept = "application/json"
87+
contentType = "application/json"
88+
r = bedrock.invoke_model(
89+
body=json_data,
90+
modelId=self._model_id,
91+
accept=accept,
92+
contentType=contentType,
93+
)
94+
responses = json.loads(r["body"].read().decode())["results"][0][
95+
"outputText"
96+
]
9297
return responses
9398

9499
for prompt in prompts:
95100
if self._model_id in [Models.TITAN_LITE, Models.TITAN_EXPRESS]:
96-
responses = _request(json.dumps({"inputText": prompt, "textGenerationConfig":self._config}))
97-
if "error" in responses:
98-
return responses["error"]
101+
responses = _request(
102+
json.dumps(
103+
{"inputText": prompt, "textGenerationConfig": self._config}
104+
)
105+
)
99106

100107
api_responses.append(responses)
101108

spacy_llm/models/bedrock/registry.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,50 @@
1-
from typing import Any, Callable, Dict, Iterable
1+
from typing import Any, Callable, Dict, Iterable, List
22

33
from confection import SimpleFrozenDict
44

55
from ...registry import registry
66
from .model import Bedrock, Models
77

8-
_DEFAULT_RETRIES = 5
9-
_DEFAULT_TEMPERATURE = 0.0
10-
_DEFAULT_MAX_TOKEN_COUNT = 512
11-
_DEFAULT_TOP_P = 1
12-
_DEFAULT_STOP_SEQUENCES = []
8+
_DEFAULT_RETRIES: int = 5
9+
_DEFAULT_TEMPERATURE: float = 0.0
10+
_DEFAULT_MAX_TOKEN_COUNT: int = 512
11+
_DEFAULT_TOP_P: int = 1
12+
_DEFAULT_STOP_SEQUENCES: List[str] = []
13+
1314

1415
@registry.llm_models("spacy.Bedrock.Titan.Express.v1")
1516
def titan_express(
1617
region: str,
1718
model_id: Models = Models.TITAN_EXPRESS,
18-
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE, maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, stopSequences=_DEFAULT_STOP_SEQUENCES, topP =_DEFAULT_TOP_P),
19-
max_retries: int = _DEFAULT_RETRIES
19+
config: Dict[Any, Any] = SimpleFrozenDict(
20+
temperature=_DEFAULT_TEMPERATURE,
21+
maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT,
22+
stopSequences=_DEFAULT_STOP_SEQUENCES,
23+
topP=_DEFAULT_TOP_P,
24+
),
25+
max_retries: int = _DEFAULT_RETRIES,
2026
) -> Callable[[Iterable[str]], Iterable[str]]:
2127
"""Returns Bedrock instance for 'amazon-titan-express' model using boto3 to prompt API.
2228
model_id (ModelId): ID of the deployed model (titan-express)
2329
region (str): Specify the AWS region for the service
2430
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
2531
"""
2632
return Bedrock(
27-
model_id = model_id,
28-
region = region,
29-
config=config,
30-
max_retries=max_retries
33+
model_id=model_id, region=region, config=config, max_retries=max_retries
3134
)
3235

36+
3337
@registry.llm_models("spacy.Bedrock.Titan.Lite.v1")
3438
def titan_lite(
3539
region: str,
3640
model_id: Models = Models.TITAN_LITE,
37-
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE, maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, stopSequences=_DEFAULT_STOP_SEQUENCES, topP =_DEFAULT_TOP_P),
38-
max_retries: int = _DEFAULT_RETRIES
41+
config: Dict[Any, Any] = SimpleFrozenDict(
42+
temperature=_DEFAULT_TEMPERATURE,
43+
maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT,
44+
stopSequences=_DEFAULT_STOP_SEQUENCES,
45+
topP=_DEFAULT_TOP_P,
46+
),
47+
max_retries: int = _DEFAULT_RETRIES,
3948
) -> Callable[[Iterable[str]], Iterable[str]]:
4049
"""Returns Bedrock instance for 'amazon-titan-lite' model using boto3 to prompt API.
4150
region (str): Specify the AWS region for the service
@@ -44,9 +53,8 @@ def titan_lite(
4453
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
4554
"""
4655
return Bedrock(
47-
model_id = model_id,
48-
region = region,
56+
model_id=model_id,
57+
region=region,
4958
config=config,
5059
max_retries=max_retries,
5160
)
52-

0 commit comments

Comments
 (0)