Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Model Attribute Transformation #1201

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,36 @@ Here is an example of customizing an attribute name:
# This attribute will be called 'tn' in DynamoDB
thread_name = UnicodeAttribute(null=True, attr_name='tn')

PynamoDB can also transform all the attribute names from Python's "Snake Case"
(for example "forum_name") to another naming convention, such as Camel Case ("forumName"),
or Pascal Case ("ForumName").
Custom attribute names can still be applied to individual attributes, and take precedence
over the attribute transform.
The attribute transformation can be assigned to the model class as part of the definition.

PynamoDB comes with these built in attribute transformations:

* :py:class:`CamelCaseAttributeTransform <pynamodb.attributes.CamelCaseAttributeTransform>`
* :py:class:`PascalCaseAttributeTransform <pynamodb.attributes.PascalCaseAttributeTransform>`

Here is example usage of both the attribute transformation, and custom attribute names:

.. code-block:: python

from pynamodb.models import Model
from pynamodb.attributes import UnicodeAttribute, CamelCaseAttributeTransform

class Thread(Model, attribute_transform=CamelCaseAttributeTransform):
class Meta:
table_name = 'Thread'
# This attribute will be called 'forumName' in DynamoDB
forum_name = UnicodeAttribute(hash_key=True)

# This attribute will be called 'threadName' in DynamoDB
thread_name = UnicodeAttribute(null=True)

# This attribute will be called 'author' in DynamoDB
post_author = UnicodeAttribute(null=True, attr_name='author')

PynamoDB comes with several built in attribute types for convenience, which include the following:

Expand Down
51 changes: 48 additions & 3 deletions pynamodb/attributes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
PynamoDB attributes
"""
import abc
import base64
import calendar
import collections.abc
Expand Down Expand Up @@ -288,22 +289,30 @@ def __new__(cls, name, bases, namespace, discriminator=None):
# Defined so that the discriminator can be set in the class definition.
return super().__new__(cls, name, bases, namespace)

def __init__(self, name, bases, namespace, discriminator=None):
def __init__(self, name, bases, namespace, discriminator=None, attribute_transform: Optional["AttributeTransform"] = None):
super().__init__(name, bases, namespace)
AttributeContainerMeta._initialize_attributes(self, discriminator)
AttributeContainerMeta._initialize_attributes(self, discriminator, attribute_transform)

@staticmethod
def _initialize_attributes(cls, discriminator_value):
def _initialize_attributes(cls, discriminator_value, attribute_transform: Optional[Type["AttributeTransform"]] = None):
"""
Initialize attributes on the class.
"""
cls._attributes = {}
cls._dynamo_to_python_attrs = {}

if attribute_transform is None or issubclass(attribute_transform, AttributeTransform):
cls._attribute_transform = attribute_transform
else:
raise ValueError(f"Attribute Transform {type(attribute_transform)} is not a subclass of AttributeTransform")

for name, attribute in getmembers(cls, lambda o: isinstance(o, Attribute)):
cls._attributes[name] = attribute
if attribute.attr_name != name:
cls._dynamo_to_python_attrs[attribute.attr_name] = name
elif cls._attribute_transform is not None:
attribute.attr_name = cls._attribute_transform.transform(name)
cls._dynamo_to_python_attrs[attribute.attr_name] = name

# Register the class with the discriminator if necessary.
discriminators = [name for name, attr in cls._attributes.items() if isinstance(attr, DiscriminatorAttribute)]
Expand Down Expand Up @@ -601,6 +610,42 @@ def deserialize(self, value):
return self._discriminator_map[value]


class AttributeTransform(abc.ABC):
"""Base case for converting python attributes in to various cases"""

@classmethod
@abc.abstractmethod
def transform(cls, python_attr: str) -> str:
"""Transform python attribute string to desired case"""
...


class CamelCaseAttributeTransform(AttributeTransform):
"""Convert python attributes to camelCase"""

@classmethod
def transform(cls, python_attr: str) -> str:
if isinstance(python_attr, str):
parts = python_attr.split("_")
return parts[0] + "".join([part.title() for part in parts[1:]])

else:
raise ValueError("Provided value is not a string")


class PascalCaseAttributeTransform(AttributeTransform):
"""Convert python attributes to PascalCase"""

@classmethod
def transform(cls, python_attr: str) -> str:
if isinstance(python_attr, str):
parts = python_attr.split("_")
return "".join([part.title() for part in parts])

else:
raise ValueError("Provided value is not a string")


class BinaryAttribute(Attribute[bytes]):
"""
An attribute containing a binary data object (:code:`bytes`).
Expand Down
8 changes: 4 additions & 4 deletions pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError, InvalidStateError, PutError, \
AttributeNullError
from pynamodb.attributes import (
AttributeContainer, AttributeContainerMeta, TTLAttribute, VersionAttribute
AttributeContainer, AttributeContainerMeta, AttributeTransform, TTLAttribute, VersionAttribute
)
from pynamodb.connection.table import TableConnection
from pynamodb.expressions.condition import Condition
Expand Down Expand Up @@ -200,12 +200,12 @@ class MetaModel(AttributeContainerMeta):
"""
Model meta class
"""
def __new__(cls, name, bases, namespace, discriminator=None):
def __new__(cls, name, bases, namespace, discriminator=None, attribute_transform: Optional[AttributeTransform] = None):
# Defined so that the discriminator can be set in the class definition.
return super().__new__(cls, name, bases, namespace)

