Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ReturnValuesOnConditionCheckFailure for non-transactional operation #1263

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pynamodb/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,11 @@ def _make_api_call(self, operation_name: str, operation_kwargs: Dict) -> Dict:
except ClientError as e:
resp_metadata = e.response.get('ResponseMetadata', {}).get('HTTPHeaders', {})
cancellation_reasons = e.response.get('CancellationReasons', [])

item = e.response.get('Item')
botocore_props = {'Error': e.response.get('Error', {})}
if item:
botocore_props['Item'] = item

verbose_props = {
'request_id': resp_metadata.get('x-amzn-requestid', ''),
'table_name': self._get_table_name_for_error_context(operation_kwargs),
Expand Down Expand Up @@ -889,6 +892,7 @@ def put_item(
return_values: Optional[str] = None,
return_consumed_capacity: Optional[str] = None,
return_item_collection_metrics: Optional[str] = None,
return_values_on_condition_failure: Optional[str] = None,
) -> Dict:
"""
Performs the PutItem operation and returns the result
Expand All @@ -902,12 +906,14 @@ def put_item(
condition=condition,
return_values=return_values,
return_consumed_capacity=return_consumed_capacity,
return_item_collection_metrics=return_item_collection_metrics
return_item_collection_metrics=return_item_collection_metrics,
return_values_on_condition_failure=return_values_on_condition_failure
)
try:
return self.dispatch(PUT_ITEM, operation_kwargs)
except BOTOCORE_EXCEPTIONS as e:
raise PutError("Failed to put item: {}".format(e), e)
response = getattr(e, 'response', {})
raise PutError("Failed to put item: {}".format(e), e, response.get('Item'))

def _get_transact_operation_kwargs(
self,
Expand Down
2 changes: 2 additions & 0 deletions pynamodb/connection/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def put_item(
return_values: Optional[str] = None,
return_consumed_capacity: Optional[str] = None,
return_item_collection_metrics: Optional[str] = None,
return_values_on_condition_failure: Optional[str] = None,
) -> Dict:
"""
Performs the PutItem operation and returns the result
Expand All @@ -152,6 +153,7 @@ def put_item(
return_values=return_values,
return_consumed_capacity=return_consumed_capacity,
return_item_collection_metrics=return_item_collection_metrics,
return_values_on_condition_failure=return_values_on_condition_failure
)

def batch_write_item(
Expand Down
5 changes: 5 additions & 0 deletions pynamodb/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class PutError(PynamoDBConnectionError):
Raised when an item fails to be created
"""
msg = "Error putting item"
raw_values_on_condition_failure: Optional[Dict[str, Any]] = None

def __init__(self, msg: Optional[str] = None, cause: Optional[Exception] = None, raw_item: Optional[Dict[str, Any]] = None) -> None:
super().__init__(msg, cause)
self.raw_values_on_condition_failure = raw_item


class UpdateError(PynamoDBConnectionError):
Expand Down
20 changes: 17 additions & 3 deletions pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,22 @@ def update(self, actions: List[Action], condition: Optional[Condition] = None, *
self.deserialize(item_data)
return data

def save(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Dict[str, Any]:
def save(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True, return_values_on_condition_failure: Optional[str] = None) -> Dict[str, Any]:
"""
Save this object to dynamodb

:param condition: an optional Condition on which to save
:param add_version_condition: For models which have a :class:`~pynamodb.attributes.VersionAttribute`,
specifies whether only to save if the version matches the model that is currently loaded.
Set to `False` for a 'last write wins' strategy.
Regardless, the version will always be incremented to prevent "rollbacks" by concurrent :meth:`update` calls.
:param return_values_on_condition_failure: If set, then this value will be returned in error if the condition is not met.
"""
args, kwargs = self._get_save_args(condition=condition, add_version_condition=add_version_condition)
args, kwargs = self._get_save_args(
condition=condition,
add_version_condition=add_version_condition,
return_values_on_condition_failure=return_values_on_condition_failure
)
data = self._get_connection().put_item(*args, **kwargs)
self.update_local_version_attribute()
return data
Expand Down Expand Up @@ -888,7 +899,7 @@ def _get_schema(cls) -> ModelSchema:

return schema

def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Tuple[Iterable[Any], Dict[str, Any]]:
def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True, return_values_on_condition_failure: Optional[str] = None) -> Tuple[Iterable[Any], Dict[str, Any]]:
"""
Gets the proper *args, **kwargs for saving and retrieving this object

Expand All @@ -898,6 +909,7 @@ def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_c
:param add_version_condition: For models which have a :class:`~pynamodb.attributes.VersionAttribute`,
specifies whether the item should only be saved if its current version matches the expected one.
Set to `False` for a 'last-write-wins' strategy.
:param return_values_on_condition_failure: If set, then this will return the values on condition failure
"""
attribute_values = self.serialize(null_check=True)
hash_key_attribute = self._hash_key_attribute()
Expand All @@ -915,6 +927,8 @@ def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_c
condition &= version_condition
kwargs['attributes'] = attribute_values
kwargs['condition'] = condition
if return_values_on_condition_failure and return_values_on_condition_failure is not None:
kwargs['return_values_on_condition_failure'] = return_values_on_condition_failure
return args, kwargs

def _get_hash_range_key_serialized_values(self) -> Tuple[Any, Optional[Any]]:
Expand Down
38 changes: 37 additions & 1 deletion tests/integration/model_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

from datetime import datetime

from pynamodb.exceptions import PutError
from pynamodb.models import Model
from pynamodb.indexes import GlobalSecondaryIndex, AllProjection, LocalSecondaryIndex
from pynamodb.attributes import (
UnicodeAttribute, BinaryAttribute, UTCDateTimeAttribute, NumberSetAttribute, NumberAttribute,
VersionAttribute)
VersionAttribute, JSONAttribute)

import pytest

Expand Down Expand Up @@ -110,6 +111,41 @@ class Meta:
TestModel.delete_table()


@pytest.mark.ddblocal
def test_model_integration_save_return_values_on_condition_failure(ddb_url):

class TestModel(Model):
"""
A model for testing
"""
class Meta:
region = 'us-east-1'
table_name = 'pynamodb-ci'
host = ddb_url
user_id = UnicodeAttribute(hash_key=True)
created_at = UnicodeAttribute(range_key=True)
data = JSONAttribute(null=True)
version = VersionAttribute()

if TestModel.exists():
TestModel.delete_table()
TestModel.create_table(read_capacity_units=1, write_capacity_units=1, wait=True)

origin_obj = TestModel('1', '2')
origin_obj.save()
parallel_obj = TestModel.get('1', '2')
parallel_obj.data = {'foo': 'bar'}
parallel_obj.save()
# original object 1 version behind
origin_obj.data = {'foo': 'second_bar'}
with pytest.raises(PutError) as excinfo:
origin_obj.save(return_values_on_condition_failure='ALL_OLD')

old_parallel_obj = TestModel.from_raw_data(excinfo.value.raw_values_on_condition_failure)
assert old_parallel_obj.data == {'foo': 'bar'}
assert old_parallel_obj.version == 2


def test_can_inherit_version_attribute(ddb_url) -> None:

class TestModelA(Model):
Expand Down
30 changes: 30 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3265,6 +3265,36 @@ def test_model_version_attribute_save(add_version_condition: bool) -> None:

assert args == params

@pytest.mark.parametrize('return_values_on_condition_failure', [None, 'ALL_OLD', 'all_old'])
def test_model_return_values_on_condition_failure(return_values_on_condition_failure: str | None) -> None:
item = VersionedModel('test_user_name', email='[email protected]')
with patch(PATCH_METHOD) as req:
req.return_value = {}
item.save(return_values_on_condition_failure=return_values_on_condition_failure)
args = req.call_args[0][1]
params = {
'Item': {
'name': {
'S': 'test_user_name'
},
'email': {
'S': '[email protected]'
},
'version': {
'N': '1'
},
},
'ReturnConsumedCapacity': 'TOTAL',
'TableName': 'VersionedModel',
'ConditionExpression': 'attribute_not_exists (#0)',
'ExpressionAttributeNames': {'#0': 'version'}
}
if return_values_on_condition_failure is not None:
params.update({
'ReturnValuesOnConditionCheckFailure': return_values_on_condition_failure.upper(),
})

assert args == params

@pytest.mark.parametrize('add_version_condition', [True, False])
def test_version_attribute_increments_on_update(add_version_condition: bool) -> None:
Expand Down