|
1 | | -import os |
2 | 1 | import json |
| 2 | +import os |
3 | 3 | import warnings |
4 | 4 | from enum import Enum |
5 | | -import requests |
6 | | -from requests import HTTPError |
7 | | -from typing import Any, Dict, Iterable, Optional, Type, List, Sized, Tuple |
8 | | - |
9 | | -from confection import SimpleFrozenDict |
| 5 | +from typing import Any, Dict, Iterable, List, Optional |
10 | 6 |
|
11 | | -from ...registry import registry |
12 | | - |
13 | | -try: |
14 | | - import boto3 |
15 | | - import botocore |
16 | | - from botocore.config import Config |
17 | | -except ImportError as err: |
18 | | - print("To use Bedrock, you need to install boto3. Use `pip install boto3` ") |
19 | | - raise err |
20 | 7 |
|
21 | 8 | class Models(str, Enum): |
22 | 9 | # Completion models |
23 | 10 | TITAN_EXPRESS = "amazon.titan-text-express-v1" |
24 | 11 | TITAN_LITE = "amazon.titan-text-lite-v1" |
25 | 12 |
|
26 | | -class Bedrock(): |
| 13 | + |
| 14 | +class Bedrock: |
27 | 15 | def __init__( |
28 | | - self, |
29 | | - model_id: str, |
30 | | - region: str, |
31 | | - config: Dict[Any, Any], |
32 | | - max_retries: int = 5 |
| 16 | + self, model_id: str, region: str, config: Dict[Any, Any], max_retries: int = 5 |
33 | 17 | ): |
34 | | - |
35 | 18 | self._region = region |
36 | 19 | self._model_id = model_id |
37 | 20 | self._config = config |
38 | 21 | self._max_retries = max_retries |
39 | | - |
40 | | - # @property |
41 | | - def get_session(self): |
| 22 | + |
| 23 | + def get_session_kwargs(self) -> Dict[str, Optional[str]]: |
42 | 24 |
|
43 | 25 | # Fetch and check the credentials |
44 | | - profile = os.getenv("AWS_PROFILE") if not None else "" |
| 26 | + profile = os.getenv("AWS_PROFILE") if not None else "" |
45 | 27 | secret_key_id = os.getenv("AWS_ACCESS_KEY_ID") |
46 | 28 | secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY") |
47 | 29 | session_token = os.getenv("AWS_SESSION_TOKEN") |
48 | 30 |
|
49 | 31 | if profile is None: |
50 | 32 | warnings.warn( |
51 | 33 | "Could not find the AWS_PROFILE to access the Amazon Bedrock . Ensure you have an AWS_PROFILE " |
52 | | - "set up by making it available as an environment variable 'AWS_PROFILE'." |
53 | | - ) |
| 34 | + "set up by making it available as an environment variable AWS_PROFILE." |
| 35 | + ) |
54 | 36 |
|
55 | 37 | if secret_key_id is None: |
56 | 38 | warnings.warn( |
57 | 39 | "Could not find the AWS_ACCESS_KEY_ID to access the Amazon Bedrock . Ensure you have an AWS_ACCESS_KEY_ID " |
58 | | - "set up by making it available as an environment variable 'AWS_ACCESS_KEY_ID'." |
| 40 | + "set up by making it available as an environment variable AWS_ACCESS_KEY_ID." |
59 | 41 | ) |
| 42 | + |
60 | 43 | if secret_access_key is None: |
61 | 44 | warnings.warn( |
62 | 45 | "Could not find the AWS_SECRET_ACCESS_KEY to access the Amazon Bedrock . Ensure you have an AWS_SECRET_ACCESS_KEY " |
63 | | - "set up by making it available as an environment variable 'AWS_SECRET_ACCESS_KEY'." |
| 46 | + "set up by making it available as an environment variable AWS_SECRET_ACCESS_KEY." |
64 | 47 | ) |
| 48 | + |
65 | 49 | if session_token is None: |
66 | 50 | warnings.warn( |
67 | 51 | "Could not find the AWS_SESSION_TOKEN to access the Amazon Bedrock . Ensure you have an AWS_SESSION_TOKEN " |
68 | | - "set up by making it available as an environment variable 'AWS_SESSION_TOKEN'." |
| 52 | + "set up by making it available as an environment variable AWS_SESSION_TOKEN." |
69 | 53 | ) |
70 | 54 |
|
71 | 55 | assert secret_key_id is not None |
72 | 56 | assert secret_access_key is not None |
73 | 57 | assert session_token is not None |
74 | | - |
75 | | - session_kwargs = {"profile_name":profile, "region_name":self._region, "aws_access_key_id":secret_key_id, "aws_secret_access_key":secret_access_key, "aws_session_token":session_token} |
76 | | - bedrock = boto3.Session(**session_kwargs) |
77 | | - return bedrock |
78 | 58 |
|
79 | | - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: |
| 59 | + session_kwargs = { |
| 60 | + "profile_name": profile, |
| 61 | + "region_name": self._region, |
| 62 | + "aws_access_key_id": secret_key_id, |
| 63 | + "aws_secret_access_key": secret_access_key, |
| 64 | + "aws_session_token": session_token, |
| 65 | + } |
| 66 | + return session_kwargs |
| 67 | + |
| 68 | + def __call__(self, prompts: Iterable[str]) -> Iterable[str]: |
80 | 69 | api_responses: List[str] = [] |
81 | 70 | prompts = list(prompts) |
82 | | - api_config = Config(retries = dict(max_attempts = self._max_retries)) |
83 | 71 |
|
84 | | - def _request(json_data: Dict[str, Any]) -> str: |
85 | | - session = self.get_session() |
86 | | - print("Session:", session) |
| 72 | + def _request(json_data: str) -> str: |
| 73 | + try: |
| 74 | + import boto3 |
| 75 | + except ImportError as err: |
| 76 | + warnings.warn( |
| 77 | + "To use Bedrock, you need to install boto3. Use pip install boto3 " |
| 78 | + ) |
| 79 | + raise err |
| 80 | + from botocore.config import Config |
| 81 | + |
| 82 | + session_kwargs = self.get_session_kwargs() |
| 83 | + session = boto3.Session(**session_kwargs) |
| 84 | + api_config = Config(retries=dict(max_attempts=self._max_retries)) |
87 | 85 | bedrock = session.client(service_name="bedrock-runtime", config=api_config) |
88 | | - accept = 'application/json' |
89 | | - contentType = 'application/json' |
90 | | - r = bedrock.invoke_model(body=json_data, modelId=self._model_id, accept=accept, contentType=contentType) |
91 | | - responses = json.loads(r['body'].read().decode())['results'][0]['outputText'] |
| 86 | + accept = "application/json" |
| 87 | + contentType = "application/json" |
| 88 | + r = bedrock.invoke_model( |
| 89 | + body=json_data, |
| 90 | + modelId=self._model_id, |
| 91 | + accept=accept, |
| 92 | + contentType=contentType, |
| 93 | + ) |
| 94 | + responses = json.loads(r["body"].read().decode())["results"][0][ |
| 95 | + "outputText" |
| 96 | + ] |
92 | 97 | return responses |
93 | 98 |
|
94 | 99 | for prompt in prompts: |
95 | 100 | if self._model_id in [Models.TITAN_LITE, Models.TITAN_EXPRESS]: |
96 | | - responses = _request(json.dumps({"inputText": prompt, "textGenerationConfig":self._config})) |
97 | | - if "error" in responses: |
98 | | - return responses["error"] |
| 101 | + responses = _request( |
| 102 | + json.dumps( |
| 103 | + {"inputText": prompt, "textGenerationConfig": self._config} |
| 104 | + ) |
| 105 | + ) |
99 | 106 |
|
100 | 107 | api_responses.append(responses) |
101 | 108 |
|
|
0 commit comments