diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 0b70cc1f..79c77b40 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -341,8 +341,11 @@ def _make_api_call(self, operation_name: str, operation_kwargs: Dict) -> Dict: except ClientError as e: resp_metadata = e.response.get('ResponseMetadata', {}).get('HTTPHeaders', {}) cancellation_reasons = e.response.get('CancellationReasons', []) - + item = e.response.get('Item') botocore_props = {'Error': e.response.get('Error', {})} + if item: + botocore_props['Item'] = item + verbose_props = { 'request_id': resp_metadata.get('x-amzn-requestid', ''), 'table_name': self._get_table_name_for_error_context(operation_kwargs), @@ -889,6 +892,7 @@ def put_item( return_values: Optional[str] = None, return_consumed_capacity: Optional[str] = None, return_item_collection_metrics: Optional[str] = None, + return_values_on_condition_failure: Optional[str] = None, ) -> Dict: """ Performs the PutItem operation and returns the result @@ -902,12 +906,14 @@ def put_item( condition=condition, return_values=return_values, return_consumed_capacity=return_consumed_capacity, - return_item_collection_metrics=return_item_collection_metrics + return_item_collection_metrics=return_item_collection_metrics, + return_values_on_condition_failure=return_values_on_condition_failure ) try: return self.dispatch(PUT_ITEM, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: - raise PutError("Failed to put item: {}".format(e), e) + response = getattr(e, 'response', {}) + raise PutError("Failed to put item: {}".format(e), e, response.get('Item')) def _get_transact_operation_kwargs( self, diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 5e70ba5c..8ebdf710 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -139,6 +139,7 @@ def put_item( return_values: Optional[str] = None, return_consumed_capacity: Optional[str] = None, return_item_collection_metrics: Optional[str] = None, + return_values_on_condition_failure: Optional[str] = None, ) -> Dict: """ Performs the PutItem operation and returns the result @@ -152,6 +153,7 @@ def put_item( return_values=return_values, return_consumed_capacity=return_consumed_capacity, return_item_collection_metrics=return_item_collection_metrics, + return_values_on_condition_failure=return_values_on_condition_failure ) def batch_write_item( diff --git a/pynamodb/exceptions.py b/pynamodb/exceptions.py index 822230e3..96571d4f 100644 --- a/pynamodb/exceptions.py +++ b/pynamodb/exceptions.py @@ -82,6 +82,11 @@ class PutError(PynamoDBConnectionError): Raised when an item fails to be created """ msg = "Error putting item" + raw_values_on_condition_failure: Optional[Dict[str, Any]] = None + + def __init__(self, msg: Optional[str] = None, cause: Optional[Exception] = None, raw_item: Optional[Dict[str, Any]] = None) -> None: + super().__init__(msg, cause) + self.raw_values_on_condition_failure = raw_item class UpdateError(PynamoDBConnectionError): diff --git a/pynamodb/models.py b/pynamodb/models.py index 569a9551..fc8a1103 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -441,11 +441,22 @@ def update(self, actions: List[Action], condition: Optional[Condition] = None, * self.deserialize(item_data) return data - def save(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Dict[str, Any]: + def save(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True, return_values_on_condition_failure: Optional[str] = None) -> Dict[str, Any]: """ Save this object to dynamodb + + :param condition: an optional Condition on which to save + :param add_version_condition: For models which have a :class:`~pynamodb.attributes.VersionAttribute`, + specifies whether only to save if the version matches the model that is currently loaded. + Set to `False` for a 'last write wins' strategy. + Regardless, the version will always be incremented to prevent "rollbacks" by concurrent :meth:`update` calls. + :param return_values_on_condition_failure: If set, then this value will be returned in error if the condition is not met. """ - args, kwargs = self._get_save_args(condition=condition, add_version_condition=add_version_condition) + args, kwargs = self._get_save_args( + condition=condition, + add_version_condition=add_version_condition, + return_values_on_condition_failure=return_values_on_condition_failure + ) data = self._get_connection().put_item(*args, **kwargs) self.update_local_version_attribute() return data @@ -888,7 +899,7 @@ def _get_schema(cls) -> ModelSchema: return schema - def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Tuple[Iterable[Any], Dict[str, Any]]: + def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True, return_values_on_condition_failure: Optional[str] = None) -> Tuple[Iterable[Any], Dict[str, Any]]: """ Gets the proper *args, **kwargs for saving and retrieving this object @@ -898,6 +909,7 @@ def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_c :param add_version_condition: For models which have a :class:`~pynamodb.attributes.VersionAttribute`, specifies whether the item should only be saved if its current version matches the expected one. Set to `False` for a 'last-write-wins' strategy. + :param return_values_on_condition_failure: If set, then this will return the values on condition failure """ attribute_values = self.serialize(null_check=True) hash_key_attribute = self._hash_key_attribute() @@ -915,6 +927,8 @@ def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_c condition &= version_condition kwargs['attributes'] = attribute_values kwargs['condition'] = condition + if return_values_on_condition_failure and return_values_on_condition_failure is not None: + kwargs['return_values_on_condition_failure'] = return_values_on_condition_failure return args, kwargs def _get_hash_range_key_serialized_values(self) -> Tuple[Any, Optional[Any]]: diff --git a/tests/integration/model_integration_test.py b/tests/integration/model_integration_test.py index 9a54cccc..12c51843 100644 --- a/tests/integration/model_integration_test.py +++ b/tests/integration/model_integration_test.py @@ -4,11 +4,12 @@ from datetime import datetime +from pynamodb.exceptions import PutError from pynamodb.models import Model from pynamodb.indexes import GlobalSecondaryIndex, AllProjection, LocalSecondaryIndex from pynamodb.attributes import ( UnicodeAttribute, BinaryAttribute, UTCDateTimeAttribute, NumberSetAttribute, NumberAttribute, - VersionAttribute) + VersionAttribute, JSONAttribute) import pytest @@ -110,6 +111,41 @@ class Meta: TestModel.delete_table() +@pytest.mark.ddblocal +def test_model_integration_save_return_values_on_condition_failure(ddb_url): + + class TestModel(Model): + """ + A model for testing + """ + class Meta: + region = 'us-east-1' + table_name = 'pynamodb-ci' + host = ddb_url + user_id = UnicodeAttribute(hash_key=True) + created_at = UnicodeAttribute(range_key=True) + data = JSONAttribute(null=True) + version = VersionAttribute() + + if TestModel.exists(): + TestModel.delete_table() + TestModel.create_table(read_capacity_units=1, write_capacity_units=1, wait=True) + + origin_obj = TestModel('1', '2') + origin_obj.save() + parallel_obj = TestModel.get('1', '2') + parallel_obj.data = {'foo': 'bar'} + parallel_obj.save() + # original object 1 version behind + origin_obj.data = {'foo': 'second_bar'} + with pytest.raises(PutError) as excinfo: + origin_obj.save(return_values_on_condition_failure='ALL_OLD') + + old_parallel_obj = TestModel.from_raw_data(excinfo.value.raw_values_on_condition_failure) + assert old_parallel_obj.data == {'foo': 'bar'} + assert old_parallel_obj.version == 2 + + def test_can_inherit_version_attribute(ddb_url) -> None: class TestModelA(Model): diff --git a/tests/test_model.py b/tests/test_model.py index 9bd12db6..8d5fae12 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3265,6 +3265,36 @@ def test_model_version_attribute_save(add_version_condition: bool) -> None: assert args == params +@pytest.mark.parametrize('return_values_on_condition_failure', [None, 'ALL_OLD', 'all_old']) +def test_model_return_values_on_condition_failure(return_values_on_condition_failure: str | None) -> None: + item = VersionedModel('test_user_name', email='test_user@email.com') + with patch(PATCH_METHOD) as req: + req.return_value = {} + item.save(return_values_on_condition_failure=return_values_on_condition_failure) + args = req.call_args[0][1] + params = { + 'Item': { + 'name': { + 'S': 'test_user_name' + }, + 'email': { + 'S': 'test_user@email.com' + }, + 'version': { + 'N': '1' + }, + }, + 'ReturnConsumedCapacity': 'TOTAL', + 'TableName': 'VersionedModel', + 'ConditionExpression': 'attribute_not_exists (#0)', + 'ExpressionAttributeNames': {'#0': 'version'} + } + if return_values_on_condition_failure is not None: + params.update({ + 'ReturnValuesOnConditionCheckFailure': return_values_on_condition_failure.upper(), + }) + + assert args == params @pytest.mark.parametrize('add_version_condition', [True, False]) def test_version_attribute_increments_on_update(add_version_condition: bool) -> None: