Skip to content

Commit 7d75dbf

Browse files
author
Eric Muller
committed
Support for versioned optimistic locking a la DynamoDBMapper (pynamodb#228)
1 parent a877dd1 commit 7d75dbf

File tree

6 files changed

+294
-7
lines changed

6 files changed

+294
-7
lines changed

Diff for: pynamodb/attributes.py

+7
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,13 @@ def deserialize(self, value):
494494
return json.loads(value)
495495

496496

497+
class VersionAttribute(NumberAttribute):
498+
"""
499+
A version attribute
500+
"""
501+
null = True # should I be doing this?
502+
503+
497504
class TTLAttribute(Attribute):
498505
"""
499506
A time-to-live attribute that signifies when the item expires and can be automatically deleted.

Diff for: pynamodb/attributes.pyi

+5
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ class NumberAttribute(Attribute[float]):
119119
@overload
120120
def __get__(self, instance: Any, owner: Any) -> float: ...
121121

122+
class VersionAttribute(Attribute[float]):
123+
@overload
124+
def __get__(self: _A, instance: None, owner: Any) -> _A: ...
125+
@overload
126+
def __get__(self, instance: Any, owner: Any) -> float: ...
122127

123128
class TTLAttribute(Attribute[datetime]):
124129
@overload

Diff for: pynamodb/models.py

+60-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
from six import add_metaclass
1212
from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError, InvalidStateError
13-
from pynamodb.attributes import Attribute, AttributeContainer, AttributeContainerMeta, MapAttribute, TTLAttribute
13+
from pynamodb.attributes import (
14+
Attribute, AttributeContainer, AttributeContainerMeta, MapAttribute, TTLAttribute, VersionAttribute
15+
)
1416
from pynamodb.connection.table import TableConnection
1517
from pynamodb.connection.util import pythonic
1618
from pynamodb.types import HASH, RANGE
@@ -209,6 +211,14 @@ def __init__(cls, name, bases, attrs):
209211
if len(ttl_attr_names) > 1:
210212
raise ValueError("The model has more than one TTL attribute: {}".format(", ".join(ttl_attr_names)))
211213

214+
version_attribute_names = [
215+
name for name, attr_obj in attrs.items() if isinstance(attr_obj, VersionAttribute)
216+
]
217+
if len(version_attribute_names) > 1:
218+
raise ValueError(
219+
"The model has more than one Version attribute: {}".format(", ".join(version_attribute_names))
220+
)
221+
212222
if META_CLASS_NAME not in attrs:
213223
setattr(cls, META_CLASS_NAME, DefaultMeta)
214224

@@ -334,6 +344,10 @@ def delete(self, condition=None):
334344
Deletes this object from dynamodb
335345
"""
336346
args, kwargs = self._get_save_args(attributes=False, null_check=False)
347+
version_condition = self._handle_version_attribute(kwargs)
348+
if version_condition is not None:
349+
condition &= version_condition
350+
337351
kwargs.update(condition=condition)
338352
return self._get_connection().delete_item(*args, **kwargs)
339353

@@ -348,6 +362,9 @@ def update(self, actions, condition=None):
348362
raise TypeError("the value of `actions` is expected to be a non-empty list")
349363

350364
args, save_kwargs = self._get_save_args(null_check=False)
365+
version_condition = self._handle_version_attribute(save_kwargs, actions=actions)
366+
if version_condition is not None:
367+
condition &= version_condition
351368
kwargs = {
352369
pythonic(RETURN_VALUES): ALL_NEW,
353370
}
@@ -371,6 +388,9 @@ def save(self, condition=None):
371388
Save this object to dynamodb
372389
"""
373390
args, kwargs = self._get_save_args()
391+
version_condition = self._handle_version_attribute(serialized_attributes=kwargs)
392+
if version_condition is not None:
393+
condition &= version_condition
374394
kwargs.update(condition=condition)
375395
return self._get_connection().put_item(*args, **kwargs)
376396

@@ -395,6 +415,10 @@ def get_operation_kwargs_from_instance(self,
395415
return_values_on_condition_failure=None):
396416
is_update = actions is not None
397417
args, save_kwargs = self._get_save_args(null_check=not is_update)
418+
version_condition = self._handle_version_attribute(serialized_attributes=save_kwargs,
419+
actions=actions)
420+
if version_condition is not None:
421+
condition &= version_condition
398422
kwargs = dict(
399423
key=key,
400424
actions=actions,
@@ -872,6 +896,7 @@ def _get_save_args(self, attributes=True, null_check=True):
872896
"""
873897
kwargs = {}
874898
serialized = self._serialize(null_check=null_check)
899+
875900
hash_key = serialized.get(HASH)
876901
range_key = serialized.get(RANGE, None)
877902
args = (hash_key, )
@@ -881,6 +906,40 @@ def _get_save_args(self, attributes=True, null_check=True):
881906
kwargs[pythonic(ATTRIBUTES)] = serialized[pythonic(ATTRIBUTES)]
882907
return args, kwargs
883908

909+
def _handle_version_attribute(self, serialized_attributes, actions=None):
910+
"""
911+
Handles modifying the request to set or increment the version attribute.
912+
913+
:param serialized_attributes: A dictionary mapping attribute names to serialized values.
914+
:param actions: A non-empty list when performing an update, otherwise None.
915+
"""
916+
version_condition = None
917+
918+
for name, attr in self.get_attributes().items():
919+
value = getattr(self, name)
920+
if isinstance(attr, VersionAttribute):
921+
# We don't modify the attribute except in the serialized payload so that
922+
# the local object is not modified on failure.
923+
if not value:
924+
version_condition = attr.does_not_exist()
925+
if actions:
926+
actions.append(attr.set(1))
927+
elif pythonic(ATTRIBUTES) in serialized_attributes:
928+
serialized_attributes[pythonic(ATTRIBUTES)][attr.attr_name] = self._serialize_value(
929+
attr, 1, null_check=True
930+
)
931+
else:
932+
version_condition = attr == value
933+
if actions:
934+
actions.append(attr.add(1))
935+
elif pythonic(ATTRIBUTES) in serialized_attributes:
936+
serialized_attributes[pythonic(ATTRIBUTES)][attr.attr_name] = self._serialize_value(
937+
attr, value + 1, null_check=True
938+
)
939+
940+
break
941+
return version_condition
942+
884943
@classmethod
885944
def _hash_key_attribute(cls):
886945
"""
@@ -1002,7 +1061,6 @@ def _serialize(self, attr_map=False, null_check=True):
10021061
if isinstance(value, MapAttribute):
10031062
if not value.validate():
10041063
raise ValueError("Attribute '{}' is not correctly typed".format(attr.attr_name))
1005-
10061064
serialized = self._serialize_value(attr, value, null_check)
10071065
if NULL in serialized:
10081066
continue

Diff for: tests/data.py

+35
Original file line numberDiff line numberDiff line change
@@ -1412,3 +1412,38 @@
14121412
"TableStatus": "ACTIVE"
14131413
}
14141414
}
1415+
1416+
VERSIONED_TABLE_DATA = {
1417+
"Table": {
1418+
"AttributeDefinitions": [
1419+
{
1420+
"AttributeName": "name",
1421+
"AttributeType": "S"
1422+
},
1423+
{
1424+
"AttributeName": "email",
1425+
"AttributeType": "S"
1426+
},
1427+
{
1428+
"AttributeName": "version",
1429+
"AttributeType": "N"
1430+
}
1431+
],
1432+
"CreationDateTime": 1.363729002358E9,
1433+
"ItemCount": 42,
1434+
"KeySchema": [
1435+
{
1436+
"AttributeName": "name",
1437+
"KeyType": "HASH"
1438+
},
1439+
],
1440+
"ProvisionedThroughput": {
1441+
"NumberOfDecreasesToday": 0,
1442+
"ReadCapacityUnits": 5,
1443+
"WriteCapacityUnits": 5
1444+
},
1445+
"TableName": "VersionedModel",
1446+
"TableSizeBytes": 0,
1447+
"TableStatus": "ACTIVE"
1448+
}
1449+
}

Diff for: tests/test_attributes.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
BinarySetAttribute, BinaryAttribute, NumberSetAttribute, NumberAttribute,
1919
UnicodeAttribute, UnicodeSetAttribute, UTCDateTimeAttribute, BooleanAttribute, MapAttribute,
2020
ListAttribute, JSONAttribute, TTLAttribute, _get_value_for_deserialize, _fast_parse_utc_datestring,
21-
)
21+
VersionAttribute)
2222
from pynamodb.constants import (
2323
DATETIME_FORMAT, DEFAULT_ENCODING, NUMBER, STRING, STRING_SET, NUMBER_SET, BINARY_SET,
2424
BINARY, BOOLEAN,
@@ -999,3 +999,18 @@ def __eq__(self, other):
999999
assert deserialized == inp
10001000
assert serialize_mock.call_args_list == [call(1), call(2)]
10011001
assert deserialize_mock.call_args_list == [call('1'), call('2')]
1002+
1003+
1004+
class TestVersionAttribute:
1005+
def test_serialize(self):
1006+
attr = VersionAttribute()
1007+
assert attr.attr_type == NUMBER
1008+
assert attr.serialize(3.141) == '3.141'
1009+
assert attr.serialize(1) == '1'
1010+
assert attr.serialize(12345678909876543211234234324234) == '12345678909876543211234234324234'
1011+
1012+
def test_deserialize(self):
1013+
attr = VersionAttribute()
1014+
assert attr.deserialize('1') == 1
1015+
assert attr.deserialize('3.141') == 3.141
1016+
assert attr.deserialize('12345678909876543211234234324234') == 12345678909876543211234234324234

0 commit comments

Comments
 (0)