Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 6 additions & 30 deletions enumfields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,32 +66,20 @@ def get_prep_value(self, value):
return None
if isinstance(value, self.enum): # Already the correct type -- fast path
return value.value
return self.enum(value).value
return self.to_python(value).value

def from_db_value(self, value, expression, connection, *args):
return self.to_python(value)

def value_to_string(self, obj):
"""
This method is needed to support proper serialization. While its name is value_to_string()
the real meaning of the method is to convert the value to some serializable format.
Since most of the enum values are strings or integers we WILL NOT convert it to string
to enable integers to be serialized natively.
"""
value = self.value_from_object(obj)
return value.value if value else None
return str(value.value) if value is not None else ''

def get_default(self):
if self.has_default():
if self.default is None:
return None

if isinstance(self.default, Enum):
return self.default

return self.enum(self.default)

return super().get_default()
default = super().get_default()
if default is not None and self.has_default():
default = self.enum(default)
return default

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
Expand Down Expand Up @@ -166,15 +154,3 @@ def validators(self):
# connection.ops.integer_field_range method.
next = super(models.IntegerField, self)
return next.validators

def get_prep_value(self, value):
if value is None:
return None

if isinstance(value, Enum):
return value.value

try:
return int(value)
except ValueError:
return self.to_python(value).value
2 changes: 2 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class MyModel(models.Model):
taste = EnumField(Taste, default=Taste.SWEET)
taste_not_editable = EnumField(Taste, default=Taste.SWEET, editable=False)
taste_null_default = EnumField(Taste, null=True, blank=True, default=None)
taste_callable_enum_default = EnumField(Taste, default=lambda: Taste.SWEET, editable=False)
taste_callable_value_default = EnumField(Taste, default=lambda: Taste.SWEET.value, editable=False)
taste_int = EnumIntegerField(Taste, default=Taste.SWEET)

default_none = EnumIntegerField(Taste, default=None, null=True, blank=True)
Expand Down
7 changes: 4 additions & 3 deletions tests/test_django_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.core.exceptions import ValidationError
from django.db import connection

import pytest
Expand All @@ -20,7 +21,7 @@ def test_field_value():
m = MyModel.objects.filter(color='r')[0]
assert m.color == Color.RED

with pytest.raises(ValueError):
with pytest.raises(ValidationError):
MyModel.objects.filter(color='xx')[0]


Expand Down Expand Up @@ -91,8 +92,8 @@ def test_serialization():
ser = PythonSerializer()
ser.serialize([m])
fields = ser.getvalue()[0]["fields"]
assert fields["color"] == m.color.value
assert fields["taste"] == m.taste.value
assert fields["color"] == "r"
assert fields["taste"] == "4"


@pytest.mark.django_db
Expand Down