diff --git a/pynamodb/async_util.py b/pynamodb/async_util.py new file mode 100644 index 000000000..774ab64be --- /dev/null +++ b/pynamodb/async_util.py @@ -0,0 +1,9 @@ +import asyncio +import functools + + +def wrap_secretly_sync_async_fn(async_fn): + @functools.wraps(async_fn) + def wrap(*args, **kwargs): + asyncio.run(async_fn(*args, **kwargs)) + return wrap diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index cca050f7d..4d849ac09 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -1,6 +1,7 @@ """ Lowest level connection """ +import asyncio import json import logging import random @@ -9,7 +10,7 @@ import uuid from base64 import b64decode from threading import local -from typing import Any, Dict, List, Mapping, Optional, Sequence +from typing import Any, Dict, List, Mapping, Optional, Sequence, cast import botocore.client import botocore.exceptions @@ -17,7 +18,15 @@ from botocore.client import ClientError from botocore.hooks import first_non_none_response from botocore.exceptions import BotoCoreError -from botocore.session import get_session + + +import aiobotocore.session +from aiobotocore.config import AioConfig +from aiobotocore.client import AioBaseClient +from aiobotocore.session import get_session as get_async_session + + +from pynamodb.async_util import wrap_secretly_sync_async_fn from pynamodb.constants import ( RETURN_CONSUMED_CAPACITY_VALUES, RETURN_ITEM_COLL_METRICS_VALUES, @@ -234,7 +243,19 @@ def get_exclusive_start_key_map(self, exclusive_start_key): } -class Connection(object): +class ConnectionMeta(type): + def __init__(self, name, bases, attrs): + super().__init__(name, bases, attrs) + + for attr_name, attr_value in attrs.items(): + suffix = "_async" + if attr_name.endswith(suffix) and asyncio.iscoroutinefunction(attr_value): + wrapped_fn = wrap_secretly_sync_async_fn(attr_value) + wrapped_fn.__name__ = wrapped_fn.__name__[:-len(suffix)] + setattr(self, wrapped_fn.__name__, wrapped_fn) + + +class Connection(metaclass=ConnectionMeta): """ A higher level abstraction over botocore """ @@ -288,30 +309,33 @@ def __init__(self, self._extra_headers = get_settings_value('extra_headers') def __repr__(self) -> str: - return "Connection<{}>".format(self.client.meta.endpoint_url) - - def _sign_request(self, request): - auth = self.client._request_signer.get_auth_instance( - self.client._request_signer.signing_name, - self.client._request_signer.region_name, - self.client._request_signer.signature_version) + return "Connection<{}>".format("BLOOP") + + async def _sign_request(self, client, request): + auth = client._request_signer.get_auth_instance( + client._request_signer.signing_name, + client._request_signer.region_name, + client._request_signer.signature_version) + if asyncio.iscoroutine(auth): + auth = await auth auth.add_auth(request) - def _create_prepared_request( + async def _create_prepared_request( self, + client, params: Dict, settings: OperationSettings, ) -> AWSPreparedRequest: request = create_request_object(params) - self._sign_request(request) - prepared_request = self.client._endpoint.prepare_request(request) + await self._sign_request(client, request) + prepared_request = client._endpoint.prepare_request(request) if self._extra_headers is not None: prepared_request.headers.update(self._extra_headers) if settings.extra_headers is not None: prepared_request.headers.update(settings.extra_headers) return prepared_request - def dispatch(self, operation_name: str, operation_kwargs: Dict, settings: OperationSettings = OperationSettings.default) -> Dict: + async def dispatch(self, operation_name: str, operation_kwargs: Dict, settings: OperationSettings = OperationSettings.default) -> Dict: """ Dispatches `operation_name` with arguments `operation_kwargs` @@ -326,7 +350,11 @@ def dispatch(self, operation_name: str, operation_kwargs: Dict, settings: Operat req_uuid = uuid.uuid4() self.send_pre_boto_callback(operation_name, req_uuid, table_name) - data = self._make_api_call(operation_name, operation_kwargs, settings) + + data = await self._make_api_call(operation_name, operation_kwargs) + if asyncio.iscoroutine(data): + data = await cast(Any, data) + self.send_post_boto_callback(operation_name, req_uuid, table_name) if data and CONSUMED_CAPACITY in data: @@ -348,14 +376,15 @@ def send_pre_boto_callback(self, operation_name, req_uuid, table_name): except Exception as e: log.exception("pre_boto callback threw an exception.") - def _make_api_call(self, operation_name: str, operation_kwargs: Dict, settings: OperationSettings = OperationSettings.default) -> Dict: + async def _make_api_call(self, operation_name: str, operation_kwargs: Dict, settings: OperationSettings = OperationSettings.default) -> Dict: """ This private method is here for two reasons: 1. It's faster to avoid using botocore's response parsing 2. It provides a place to monkey patch HTTP requests for unit testing """ - operation_model = self.client._service_model.operation_model(operation_name) - request_dict = self.client._convert_to_request_dict( + client = await self.client + operation_model = client._service_model.operation_model(operation_name) + request_dict = await client._convert_to_request_dict( operation_kwargs, operation_model, ) @@ -375,24 +404,20 @@ def _make_api_call(self, operation_name: str, operation_kwargs: Dict, settings: prepared_request.reset_stream() # Create a new request for each retry (including a new signature). - prepared_request = self._create_prepared_request(request_dict, settings) + prepared_request = await self._create_prepared_request(client, request_dict, settings) # Implement the before-send event from botocore event_name = 'before-send.dynamodb.{}'.format(operation_model.name) - event_responses = self.client._endpoint._event_emitter.emit(event_name, request=prepared_request) + event_responses = await client._endpoint._event_emitter.emit(event_name, request=prepared_request) event_response = first_non_none_response(event_responses) if event_response is None: - http_response = self.client._endpoint.http_session.send(prepared_request) + http_response = await client._endpoint._send(prepared_request) else: http_response = event_response is_last_attempt_for_exceptions = True # don't retry if we have an event response - # json.loads accepts bytes in >= 3.6.0 - if sys.version_info < (3, 6, 0): - data = json.loads(http_response.text) - else: - data = json.loads(http_response.content) + data = json.loads(http_response.content) except (ValueError, botocore.exceptions.HTTPClientError, botocore.exceptions.ConnectionError) as e: if is_last_attempt_for_exceptions: log.debug('Reached the maximum number of retry attempts: %s', attempt_number) @@ -509,17 +534,17 @@ def _handle_binary_attributes(data): return data @property - def session(self) -> botocore.session.Session: + def session(self) -> aiobotocore.session.AioSession: """ - Returns a valid botocore session + Returns a valid async aiobotocore session """ # botocore client creation is not thread safe as of v1.2.5+ (see issue #153) - if getattr(self._local, 'session', None) is None: - self._local.session = get_session() - return self._local.session + if getattr(self._local, 'async_session', None) is None: + self._local.async_session = get_async_session() + return self._local.async_session @property - def client(self): + async def client(self) -> AioBaseClient: """ Returns a botocore dynamodb client """ @@ -528,15 +553,15 @@ def client(self): # if the client does not have credentials, we create a new client # otherwise the client is permanently poisoned in the case of metadata service flakiness when using IAM roles if not self._client or (self._client._request_signer and not self._client._request_signer._credentials): - config = botocore.client.Config( + config = AioConfig( parameter_validation=False, # Disable unnecessary validation for performance connect_timeout=self._connect_timeout_seconds, read_timeout=self._read_timeout_seconds, max_pool_connections=self._max_pool_connections) - self._client = self.session.create_client(SERVICE_NAME, self.region, endpoint_url=self.host, config=config) + self._client = await self.session._create_client(SERVICE_NAME, self.region, endpoint_url=self.host, config=config) return self._client - def get_meta_table(self, table_name: str, refresh: bool = False): + async def get_meta_table_async(self, table_name: str, refresh: bool = False): """ Returns a MetaTable """ @@ -545,7 +570,7 @@ def get_meta_table(self, table_name: str, refresh: bool = False): TABLE_NAME: table_name } try: - data = self.dispatch(DESCRIBE_TABLE, operation_kwargs) + data = await self.dispatch(DESCRIBE_TABLE, operation_kwargs) self._tables[table_name] = MetaTable(data.get(TABLE_KEY)) except BotoCoreError as e: raise TableError("Unable to describe table: {}".format(e), e) @@ -556,7 +581,7 @@ def get_meta_table(self, table_name: str, refresh: bool = False): raise return self._tables[table_name] - def create_table( + async def create_table_async( self, table_name: str, attribute_definitions: Optional[Any] = None, @@ -646,12 +671,12 @@ def create_table( ] try: - data = self.dispatch(CREATE_TABLE, operation_kwargs) + data = await self.dispatch(CREATE_TABLE, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: raise TableError("Failed to create table: {}".format(e), e) return data - def update_time_to_live(self, table_name: str, ttl_attribute_name: str) -> Dict: + async def update_time_to_live_async(self, table_name: str, ttl_attribute_name: str) -> Dict: """ Performs the UpdateTimeToLive operation """ @@ -663,11 +688,11 @@ def update_time_to_live(self, table_name: str, ttl_attribute_name: str) -> Dict: } } try: - return self.dispatch(UPDATE_TIME_TO_LIVE, operation_kwargs) + return await self.dispatch(UPDATE_TIME_TO_LIVE, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: raise TableError("Failed to update TTL on table: {}".format(e), e) - def delete_table(self, table_name: str) -> Dict: + async def delete_table_async(self, table_name: str) -> Dict: """ Performs the DeleteTable operation """ @@ -675,12 +700,12 @@ def delete_table(self, table_name: str) -> Dict: TABLE_NAME: table_name } try: - data = self.dispatch(DELETE_TABLE, operation_kwargs) + data = await self.dispatch(DELETE_TABLE, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: raise TableError("Failed to delete table: {}".format(e), e) return data - def update_table( + async def update_table_async( self, table_name: str, read_capacity_units: Optional[int] = None, @@ -714,11 +739,11 @@ def update_table( }) operation_kwargs[GLOBAL_SECONDARY_INDEX_UPDATES] = global_secondary_indexes_list try: - return self.dispatch(UPDATE_TABLE, operation_kwargs) + return await self.dispatch(UPDATE_TABLE, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: raise TableError("Failed to update table: {}".format(e), e) - def list_tables( + async def list_tables_async( self, exclusive_start_table_name: Optional[str] = None, limit: Optional[int] = None, @@ -736,23 +761,23 @@ def list_tables( LIMIT: limit }) try: - return self.dispatch(LIST_TABLES, operation_kwargs) + return await self.dispatch(LIST_TABLES, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: raise TableError("Unable to list tables: {}".format(e), e) - def describe_table(self, table_name: str) -> Dict: + async def describe_table_async(self, table_name: str) -> Dict: """ Performs the DescribeTable operation """ try: - tbl = self.get_meta_table(table_name, refresh=True) + tbl = await self.get_meta_table_async(table_name, refresh=True) if tbl: return tbl.data except ValueError: pass raise TableDoesNotExist(table_name) - def get_item_attribute_map( + async def get_item_attribute_map( self, table_name: str, attributes: Any, @@ -762,7 +787,7 @@ def get_item_attribute_map( """ Builds up a dynamodb compatible AttributeValue map """ - tbl = self.get_meta_table(table_name) + tbl = await self.get_meta_table_async(table_name) if tbl is None: raise TableError("No such table {}".format(table_name)) return tbl.get_item_attribute_map( @@ -792,7 +817,7 @@ def parse_attribute( return None, attribute return attribute - def get_attribute_type( + async def get_attribute_type( self, table_name: str, attribute_name: str, @@ -802,12 +827,12 @@ def get_attribute_type( Returns the proper attribute type for a given attribute name :param value: The attribute value an be supplied just in case the type is already included """ - tbl = self.get_meta_table(table_name) + tbl = await self.get_meta_table_async(table_name) if tbl is None: raise TableError("No such table {}".format(table_name)) return tbl.get_attribute_type(attribute_name, value=value) - def get_identifier_map( + async def get_identifier_map( self, table_name: str, hash_key: str, @@ -817,7 +842,7 @@ def get_identifier_map( """ Builds the identifier map that is common to several operations """ - tbl = self.get_meta_table(table_name) + tbl = await self.get_meta_table_async(table_name) if tbl is None: raise TableError("No such table {}".format(table_name)) return tbl.get_identifier_map(hash_key, range_key=range_key, key=key) @@ -868,16 +893,16 @@ def get_item_collection_map(self, return_item_collection_metrics: str) -> Dict: RETURN_ITEM_COLL_METRICS: str(return_item_collection_metrics).upper() } - def get_exclusive_start_key_map(self, table_name: str, exclusive_start_key: str) -> Dict: + async def get_exclusive_start_key_map(self, table_name: str, exclusive_start_key: str) -> Dict: """ Builds the exclusive start key attribute map """ - tbl = self.get_meta_table(table_name) + tbl = await self.get_meta_table_async(table_name) if tbl is None: raise TableError("No such table {}".format(table_name)) return tbl.get_exclusive_start_key_map(exclusive_start_key) - def get_operation_kwargs( + async def get_operation_kwargs( self, table_name: str, hash_key: str, @@ -900,9 +925,9 @@ def get_operation_kwargs( expression_attribute_values: Dict[str, Any] = {} operation_kwargs[TABLE_NAME] = table_name - operation_kwargs.update(self.get_identifier_map(table_name, hash_key, range_key, key=key)) + operation_kwargs.update(await self.get_identifier_map(table_name, hash_key, range_key, key=key)) if attributes and operation_kwargs.get(ITEM) is not None: - attrs = self.get_item_attribute_map(table_name, attributes) + attrs = await self.get_item_attribute_map(table_name, attributes) operation_kwargs[ITEM].update(attrs[ITEM]) if attributes_to_get is not None: projection_expression = create_projection_expression(attributes_to_get, name_placeholders) @@ -932,7 +957,7 @@ def get_operation_kwargs( operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values return operation_kwargs - def delete_item( + async def delete_item( self, table_name: str, hash_key: str, @@ -946,7 +971,7 @@ def delete_item( """ Performs the DeleteItem operation and returns the result """ - operation_kwargs = self.get_operation_kwargs( + operation_kwargs = await self.get_operation_kwargs( table_name, hash_key, range_key=range_key, @@ -956,11 +981,11 @@ def delete_item( return_item_collection_metrics=return_item_collection_metrics ) try: - return self.dispatch(DELETE_ITEM, operation_kwargs, settings) + return await self.dispatch(DELETE_ITEM, operation_kwargs, settings) except BOTOCORE_EXCEPTIONS as e: raise DeleteError("Failed to delete item: {}".format(e), e) - def update_item( + async def update_item( self, table_name: str, hash_key: str, @@ -978,7 +1003,7 @@ def update_item( if not actions: raise ValueError("'actions' cannot be empty") - operation_kwargs = self.get_operation_kwargs( + operation_kwargs = await self.get_operation_kwargs( table_name=table_name, hash_key=hash_key, range_key=range_key, @@ -989,11 +1014,11 @@ def update_item( return_item_collection_metrics=return_item_collection_metrics, ) try: - return self.dispatch(UPDATE_ITEM, operation_kwargs, settings) + return await self.dispatch(UPDATE_ITEM, operation_kwargs, settings) except BOTOCORE_EXCEPTIONS as e: raise UpdateError("Failed to update item: {}".format(e), e) - def put_item( + async def put_item( self, table_name: str, hash_key: str, @@ -1008,7 +1033,7 @@ def put_item( """ Performs the PutItem operation and returns the result """ - operation_kwargs = self.get_operation_kwargs( + operation_kwargs = await self.get_operation_kwargs( table_name=table_name, hash_key=hash_key, range_key=range_key, @@ -1020,7 +1045,7 @@ def put_item( return_item_collection_metrics=return_item_collection_metrics ) try: - return self.dispatch(PUT_ITEM, operation_kwargs, settings) + return await self.dispatch(PUT_ITEM, operation_kwargs, settings) except BOTOCORE_EXCEPTIONS as e: raise PutError("Failed to put item: {}".format(e), e) @@ -1040,7 +1065,7 @@ def _get_transact_operation_kwargs( return operation_kwargs - def transact_write_items( + async def transact_write_items( self, condition_check_items: Sequence[Dict], delete_items: Sequence[Dict], @@ -1076,11 +1101,11 @@ def transact_write_items( operation_kwargs[TRANSACT_ITEMS] = transact_items try: - return self.dispatch(TRANSACT_WRITE_ITEMS, operation_kwargs, settings) + return await self.dispatch(TRANSACT_WRITE_ITEMS, operation_kwargs, settings) except BOTOCORE_EXCEPTIONS as e: raise TransactWriteError("Failed to write transaction items", e) - def transact_get_items( + async def transact_get_items( self, get_items: Sequence[Dict], return_consumed_capacity: Optional[str] = None, @@ -1095,11 +1120,11 @@ def transact_get_items( ] try: - return self.dispatch(TRANSACT_GET_ITEMS, operation_kwargs, settings) + return await self.dispatch(TRANSACT_GET_ITEMS, operation_kwargs, settings) except BOTOCORE_EXCEPTIONS as e: raise TransactGetError("Failed to get transaction items", e) - def batch_write_item( + async def batch_write_item( self, table_name: str, put_items: Optional[Any] = None, @@ -1126,21 +1151,21 @@ def batch_write_item( if put_items: for item in put_items: put_items_list.append({ - PUT_REQUEST: self.get_item_attribute_map(table_name, item, pythonic_key=False) + PUT_REQUEST: await self.get_item_attribute_map(table_name, item, pythonic_key=False) }) delete_items_list = [] if delete_items: for item in delete_items: delete_items_list.append({ - DELETE_REQUEST: self.get_item_attribute_map(table_name, item, item_key=KEY, pythonic_key=False) + DELETE_REQUEST: await self.get_item_attribute_map(table_name, item, item_key=KEY, pythonic_key=False) }) operation_kwargs[REQUEST_ITEMS][table_name] = delete_items_list + put_items_list try: - return self.dispatch(BATCH_WRITE_ITEM, operation_kwargs, settings) + return await self.dispatch(BATCH_WRITE_ITEM, operation_kwargs, settings) except BOTOCORE_EXCEPTIONS as e: raise PutError("Failed to batch write items: {}".format(e), e) - def batch_get_item( + async def batch_get_item( self, table_name: str, keys: Sequence[str], @@ -1174,15 +1199,15 @@ def batch_get_item( keys_map: Dict[str, List] = {KEYS: []} for key in keys: keys_map[KEYS].append( - self.get_item_attribute_map(table_name, key)[ITEM] + (await self.get_item_attribute_map(table_name, key))[ITEM] ) operation_kwargs[REQUEST_ITEMS][table_name].update(keys_map) try: - return self.dispatch(BATCH_GET_ITEM, operation_kwargs, settings) + return await self.dispatch(BATCH_GET_ITEM, operation_kwargs, settings) except BOTOCORE_EXCEPTIONS as e: raise GetError("Failed to batch get items: {}".format(e), e) - def get_item( + async def get_item( self, table_name: str, hash_key: str, @@ -1194,7 +1219,7 @@ def get_item( """ Performs the GetItem operation and returns the result """ - operation_kwargs = self.get_operation_kwargs( + operation_kwargs = await self.get_operation_kwargs( table_name=table_name, hash_key=hash_key, range_key=range_key, @@ -1202,11 +1227,11 @@ def get_item( attributes_to_get=attributes_to_get ) try: - return self.dispatch(GET_ITEM, operation_kwargs, settings) + return await self.dispatch(GET_ITEM, operation_kwargs, settings) except BOTOCORE_EXCEPTIONS as e: raise GetError("Failed to get item: {}".format(e), e) - def scan( + async def scan( self, table_name: str, filter_condition: Optional[Any] = None, @@ -1242,7 +1267,7 @@ def scan( if return_consumed_capacity: operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) if exclusive_start_key: - operation_kwargs.update(self.get_exclusive_start_key_map(table_name, exclusive_start_key)) + operation_kwargs.update(await self.get_exclusive_start_key_map(table_name, exclusive_start_key)) if segment is not None: operation_kwargs[SEGMENT] = segment if total_segments: @@ -1255,11 +1280,11 @@ def scan( operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values try: - return self.dispatch(SCAN, operation_kwargs, settings) + return await self.dispatch(SCAN, operation_kwargs, settings) except BOTOCORE_EXCEPTIONS as e: raise ScanError("Failed to scan table: {}".format(e), e) - def query( + async def query( self, table_name: str, hash_key: str, @@ -1285,7 +1310,7 @@ def query( name_placeholders: Dict[str, str] = {} expression_attribute_values: Dict[str, Any] = {} - tbl = self.get_meta_table(table_name) + tbl = await self.get_meta_table_async(table_name) if tbl is None: raise TableError("No such table: {}".format(table_name)) if index_name: @@ -1295,7 +1320,7 @@ def query( else: hash_keyname = tbl.hash_keyname - hash_condition_value = {self.get_attribute_type(table_name, hash_keyname, hash_key): self.parse_attribute(hash_key)} + hash_condition_value = {await self.get_attribute_type(table_name, hash_keyname, hash_key): self.parse_attribute(hash_key)} key_condition = Path([hash_keyname]) == hash_condition_value if range_key_condition is not None: key_condition &= range_key_condition @@ -1311,7 +1336,7 @@ def query( if consistent_read: operation_kwargs[CONSISTENT_READ] = True if exclusive_start_key: - operation_kwargs.update(self.get_exclusive_start_key_map(table_name, exclusive_start_key)) + operation_kwargs.update(await self.get_exclusive_start_key_map(table_name, exclusive_start_key)) if index_name: operation_kwargs[INDEX_NAME] = index_name if limit is not None: @@ -1330,7 +1355,7 @@ def query( operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values try: - return self.dispatch(QUERY, operation_kwargs, settings) + return await self.dispatch(QUERY, operation_kwargs, settings) except BOTOCORE_EXCEPTIONS as e: raise QueryError("Failed to query items: {}".format(e), e) diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 183467a9f..41036d162 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -3,6 +3,8 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ +import asyncio +from pynamodb.async_util import wrap_secretly_sync_async_fn from typing import Any, Dict, Mapping, Optional, Sequence from pynamodb.connection.base import Connection, MetaTable, OperationSettings @@ -10,8 +12,17 @@ from pynamodb.expressions.condition import Condition from pynamodb.expressions.update import Action +class TableMeta(type): + def __init__(self, name, bases, attrs): + super().__init__(name, bases, attrs) -class TableConnection: + for attr_name, attr_value in attrs.items(): + suffix = "_async" + if attr_name.endswith(suffix) and asyncio.iscoroutinefunction(attr_value): + setattr(self, attr_name[:-len(suffix)], wrap_secretly_sync_async_fn(attr_value)) + + +class TableConnection(metaclass=TableMeta): """ A higher level abstraction over botocore """ @@ -32,6 +43,7 @@ def __init__( aws_session_token: Optional[str] = None, ) -> None: self.table_name = table_name + self.connection = Connection(region=region, host=host, connect_timeout_seconds=connect_timeout_seconds, @@ -46,13 +58,13 @@ def __init__( aws_secret_access_key, aws_session_token) - def get_meta_table(self, refresh: bool = False) -> MetaTable: + async def get_meta_table_async(self, refresh: bool = False) -> MetaTable: """ Returns a MetaTable """ - return self.connection.get_meta_table(self.table_name, refresh=refresh) + return await self.connection.get_meta_table_async(self.table_name, refresh=refresh) - def get_operation_kwargs( + async def get_operation_kwargs_async( self, hash_key: str, range_key: Optional[str] = None, @@ -67,7 +79,7 @@ def get_operation_kwargs( return_item_collection_metrics: Optional[str] = None, return_values_on_condition_failure: Optional[str] = None, ) -> Dict: - return self.connection.get_operation_kwargs( + return await self.connection.get_operation_kwargs( self.table_name, hash_key, range_key=range_key, @@ -83,7 +95,7 @@ def get_operation_kwargs( return_values_on_condition_failure=return_values_on_condition_failure ) - def delete_item( + async def delete_item_async( self, hash_key: str, range_key: Optional[str] = None, @@ -96,7 +108,7 @@ def delete_item( """ Performs the DeleteItem operation and returns the result """ - return self.connection.delete_item( + return await self.connection.delete_item( self.table_name, hash_key, range_key=range_key, @@ -107,7 +119,7 @@ def delete_item( settings=settings, ) - def update_item( + async def update_item_async( self, hash_key: str, range_key: Optional[str] = None, @@ -121,7 +133,7 @@ def update_item( """ Performs the UpdateItem operation """ - return self.connection.update_item( + return await self.connection.update_item( self.table_name, hash_key, range_key=range_key, @@ -133,7 +145,7 @@ def update_item( settings=settings, ) - def put_item( + async def put_item_async( self, hash_key: str, range_key: Optional[str] = None, @@ -147,7 +159,7 @@ def put_item( """ Performs the PutItem operation and returns the result """ - return self.connection.put_item( + return await self.connection.put_item( self.table_name, hash_key, range_key=range_key, @@ -159,7 +171,7 @@ def put_item( settings=settings, ) - def batch_write_item( + async def batch_write_item_async( self, put_items: Optional[Any] = None, delete_items: Optional[Any] = None, @@ -170,7 +182,7 @@ def batch_write_item( """ Performs the batch_write_item operation """ - return self.connection.batch_write_item( + return await self.connection.batch_write_item( self.table_name, put_items=put_items, delete_items=delete_items, @@ -179,7 +191,7 @@ def batch_write_item( settings=settings, ) - def batch_get_item( + async def batch_get_item_async( self, keys: Sequence[str], consistent_read: Optional[bool] = None, @@ -190,7 +202,7 @@ def batch_get_item( """ Performs the batch get item operation """ - return self.connection.batch_get_item( + return await self.connection.batch_get_item( self.table_name, keys, consistent_read=consistent_read, @@ -199,7 +211,7 @@ def batch_get_item( settings=settings, ) - def get_item( + async def get_item_async( self, hash_key: str, range_key: Optional[str] = None, @@ -210,7 +222,7 @@ def get_item( """ Performs the GetItem operation and returns the result """ - return self.connection.get_item( + return await self.connection.get_item( self.table_name, hash_key, range_key=range_key, @@ -219,7 +231,7 @@ def get_item( settings=settings, ) - def scan( + async def scan_async( self, filter_condition: Optional[Any] = None, attributes_to_get: Optional[Any] = None, @@ -235,7 +247,7 @@ def scan( """ Performs the scan operation """ - return self.connection.scan( + return await self.connection.scan( self.table_name, filter_condition=filter_condition, attributes_to_get=attributes_to_get, @@ -249,7 +261,7 @@ def scan( settings=settings, ) - def query( + async def query_async( self, hash_key: str, range_key_condition: Optional[Condition] = None, @@ -267,7 +279,7 @@ def query( """ Performs the Query operation and returns the result """ - return self.connection.query( + return await self.connection.query( self.table_name, hash_key, range_key_condition=range_key_condition, @@ -283,25 +295,25 @@ def query( settings=settings, ) - def describe_table(self) -> Dict: + async def describe_table_async(self) -> Dict: """ Performs the DescribeTable operation and returns the result """ - return self.connection.describe_table(self.table_name) + return await self.connection.describe_table_async(self.table_name) - def delete_table(self) -> Dict: + async def delete_table_async(self) -> Dict: """ Performs the DeleteTable operation and returns the result """ - return self.connection.delete_table(self.table_name) + return await self.connection.delete_table_async(self.table_name) - def update_time_to_live(self, ttl_attr_name: str) -> Dict: + async def update_time_to_live_async(self, ttl_attr_name: str) -> Dict: """ Performs the UpdateTimeToLive operation and returns the result """ - return self.connection.update_time_to_live(self.table_name, ttl_attr_name) + return await self.connection.update_time_to_live_async(self.table_name, ttl_attr_name) - def update_table( + async def update_table_async( self, read_capacity_units: Optional[int] = None, write_capacity_units: Optional[int] = None, @@ -310,13 +322,13 @@ def update_table( """ Performs the UpdateTable operation and returns the result """ - return self.connection.update_table( + return await self.connection.update_table_async( self.table_name, read_capacity_units=read_capacity_units, write_capacity_units=write_capacity_units, global_secondary_index_updates=global_secondary_index_updates) - def create_table( + async def create_table_async( self, attribute_definitions: Optional[Any] = None, key_schema: Optional[Any] = None, @@ -331,7 +343,7 @@ def create_table( """ Performs the CreateTable operation and returns the result """ - return self.connection.create_table( + return await self.connection.create_table_async( self.table_name, attribute_definitions=attribute_definitions, key_schema=key_schema, diff --git a/pynamodb/indexes.py b/pynamodb/indexes.py index 6ee5508b4..16d35d80e 100644 --- a/pynamodb/indexes.py +++ b/pynamodb/indexes.py @@ -30,7 +30,7 @@ class IndexMeta(GenericMeta): that contains the index settings """ def __init__(self, name, bases, attrs, *args, **kwargs): - super().__init__(name, bases, attrs, *args, **kwargs) # type: ignore + super().__init__(name, bases, attrs, *args, **kwargs) if isinstance(attrs, dict): for attr_name, attr_obj in attrs.items(): if attr_name == META_CLASS_NAME: diff --git a/pynamodb/models.py b/pynamodb/models.py index ca768afca..a37454954 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -7,7 +7,7 @@ import warnings import sys from inspect import getmembers -from typing import Any +from typing import Any, AsyncIterator from typing import Dict from typing import Generic from typing import Iterable @@ -75,7 +75,7 @@ def __init__(self, model: Type[_T], auto_commit: bool = True, settings: Operatio self.failed_operations: List[Any] = [] self.settings = settings - def save(self, put_item: _T) -> None: + async def save(self, put_item: _T) -> None: """ This adds `put_item` to the list of pending operations to be performed. @@ -92,10 +92,10 @@ def save(self, put_item: _T) -> None: if not self.auto_commit: raise ValueError("DynamoDB allows a maximum of 25 batch operations") else: - self.commit() + await self.commit() self.pending_operations.append({"action": PUT, "item": put_item}) - def delete(self, del_item: _T) -> None: + async def delete(self, del_item: _T) -> None: """ This adds `del_item` to the list of pending operations to be performed. @@ -112,20 +112,20 @@ def delete(self, del_item: _T) -> None: if not self.auto_commit: raise ValueError("DynamoDB allows a maximum of 25 batch operations") else: - self.commit() + await self.commit() self.pending_operations.append({"action": DELETE, "item": del_item}) - def __enter__(self): + async def __aenter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type, exc_val, exc_tb): """ This ensures that all pending operations are committed when the context is exited """ - return self.commit() + return await self.commit() - def commit(self) -> None: + async def commit(self) -> None: """ Writes all of the changes that are pending """ @@ -140,7 +140,7 @@ def commit(self) -> None: self.pending_operations = [] if not len(put_items) and not len(delete_items): return - data = self.model._get_connection().batch_write_item( + data = await self.model._get_connection().batch_write_item_async( put_items=put_items, delete_items=delete_items, settings=self.settings, @@ -165,7 +165,7 @@ def commit(self) -> None: delete_items.append(item.get(DELETE_REQUEST).get(KEY)) # type: ignore log.info("Resending %d unprocessed keys for batch operation after %d seconds sleep", len(unprocessed_items), sleep_time) - data = self.model._get_connection().batch_write_item( + data = await self.model._get_connection().batch_write_item_async( put_items=put_items, delete_items=delete_items, settings=self.settings, @@ -311,13 +311,13 @@ def __init__( super(Model, self).__init__(_user_instantiated=_user_instantiated, **attributes) @classmethod - def batch_get( + async def batch_get( cls: Type[_T], items: Iterable[Union[_KeyType, Iterable[_KeyType]]], consistent_read: Optional[bool] = None, attributes_to_get: Optional[Sequence[str]] = None, settings: OperationSettings = OperationSettings.default - ) -> Iterator[_T]: + ) -> AsyncIterator[_T]: """ BatchGetItem for this model @@ -331,7 +331,7 @@ def batch_get( while items: if len(keys_to_get) == BATCH_GET_PAGE_LIMIT: while keys_to_get: - page, unprocessed_keys = cls._batch_get_page( + page, unprocessed_keys = await cls._batch_get_page( keys_to_get, consistent_read=consistent_read, attributes_to_get=attributes_to_get, @@ -357,7 +357,7 @@ def batch_get( }) while keys_to_get: - page, unprocessed_keys = cls._batch_get_page( + page, unprocessed_keys = await cls._batch_get_page( keys_to_get, consistent_read=consistent_read, attributes_to_get=attributes_to_get, @@ -391,7 +391,7 @@ def __repr__(self) -> str: msg = "{}<{}>".format(self.Meta.table_name, hash_key) return msg - def delete(self, condition: Optional[Condition] = None, settings: OperationSettings = OperationSettings.default) -> Any: + async def delete(self, condition: Optional[Condition] = None, settings: OperationSettings = OperationSettings.default) -> Any: """ Deletes this object from dynamodb @@ -402,9 +402,9 @@ def delete(self, condition: Optional[Condition] = None, settings: OperationSetti if version_condition is not None: condition &= version_condition - return self._get_connection().delete_item(hk_value, range_key=rk_value, condition=condition, settings=settings) + return await self._get_connection().delete_item_async(hk_value, range_key=rk_value, condition=condition, settings=settings) - def update(self, actions: List[Action], condition: Optional[Condition] = None, settings: OperationSettings = OperationSettings.default) -> Any: + async def update(self, actions: List[Action], condition: Optional[Condition] = None, settings: OperationSettings = OperationSettings.default) -> Any: """ Updates an item using the UpdateItem operation. @@ -422,7 +422,7 @@ def update(self, actions: List[Action], condition: Optional[Condition] = None, s if version_condition is not None: condition &= version_condition - data = self._get_connection().update_item(hk_value, range_key=rk_value, return_values=ALL_NEW, condition=condition, actions=actions, settings=settings) + data = await self._get_connection().update_item_async(hk_value, range_key=rk_value, return_values=ALL_NEW, condition=condition, actions=actions, settings=settings) item_data = data[ATTRIBUTES] stored_cls = self._get_discriminator_class(item_data) if stored_cls and stored_cls != type(self): @@ -430,17 +430,17 @@ def update(self, actions: List[Action], condition: Optional[Condition] = None, s self.deserialize(item_data) return data - def save(self, condition: Optional[Condition] = None, settings: OperationSettings = OperationSettings.default) -> Dict[str, Any]: + async def save(self, condition: Optional[Condition] = None, settings: OperationSettings = OperationSettings.default) -> Dict[str, Any]: """ Save this object to dynamodb """ args, kwargs = self._get_save_args(condition=condition) kwargs['settings'] = settings - data = self._get_connection().put_item(*args, **kwargs) + data = await self._get_connection().put_item_async(*args, **kwargs) self.update_local_version_attribute() return data - def refresh(self, consistent_read: bool = False, settings: OperationSettings = OperationSettings.default) -> None: + async def refresh(self, consistent_read: bool = False, settings: OperationSettings = OperationSettings.default) -> None: """ Retrieves this object's data from dynamodb and syncs this local object @@ -449,7 +449,7 @@ def refresh(self, consistent_read: bool = False, settings: OperationSettings = O :raises ModelInstance.DoesNotExist: if the object to be updated does not exist """ hk_value, rk_value = self._get_hash_range_key_serialized_values() - attrs = self._get_connection().get_item(hk_value, range_key=rk_value, consistent_read=consistent_read, settings=settings) + attrs = await self._get_connection().get_item_async(hk_value, range_key=rk_value, consistent_read=consistent_read, settings=settings) item_data = attrs.get(ITEM, None) if item_data is None: raise self.DoesNotExist("This item does not exist in the table.") @@ -458,7 +458,7 @@ def refresh(self, consistent_read: bool = False, settings: OperationSettings = O raise ValueError("Cannot refresh this item from the returned class: {}".format(stored_cls.__name__)) self.deserialize(item_data) - def get_update_kwargs_from_instance( + async def get_update_kwargs_from_instance( self, actions: List[Action], condition: Optional[Condition] = None, @@ -470,9 +470,9 @@ def get_update_kwargs_from_instance( if version_condition is not None: condition &= version_condition - return self._get_connection().get_operation_kwargs(hk_value, range_key=rk_value, key=KEY, actions=actions, condition=condition, return_values_on_condition_failure=return_values_on_condition_failure) + return await self._get_connection().get_operation_kwargs_async(hk_value, range_key=rk_value, key=KEY, actions=actions, condition=condition, return_values_on_condition_failure=return_values_on_condition_failure) - def get_delete_kwargs_from_instance( + async def get_delete_kwargs_from_instance( self, condition: Optional[Condition] = None, return_values_on_condition_failure: Optional[str] = None, @@ -483,9 +483,9 @@ def get_delete_kwargs_from_instance( if version_condition is not None: condition &= version_condition - return self._get_connection().get_operation_kwargs(hk_value, range_key=rk_value, key=KEY, condition=condition, return_values_on_condition_failure=return_values_on_condition_failure) + return await self._get_connection().get_operation_kwargs_async(hk_value, range_key=rk_value, key=KEY, condition=condition, return_values_on_condition_failure=return_values_on_condition_failure) - def get_save_kwargs_from_instance( + async def get_save_kwargs_from_instance( self, condition: Optional[Condition] = None, return_values_on_condition_failure: Optional[str] = None, @@ -493,24 +493,24 @@ def get_save_kwargs_from_instance( args, save_kwargs = self._get_save_args(null_check=True, condition=condition) save_kwargs['key'] = ITEM save_kwargs['return_values_on_condition_failure'] = return_values_on_condition_failure - return self._get_connection().get_operation_kwargs(*args, **save_kwargs) + return await self._get_connection().get_operation_kwargs_async(*args, **save_kwargs) @classmethod - def get_operation_kwargs_from_class( + async def get_operation_kwargs_from_class( cls, hash_key: str, range_key: Optional[_KeyType] = None, condition: Optional[Condition] = None, ) -> Dict[str, Any]: hash_key, range_key = cls._serialize_keys(hash_key, range_key) - return cls._get_connection().get_operation_kwargs( + return await cls._get_connection().get_operation_kwargs_async( hash_key=hash_key, range_key=range_key, condition=condition ) @classmethod - def get( + async def get( cls: Type[_T], hash_key: _KeyType, range_key: Optional[_KeyType] = None, @@ -529,7 +529,7 @@ def get( """ hash_key, range_key = cls._serialize_keys(hash_key, range_key) - data = cls._get_connection().get_item( + data = await cls._get_connection().get_item_async( hash_key, range_key=range_key, consistent_read=consistent_read, @@ -556,7 +556,7 @@ def from_raw_data(cls: Type[_T], data: Dict[str, Any]) -> _T: return cls._instantiate(data) @classmethod - def count( + async def count( cls: Type[_T], hash_key: Optional[_KeyType] = None, range_key_condition: Optional[Condition] = None, @@ -580,7 +580,7 @@ def count( if hash_key is None: if filter_condition is not None: raise ValueError('A hash_key must be given to use filters') - return cls.describe_table().get(ITEM_COUNT) + return (await cls.describe_table()).get(ITEM_COUNT) cls._get_indexes() if cls._index_classes and index_name: @@ -604,7 +604,7 @@ def count( ) result_iterator: ResultIterator[_T] = ResultIterator( - cls._get_connection().query, + cls._get_connection().query_async, query_args, query_kwargs, limit=limit, @@ -613,7 +613,8 @@ def count( ) # iterate through results - list(result_iterator) + async for i in result_iterator: + pass return result_iterator.total_count @@ -676,7 +677,7 @@ def query( ) return ResultIterator( - cls._get_connection().query, + cls._get_connection().query_async, query_args, query_kwargs, map_fn=cls.from_raw_data, @@ -735,7 +736,7 @@ def scan( ) return ResultIterator( - cls._get_connection().scan, + cls._get_connection().scan_async, scan_args, scan_kwargs, map_fn=cls.from_raw_data, @@ -745,32 +746,32 @@ def scan( ) @classmethod - def exists(cls: Type[_T]) -> bool: + async def exists(cls: Type[_T]) -> bool: """ Returns True if this table exists, False otherwise """ try: - cls._get_connection().describe_table() + await cls._get_connection().describe_table_async() return True except TableDoesNotExist: return False @classmethod - def delete_table(cls) -> Any: + async def delete_table(cls) -> Any: """ Delete the table for this model """ - return cls._get_connection().delete_table() + return await cls._get_connection().delete_table_async() @classmethod - def describe_table(cls) -> Any: + async def describe_table(cls) -> Any: """ Returns the result of a DescribeTable operation on this model's table """ - return cls._get_connection().describe_table() + return await cls._get_connection().describe_table_async() @classmethod - def create_table( + async def create_table( cls, wait: bool = False, read_capacity_units: Optional[int] = None, @@ -817,12 +818,12 @@ def create_table( if attr_name not in attr_keys: schema['attribute_definitions'].append(attr) attr_keys.append(attr_name) - cls._get_connection().create_table( + await cls._get_connection().create_table_async( **schema ) if wait: while True: - status = cls._get_connection().describe_table() + status = await cls._get_connection().describe_table_async() if status: data = status.get(TABLE_STATUS) if data == ACTIVE: @@ -832,10 +833,10 @@ def create_table( else: raise TableError("No TableStatus returned for table") - cls.update_ttl(ignore_update_ttl_errors) + await cls.update_ttl(ignore_update_ttl_errors) @classmethod - def update_ttl(cls, ignore_update_ttl_errors: bool) -> None: + async def update_ttl(cls, ignore_update_ttl_errors: bool) -> None: """ Attempt to update the TTL on the table. Certain implementations (eg: dynalite) do not support updating TTLs and will fail. @@ -845,7 +846,7 @@ def update_ttl(cls, ignore_update_ttl_errors: bool) -> None: # Some dynamoDB implementations (eg: dynalite) do not support updating TTLs so # this will fail. It's fine for this to fail in those cases. try: - cls._get_connection().update_time_to_live(ttl_attribute.attr_name) + await cls._get_connection().update_time_to_live_async(ttl_attribute.attr_name) except Exception: if ignore_update_ttl_errors: log.info("Unable to update the TTL for {}".format(cls.Meta.table_name)) @@ -1038,7 +1039,7 @@ def _get_serialized_keys(self) -> Tuple[_KeyType, _KeyType]: return self._serialize_keys(hash_key, range_key) @classmethod - def _batch_get_page(cls, keys_to_get, consistent_read, attributes_to_get, settings: OperationSettings): + async def _batch_get_page(cls, keys_to_get, consistent_read, attributes_to_get, settings: OperationSettings): """ Returns a single page from BatchGetItem Also returns any unprocessed items @@ -1048,7 +1049,7 @@ def _batch_get_page(cls, keys_to_get, consistent_read, attributes_to_get, settin :param attributes_to_get: A list of attributes to return """ log.debug("Fetching a BatchGetItem page") - data = cls._get_connection().batch_get_item( + data = await cls._get_connection().batch_get_item_async( keys_to_get, consistent_read=consistent_read, attributes_to_get=attributes_to_get, settings=settings, ) item_data = data.get(RESPONSES).get(cls.Meta.table_name) # type: ignore diff --git a/pynamodb/pagination.py b/pynamodb/pagination.py index f8682421b..e59d3947d 100644 --- a/pynamodb/pagination.py +++ b/pynamodb/pagination.py @@ -1,5 +1,5 @@ import time -from typing import Any, Callable, Dict, Iterable, Iterator, TypeVar, Optional +from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterable, Iterator, TypeVar, Optional from pynamodb.constants import (CAMEL_COUNT, ITEMS, LAST_EVALUATED_KEY, SCANNED_COUNT, CONSUMED_CAPACITY, TOTAL, CAPACITY_UNITS) @@ -72,7 +72,7 @@ def rate_limit(self, rate_limit: float): self._rate_limit = rate_limit -class PageIterator(Iterator[_T]): +class PageIterator(AsyncIterator[_T]): """ PageIterator handles Query and Scan result pagination. @@ -81,7 +81,7 @@ class PageIterator(Iterator[_T]): """ def __init__( self, - operation: Callable, + operation: Callable[..., Coroutine[Any, Any, Any]], args: Any, kwargs: Dict[str, Any], rate_limit: Optional[float] = None, @@ -98,12 +98,12 @@ def __init__( self._rate_limiter = RateLimiter(rate_limit) self._settings = settings - def __iter__(self) -> Iterator[_T]: + def __aiter__(self) -> AsyncIterator[_T]: return self - def __next__(self) -> _T: + async def __anext__(self) -> _T: if self._last_evaluated_key is None and not self._first_iteration: - raise StopIteration() + raise StopAsyncIteration() self._first_iteration = False @@ -112,7 +112,7 @@ def __next__(self) -> _T: if self._rate_limiter: self._rate_limiter.acquire() self._kwargs['return_consumed_capacity'] = TOTAL - page = self._operation(*self._args, settings=self._settings, **self._kwargs) + page = await self._operation(*self._args, settings=self._settings, **self._kwargs) self._last_evaluated_key = page.get(LAST_EVALUATED_KEY) self._total_scanned_count += page[SCANNED_COUNT] @@ -122,9 +122,6 @@ def __next__(self) -> _T: return page - def next(self) -> _T: - return self.__next__() - @property def key_names(self) -> Iterable[str]: # If the current page has a last_evaluated_key, use it to determine key attributes @@ -152,7 +149,7 @@ def total_scanned_count(self) -> int: return self._total_scanned_count -class ResultIterator(Iterator[_T]): +class ResultIterator(AsyncIterator[_T]): """ ResultIterator handles Query and Scan item pagination. @@ -175,26 +172,26 @@ def __init__( self._limit = limit self._total_count = 0 - def _get_next_page(self) -> None: - page = next(self.page_iter) + async def _get_next_page(self) -> None: + page = await self.page_iter.__anext__() self._count = page[CAMEL_COUNT] self._items = page.get(ITEMS) # not returned if 'Select' is set to 'COUNT' self._index = 0 if self._items else self._count self._total_count += self._count - def __iter__(self) -> Iterator[_T]: + def __aiter__(self) -> AsyncIterator[_T]: return self - def __next__(self) -> _T: + async def __anext__(self) -> _T: if self._limit == 0: - raise StopIteration + raise StopAsyncIteration if self._first_iteration: self._first_iteration = False - self._get_next_page() + await self._get_next_page() while self._index == self._count: - self._get_next_page() + await self._get_next_page() item = self._items[self._index] self._index += 1 @@ -204,9 +201,6 @@ def __next__(self) -> _T: item = self._map_fn(item) return item - def next(self) -> _T: - return self.__next__() - @property def last_evaluated_key(self) -> Optional[Dict[str, Dict[str, Any]]]: if self._first_iteration or self._index == self._count: diff --git a/pynamodb/transactions.py b/pynamodb/transactions.py index 28b3409e8..cbc07b7f7 100644 --- a/pynamodb/transactions.py +++ b/pynamodb/transactions.py @@ -39,7 +39,7 @@ def __init__(self, *args, **kwargs): self._futures: List[_ModelFuture] = [] super(TransactGet, self).__init__(*args, **kwargs) - def get(self, model_cls: Type[_M], hash_key: _KeyType, range_key: Optional[_KeyType] = None) -> _ModelFuture[_M]: + async def get(self, model_cls: Type[_M], hash_key: _KeyType, range_key: Optional[_KeyType] = None) -> _ModelFuture[_M]: """ Adds the operation arguments for an item to list of models to get returns a _ModelFuture object as a placeholder @@ -49,7 +49,7 @@ def get(self, model_cls: Type[_M], hash_key: _KeyType, range_key: Optional[_KeyT :param range_key: :return: """ - operation_kwargs = model_cls.get_operation_kwargs_from_class(hash_key, range_key=range_key) + operation_kwargs = await model_cls.get_operation_kwargs_from_class(hash_key, range_key=range_key) model_future = _ModelFuture(model_cls) self._futures.append(model_future) self._get_items.append(operation_kwargs) @@ -60,8 +60,8 @@ def _update_futures(futures: List[_ModelFuture], results: List) -> None: for model, data in zip(futures, results): model.update_with_raw_data(data.get(ITEM)) - def _commit(self) -> Any: - response = self._connection.transact_get_items( + async def _commit(self) -> Any: + response = await self._connection.transact_get_items( get_items=self._get_items, return_consumed_capacity=self._return_consumed_capacity ) @@ -89,30 +89,30 @@ def __init__( self._update_items: List[Dict] = [] self._models_for_version_attribute_update: List[Any] = [] - def condition_check(self, model_cls: Type[_M], hash_key: _KeyType, range_key: Optional[_KeyType] = None, condition: Optional[Condition] = None): + async def condition_check(self, model_cls: Type[_M], hash_key: _KeyType, range_key: Optional[_KeyType] = None, condition: Optional[Condition] = None): if condition is None: raise TypeError('`condition` cannot be None') - operation_kwargs = model_cls.get_operation_kwargs_from_class( + operation_kwargs = await model_cls.get_operation_kwargs_from_class( hash_key, range_key=range_key, condition=condition ) self._condition_check_items.append(operation_kwargs) - def delete(self, model: _M, condition: Optional[Condition] = None) -> None: - operation_kwargs = model.get_delete_kwargs_from_instance(condition=condition) + async def delete(self, model: _M, condition: Optional[Condition] = None) -> None: + operation_kwargs = await model.get_delete_kwargs_from_instance(condition=condition) self._delete_items.append(operation_kwargs) - def save(self, model: _M, condition: Optional[Condition] = None, return_values: Optional[str] = None) -> None: - operation_kwargs = model.get_save_kwargs_from_instance( + async def save(self, model: _M, condition: Optional[Condition] = None, return_values: Optional[str] = None) -> None: + operation_kwargs = await model.get_save_kwargs_from_instance( condition=condition, return_values_on_condition_failure=return_values ) self._put_items.append(operation_kwargs) self._models_for_version_attribute_update.append(model) - def update(self, model: _M, actions: List[Action], condition: Optional[Condition] = None, return_values: Optional[str] = None) -> None: - operation_kwargs = model.get_update_kwargs_from_instance( + async def update(self, model: _M, actions: List[Action], condition: Optional[Condition] = None, return_values: Optional[str] = None) -> None: + operation_kwargs = await model.get_update_kwargs_from_instance( actions=actions, condition=condition, return_values_on_condition_failure=return_values @@ -120,8 +120,8 @@ def update(self, model: _M, actions: List[Action], condition: Optional[Condition self._update_items.append(operation_kwargs) self._models_for_version_attribute_update.append(model) - def _commit(self) -> Any: - response = self._connection.transact_write_items( + async def _commit(self) -> Any: + response = await self._connection.transact_write_items( condition_check_items=self._condition_check_items, delete_items=self._delete_items, put_items=self._put_items, diff --git a/setup.py b/setup.py index 868680752..d037c75a5 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ install_requires = [ - 'botocore>=1.12.54', + 'aiobotocore>=1.3.0', 'typing-extensions>=3.7; python_version<"3.8"' ]