From 8f3cf4ae37f6a109e5d5306ea47ced987b631092 Mon Sep 17 00:00:00 2001 From: Will Date: Mon, 16 May 2022 16:06:41 -0500 Subject: [PATCH 1/5] Updated requirements. --- requirements-dev.txt | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1b2b06323..377763dbc 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ +amazon-dax-client>=2.0.0,<3.0.0 pytest>=6 pytest-env pytest-mock diff --git a/setup.py b/setup.py index 825879c13..c8b55df76 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ ], extras_require={ 'signals': ['blinker>=1.3,<2.0'], + 'dax': ['amazon-dax-client>=2.0.0,<3.0.0'] }, package_data={'pynamodb': ['py.typed']}, ) From 7ef00e3ec8df76dffa205282756e70beb68445ec Mon Sep 17 00:00:00 2001 From: Will Date: Mon, 16 May 2022 16:07:13 -0500 Subject: [PATCH 2/5] Updated settings doc. --- docs/settings.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/settings.rst b/docs/settings.rst index 5d8482104..7df6907cc 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -78,6 +78,26 @@ Default: automatically constructed by boto to account for region The URL endpoint for DynamoDB. This can be used to use a local implementation of DynamoDB such as DynamoDB Local or dynalite. +dax_write_endpoints +------------------ + +Default: ``[]`` + +Connect to DAX endpoints for write operations. + +Supported Operations: PutItem, DeleteItem, UpdateItem, BatchWriteItem + + +dax_read_endpoints +------------------ + +Default: ``[]`` + +Connect to DAX endpoints for read operations. + +Supported Operations: GetItem, Scan, BatchGetItem, Query + + Overriding settings ~~~~~~~~~~~~~~~~~~~ From a1eaa75b97b335717061ca12f00f11ad54ba6180 Mon Sep 17 00:00:00 2001 From: Will Date: Mon, 16 May 2022 16:08:15 -0500 Subject: [PATCH 3/5] Added dax support. --- pynamodb/connection/base.py | 33 ++++++++++++++++++++++++++++++++- pynamodb/connection/dax.py | 33 +++++++++++++++++++++++++++++++++ pynamodb/connection/table.py | 10 ++++++++-- pynamodb/models.py | 11 ++++++++++- pynamodb/settings.py | 3 +++ 5 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 pynamodb/connection/dax.py diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index cca050f7d..51de8dcd7 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -48,6 +48,7 @@ VerboseClientError, TransactGetError, TransactWriteError) from pynamodb.expressions.condition import Condition +from pynamodb.connection.dax import DaxClient, OP_READ, OP_WRITE from pynamodb.expressions.operand import Path from pynamodb.expressions.projection import create_projection_expression from pynamodb.expressions.update import Action, Update @@ -247,11 +248,16 @@ def __init__(self, max_retry_attempts: Optional[int] = None, base_backoff_ms: Optional[int] = None, max_pool_connections: Optional[int] = None, - extra_headers: Optional[Mapping[str, str]] = None): + extra_headers: Optional[Mapping[str, str]] = None, + dax_write_endpoints: Optional[List[str]]=None, + dax_read_endpoints: Optional[List[str]]=None, + fallback_to_dynamodb: Optional[bool]=False): self._tables: Dict[str, MetaTable] = {} self.host = host self._local = local() self._client = None + self._dax_support = bool(dax_write_endpoints or dax_read_endpoints) + if region: self.region = region else: @@ -287,6 +293,19 @@ def __init__(self, else: self._extra_headers = get_settings_value('extra_headers') + if dax_write_endpoints is None: + dax_write_endpoints = get_settings_value('dax_write_endpoints') + self._dax_write_client = None if not dax_write_endpoints else DaxClient(endpoints=dax_write_endpoints, region_name=self.region) + + if dax_read_endpoints is None: + dax_read_endpoints = get_settings_value('dax_read_endpoints') + self._dax_read_client = None if not dax_read_endpoints else DaxClient(endpoints=dax_read_endpoints, region_name=self.region) + + if fallback_to_dynamodb is not None: + self._fallback_to_dynamodb = fallback_to_dynamodb + else: + self._fallback_to_dynamodb = get_settings_value('fallback_to_dynamodb') + def __repr__(self) -> str: return "Connection<{}>".format(self.client.meta.endpoint_url) @@ -354,6 +373,18 @@ def _make_api_call(self, operation_name: str, operation_kwargs: Dict, settings: 1. It's faster to avoid using botocore's response parsing 2. It provides a place to monkey patch HTTP requests for unit testing """ + if self._dax_support: + from amazondax.DaxError import DaxClientError + + try: + if operation_name in OP_WRITE and self._dax_write_client: + return self._dax_write_client.dispatch(operation_name, operation_kwargs) + elif operation_name in OP_READ and self._dax_read_client: + return self._dax_read_client.dispatch(operation_name, operation_kwargs) + except DaxClientError: + if not self._fallback_to_dynamodb: + raise + operation_model = self.client._service_model.operation_model(operation_name) request_dict = self.client._convert_to_request_dict( operation_kwargs, diff --git a/pynamodb/connection/dax.py b/pynamodb/connection/dax.py new file mode 100644 index 000000000..414c4db2a --- /dev/null +++ b/pynamodb/connection/dax.py @@ -0,0 +1,33 @@ +from typing import Dict, List + +OP_WRITE = { + 'PutItem': 'put_item', + 'DeleteItem': 'delete_item', + 'UpdateItem': 'update_item', + 'BatchWriteItem': 'batch_write_item', +} + +OP_READ = { + 'GetItem': 'get_item', + 'Scan': 'scan', + 'BatchGetItem': 'batch_get_item', + 'Query': 'query', +} + +OP_NAME_TO_METHOD = OP_WRITE.copy() +OP_NAME_TO_METHOD.update(OP_READ) + + +class DaxClient(object): + + def __init__(self, endpoints: List[str], region_name: str): + from amazondax import AmazonDaxClient + + self.connection = AmazonDaxClient( + endpoints=endpoints, + region_name=region_name + ) + + def dispatch(self, operation_name: str, operation_kwargs: Dict): + method = getattr(self.connection, OP_NAME_TO_METHOD[operation_name]) + return method(**operation_kwargs) diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 183467a9f..c7df7442a 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -3,7 +3,7 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ -from typing import Any, Dict, Mapping, Optional, Sequence +from typing import Any, Dict, List, Mapping, Optional, Sequence from pynamodb.connection.base import Connection, MetaTable, OperationSettings from pynamodb.constants import DEFAULT_BILLING_MODE, KEY @@ -30,6 +30,9 @@ def __init__( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, + dax_write_endpoints: Optional[List[str]]=None, + dax_read_endpoints: Optional[List[str]]=None, + fallback_to_dynamodb: Optional[bool]=False ) -> None: self.table_name = table_name self.connection = Connection(region=region, @@ -39,7 +42,10 @@ def __init__( max_retry_attempts=max_retry_attempts, base_backoff_ms=base_backoff_ms, max_pool_connections=max_pool_connections, - extra_headers=extra_headers) + extra_headers=extra_headers, + dax_write_endpoints=dax_write_endpoints, + dax_read_endpoints=dax_read_endpoints, + fallback_to_dynamodb=fallback_to_dynamodb) if aws_access_key_id and aws_secret_access_key: self.connection.session.set_credentials(aws_access_key_id, diff --git a/pynamodb/models.py b/pynamodb/models.py index 7d5e99161..4355911ba 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -258,6 +258,12 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None: setattr(attr_obj, 'aws_secret_access_key', None) if not hasattr(attr_obj, 'aws_session_token'): setattr(attr_obj, 'aws_session_token', None) + if not hasattr(attr_obj, 'dax_write_endpoints'): + setattr(attr_obj, 'dax_write_endpoints', get_settings_value('dax_write_endpoints')) + if not hasattr(attr_obj, 'dax_read_endpoints'): + setattr(attr_obj, 'dax_read_endpoints', get_settings_value('dax_read_endpoints')) + if not hasattr(attr_obj, 'fallback_to_dynamodb'): + setattr(attr_obj, 'fallback_to_dynamodb', get_settings_value('fallback_to_dynamodb')) # create a custom Model.DoesNotExist derived from pynamodb.exceptions.DoesNotExist, # so that "except Model.DoesNotExist:" would not catch other models' exceptions @@ -1072,7 +1078,10 @@ def _get_connection(cls) -> TableConnection: extra_headers=cls.Meta.extra_headers, aws_access_key_id=cls.Meta.aws_access_key_id, aws_secret_access_key=cls.Meta.aws_secret_access_key, - aws_session_token=cls.Meta.aws_session_token) + aws_session_token=cls.Meta.aws_session_token, + dax_write_endpoints=cls.Meta.dax_write_endpoints, + dax_read_endpoints=cls.Meta.dax_read_endpoints, + fallback_to_dynamodb=cls.Meta.fallback_to_dynamodb) return cls._connection @classmethod diff --git a/pynamodb/settings.py b/pynamodb/settings.py index 7283dce03..78af4b347 100644 --- a/pynamodb/settings.py +++ b/pynamodb/settings.py @@ -16,6 +16,9 @@ 'region': None, 'max_pool_connections': 10, 'extra_headers': None, + 'dax_write_endpoints': [], + 'dax_read_endpoints': [], + 'fallback_to_dynamodb': False } OVERRIDE_SETTINGS_PATH = getenv('PYNAMODB_CONFIG', '/etc/pynamodb/global_default_settings.py') From badb297684aaee6f1d10fae4cb7e4a53804bc02c Mon Sep 17 00:00:00 2001 From: Will Date: Mon, 16 May 2022 18:44:22 -0500 Subject: [PATCH 4/5] Catch scenario where dax endpoints are defined in settings. --- pynamodb/connection/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 51de8dcd7..2df5e3335 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -256,7 +256,6 @@ def __init__(self, self.host = host self._local = local() self._client = None - self._dax_support = bool(dax_write_endpoints or dax_read_endpoints) if region: self.region = region @@ -295,11 +294,13 @@ def __init__(self, if dax_write_endpoints is None: dax_write_endpoints = get_settings_value('dax_write_endpoints') - self._dax_write_client = None if not dax_write_endpoints else DaxClient(endpoints=dax_write_endpoints, region_name=self.region) if dax_read_endpoints is None: dax_read_endpoints = get_settings_value('dax_read_endpoints') + + self._dax_support = bool(dax_write_endpoints or dax_read_endpoints) self._dax_read_client = None if not dax_read_endpoints else DaxClient(endpoints=dax_read_endpoints, region_name=self.region) + self._dax_write_client = None if not dax_write_endpoints else DaxClient(endpoints=dax_write_endpoints, region_name=self.region) if fallback_to_dynamodb is not None: self._fallback_to_dynamodb = fallback_to_dynamodb From 8d501bb9f532111514cd4c8ef72e5663985d842c Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 17 May 2022 08:37:05 -0500 Subject: [PATCH 5/5] Fixed spacing on params. --- pynamodb/connection/base.py | 6 +++--- pynamodb/connection/table.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 2df5e3335..0412b30a7 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -249,9 +249,9 @@ def __init__(self, base_backoff_ms: Optional[int] = None, max_pool_connections: Optional[int] = None, extra_headers: Optional[Mapping[str, str]] = None, - dax_write_endpoints: Optional[List[str]]=None, - dax_read_endpoints: Optional[List[str]]=None, - fallback_to_dynamodb: Optional[bool]=False): + dax_write_endpoints: Optional[List[str]] = None, + dax_read_endpoints: Optional[List[str]] = None, + fallback_to_dynamodb: Optional[bool] = False): self._tables: Dict[str, MetaTable] = {} self.host = host self._local = local() diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index c7df7442a..1a96d4f0b 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -30,9 +30,9 @@ def __init__( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, - dax_write_endpoints: Optional[List[str]]=None, - dax_read_endpoints: Optional[List[str]]=None, - fallback_to_dynamodb: Optional[bool]=False + dax_write_endpoints: Optional[List[str]] = None, + dax_read_endpoints: Optional[List[str]] = None, + fallback_to_dynamodb: Optional[bool] = False ) -> None: self.table_name = table_name self.connection = Connection(region=region,