def __init__(self, name, bases, namespace, discriminator=None) -> None:
super().__init__(name, bases, namespace, discriminator)
def __init__(self, name, bases, namespace, discriminator=None, attribute_transform: Optional[AttributeTransform] = None) -> None:
super().__init__(name, bases, namespace, discriminator, attribute_transform)
MetaModel._initialize_indexes(self)
cls = cast(Type['Model'], self)
for attr_name, attribute in cls.get_attributes().items():
Expand Down
104 changes: 103 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
IncludeProjection, KeysOnlyProjection, Index
)
from pynamodb.attributes import (
DiscriminatorAttribute, UnicodeAttribute, NumberAttribute, BinaryAttribute, UTCDateTimeAttribute,
CamelCaseAttributeTransform, DiscriminatorAttribute, PascalCaseAttributeTransform, UnicodeAttribute, NumberAttribute, BinaryAttribute, UTCDateTimeAttribute,
UnicodeSetAttribute, NumberSetAttribute, BinarySetAttribute, MapAttribute,
BooleanAttribute, ListAttribute, TTLAttribute, VersionAttribute)
from .data import (
Expand Down Expand Up @@ -211,6 +211,34 @@ class Meta:
uid_index = CustomAttrIndex()


class CamelCaseTransformAttrNameModel(Model, attribute_transform=CamelCaseAttributeTransform):
"""
Attribute names transformed in to Camel Case
"""

class Meta:
table_name = 'CustomAttrModel'

user_name = UnicodeAttribute(hash_key=True)
user_id = UnicodeAttribute(range_key=True)
enabled = UnicodeAttribute(null=True)
overidden_attr = UnicodeAttribute(attr_name='foo_attr', null=True)


class PascalCaseTransformAttrNameModel(Model, attribute_transform=PascalCaseAttributeTransform):
"""
Attribute names transformed in to Camel Case
"""

class Meta:
table_name = 'CustomAttrModel'

user_name = UnicodeAttribute(hash_key=True)
user_id = UnicodeAttribute(range_key=True)
enabled = UnicodeAttribute(null=True)
overidden_attr = UnicodeAttribute(attr_name='foo_attr', null=True)


class UserModel(Model):
"""
A testing model
Expand Down Expand Up @@ -687,13 +715,87 @@ def test_overridden_defaults(self):
]
)

schema = CamelCaseTransformAttrNameModel._get_schema()
self.assertListEqual(
schema['key_schema'],
[
{
'KeyType': 'RANGE',
'AttributeName': 'userId'
},
{
'KeyType': 'HASH',
'AttributeName': 'userName'
},
],
)
self.assertListEqual(
schema['attribute_definitions'],
[
{
'AttributeType': 'S',
'AttributeName': 'userId'
},
{
'AttributeType': 'S',
'AttributeName': 'userName'
},
]
)

schema = PascalCaseTransformAttrNameModel._get_schema()
self.assertListEqual(
schema['key_schema'],
[
{
'KeyType': 'RANGE',
'AttributeName': 'UserId'
},
{
'KeyType': 'HASH',
'AttributeName': 'UserName'
},
],
)
self.assertListEqual(
schema['attribute_definitions'],
[
{
'AttributeType': 'S',
'AttributeName': 'UserId'
},
{
'AttributeType': 'S',
'AttributeName': 'UserName'
},
]
)

def test_overridden_attr_name(self):
user = UserModel(custom_user_name="bob")
self.assertEqual(user.custom_user_name, "bob")
self.assertRaises(AttributeError, getattr, user, "user_name")

self.assertRaises(ValueError, UserModel, user_name="bob")

def test_transformed_attr_name(self):
"""
Test transformed attributes names
"""
item = CamelCaseTransformAttrNameModel('foo', 'bar', overidden_attr='test', enabled="test")
self.assertEqual(item.overidden_attr, 'test')
attrs = item.get_attributes()
self.assertEqual(attrs["user_name"].attr_name, "userName")
self.assertEqual(attrs["user_id"].attr_name, "userId")
self.assertEqual(attrs["enabled"].attr_name, "enabled")

item = PascalCaseTransformAttrNameModel('foo', 'bar', overidden_attr='test', enabled="test")
self.assertEqual(item.overidden_attr, 'test')
attrs = item.get_attributes()
self.assertEqual(attrs["user_name"].attr_name, "UserName")
self.assertEqual(attrs["user_id"].attr_name, "UserId")
self.assertEqual(attrs["enabled"].attr_name, "Enabled")

def test_refresh(self):
"""
Model.refresh
Expand Down