diff --git a/docs/tutorial.rst b/docs/tutorial.rst index a8d1f45d..2f3ad2cf 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -201,6 +201,36 @@ Here is an example of customizing an attribute name: # This attribute will be called 'tn' in DynamoDB thread_name = UnicodeAttribute(null=True, attr_name='tn') +PynamoDB can also transform all the attribute names from Python's "Snake Case" +(for example "forum_name") to another naming convention, such as Camel Case ("forumName"), +or Pascal Case ("ForumName"). +Custom attribute names can still be applied to individual attributes, and take precedence +over the attribute transform. +The attribute transformation can be assigned to the model class as part of the definition. + +PynamoDB comes with these built in attribute transformations: + +* :py:class:`CamelCaseAttributeTransform ` +* :py:class:`PascalCaseAttributeTransform ` + +Here is example usage of both the attribute transformation, and custom attribute names: + +.. code-block:: python + + from pynamodb.models import Model + from pynamodb.attributes import UnicodeAttribute, CamelCaseAttributeTransform + + class Thread(Model, attribute_transform=CamelCaseAttributeTransform): + class Meta: + table_name = 'Thread' + # This attribute will be called 'forumName' in DynamoDB + forum_name = UnicodeAttribute(hash_key=True) + + # This attribute will be called 'threadName' in DynamoDB + thread_name = UnicodeAttribute(null=True) + + # This attribute will be called 'author' in DynamoDB + post_author = UnicodeAttribute(null=True, attr_name='author') PynamoDB comes with several built in attribute types for convenience, which include the following: diff --git a/pynamodb/attributes.py b/pynamodb/attributes.py index ce750bf4..6cb4cc3a 100644 --- a/pynamodb/attributes.py +++ b/pynamodb/attributes.py @@ -1,6 +1,7 @@ """ PynamoDB attributes """ +import abc import base64 import calendar import collections.abc @@ -288,22 +289,30 @@ def __new__(cls, name, bases, namespace, discriminator=None): # Defined so that the discriminator can be set in the class definition. return super().__new__(cls, name, bases, namespace) - def __init__(self, name, bases, namespace, discriminator=None): + def __init__(self, name, bases, namespace, discriminator=None, attribute_transform: Optional["AttributeTransform"] = None): super().__init__(name, bases, namespace) - AttributeContainerMeta._initialize_attributes(self, discriminator) + AttributeContainerMeta._initialize_attributes(self, discriminator, attribute_transform) @staticmethod - def _initialize_attributes(cls, discriminator_value): + def _initialize_attributes(cls, discriminator_value, attribute_transform: Optional[Type["AttributeTransform"]] = None): """ Initialize attributes on the class. """ cls._attributes = {} cls._dynamo_to_python_attrs = {} + if attribute_transform is None or issubclass(attribute_transform, AttributeTransform): + cls._attribute_transform = attribute_transform + else: + raise ValueError(f"Attribute Transform {type(attribute_transform)} is not a subclass of AttributeTransform") + for name, attribute in getmembers(cls, lambda o: isinstance(o, Attribute)): cls._attributes[name] = attribute if attribute.attr_name != name: cls._dynamo_to_python_attrs[attribute.attr_name] = name + elif cls._attribute_transform is not None: + attribute.attr_name = cls._attribute_transform.transform(name) + cls._dynamo_to_python_attrs[attribute.attr_name] = name # Register the class with the discriminator if necessary. discriminators = [name for name, attr in cls._attributes.items() if isinstance(attr, DiscriminatorAttribute)] @@ -601,6 +610,42 @@ def deserialize(self, value): return self._discriminator_map[value] +class AttributeTransform(abc.ABC): + """Base case for converting python attributes in to various cases""" + + @classmethod + @abc.abstractmethod + def transform(cls, python_attr: str) -> str: + """Transform python attribute string to desired case""" + ... + + +class CamelCaseAttributeTransform(AttributeTransform): + """Convert python attributes to camelCase""" + + @classmethod + def transform(cls, python_attr: str) -> str: + if isinstance(python_attr, str): + parts = python_attr.split("_") + return parts[0] + "".join([part.title() for part in parts[1:]]) + + else: + raise ValueError("Provided value is not a string") + + +class PascalCaseAttributeTransform(AttributeTransform): + """Convert python attributes to PascalCase""" + + @classmethod + def transform(cls, python_attr: str) -> str: + if isinstance(python_attr, str): + parts = python_attr.split("_") + return "".join([part.title() for part in parts]) + + else: + raise ValueError("Provided value is not a string") + + class BinaryAttribute(Attribute[bytes]): """ An attribute containing a binary data object (:code:`bytes`). diff --git a/pynamodb/models.py b/pynamodb/models.py index f05ed485..bc1b1351 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -36,7 +36,7 @@ from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError, InvalidStateError, PutError, \ AttributeNullError from pynamodb.attributes import ( - AttributeContainer, AttributeContainerMeta, TTLAttribute, VersionAttribute + AttributeContainer, AttributeContainerMeta, AttributeTransform, TTLAttribute, VersionAttribute ) from pynamodb.connection.table import TableConnection from pynamodb.expressions.condition import Condition @@ -200,12 +200,12 @@ class MetaModel(AttributeContainerMeta): """ Model meta class """ - def __new__(cls, name, bases, namespace, discriminator=None): + def __new__(cls, name, bases, namespace, discriminator=None, attribute_transform: Optional[AttributeTransform] = None): # Defined so that the discriminator can be set in the class definition. return super().__new__(cls, name, bases, namespace) - def __init__(self, name, bases, namespace, discriminator=None) -> None: - super().__init__(name, bases, namespace, discriminator) + def __init__(self, name, bases, namespace, discriminator=None, attribute_transform: Optional[AttributeTransform] = None) -> None: + super().__init__(name, bases, namespace, discriminator, attribute_transform) MetaModel._initialize_indexes(self) cls = cast(Type['Model'], self) for attr_name, attribute in cls.get_attributes().items(): diff --git a/tests/test_model.py b/tests/test_model.py index cd236c9a..86ab4423 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -26,7 +26,7 @@ IncludeProjection, KeysOnlyProjection, Index ) from pynamodb.attributes import ( - DiscriminatorAttribute, UnicodeAttribute, NumberAttribute, BinaryAttribute, UTCDateTimeAttribute, + CamelCaseAttributeTransform, DiscriminatorAttribute, PascalCaseAttributeTransform, UnicodeAttribute, NumberAttribute, BinaryAttribute, UTCDateTimeAttribute, UnicodeSetAttribute, NumberSetAttribute, BinarySetAttribute, MapAttribute, BooleanAttribute, ListAttribute, TTLAttribute, VersionAttribute) from .data import ( @@ -211,6 +211,34 @@ class Meta: uid_index = CustomAttrIndex() +class CamelCaseTransformAttrNameModel(Model, attribute_transform=CamelCaseAttributeTransform): + """ + Attribute names transformed in to Camel Case + """ + + class Meta: + table_name = 'CustomAttrModel' + + user_name = UnicodeAttribute(hash_key=True) + user_id = UnicodeAttribute(range_key=True) + enabled = UnicodeAttribute(null=True) + overidden_attr = UnicodeAttribute(attr_name='foo_attr', null=True) + + +class PascalCaseTransformAttrNameModel(Model, attribute_transform=PascalCaseAttributeTransform): + """ + Attribute names transformed in to Camel Case + """ + + class Meta: + table_name = 'CustomAttrModel' + + user_name = UnicodeAttribute(hash_key=True) + user_id = UnicodeAttribute(range_key=True) + enabled = UnicodeAttribute(null=True) + overidden_attr = UnicodeAttribute(attr_name='foo_attr', null=True) + + class UserModel(Model): """ A testing model @@ -687,6 +715,62 @@ def test_overridden_defaults(self): ] ) + schema = CamelCaseTransformAttrNameModel._get_schema() + self.assertListEqual( + schema['key_schema'], + [ + { + 'KeyType': 'RANGE', + 'AttributeName': 'userId' + }, + { + 'KeyType': 'HASH', + 'AttributeName': 'userName' + }, + ], + ) + self.assertListEqual( + schema['attribute_definitions'], + [ + { + 'AttributeType': 'S', + 'AttributeName': 'userId' + }, + { + 'AttributeType': 'S', + 'AttributeName': 'userName' + }, + ] + ) + + schema = PascalCaseTransformAttrNameModel._get_schema() + self.assertListEqual( + schema['key_schema'], + [ + { + 'KeyType': 'RANGE', + 'AttributeName': 'UserId' + }, + { + 'KeyType': 'HASH', + 'AttributeName': 'UserName' + }, + ], + ) + self.assertListEqual( + schema['attribute_definitions'], + [ + { + 'AttributeType': 'S', + 'AttributeName': 'UserId' + }, + { + 'AttributeType': 'S', + 'AttributeName': 'UserName' + }, + ] + ) + def test_overridden_attr_name(self): user = UserModel(custom_user_name="bob") self.assertEqual(user.custom_user_name, "bob") @@ -694,6 +778,24 @@ def test_overridden_attr_name(self): self.assertRaises(ValueError, UserModel, user_name="bob") + def test_transformed_attr_name(self): + """ + Test transformed attributes names + """ + item = CamelCaseTransformAttrNameModel('foo', 'bar', overidden_attr='test', enabled="test") + self.assertEqual(item.overidden_attr, 'test') + attrs = item.get_attributes() + self.assertEqual(attrs["user_name"].attr_name, "userName") + self.assertEqual(attrs["user_id"].attr_name, "userId") + self.assertEqual(attrs["enabled"].attr_name, "enabled") + + item = PascalCaseTransformAttrNameModel('foo', 'bar', overidden_attr='test', enabled="test") + self.assertEqual(item.overidden_attr, 'test') + attrs = item.get_attributes() + self.assertEqual(attrs["user_name"].attr_name, "UserName") + self.assertEqual(attrs["user_id"].attr_name, "UserId") + self.assertEqual(attrs["enabled"].attr_name, "Enabled") + def test_refresh(self): """ Model.refresh