Skip to content

Commit 8e85030

Browse files
Merge pull request #249 from phenobarbital/new-exports
added validation for typing.Literal
2 parents 137c414 + 65833ec commit 8e85030

File tree

5 files changed

+394
-5
lines changed

5 files changed

+394
-5
lines changed

datamodel/converters.pyx

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (C) 2018-present Jesus Lara
33
#
44
import re
5-
from typing import get_args, get_origin, Union, Optional, List, NewType
5+
from typing import get_args, get_origin, Union, Optional, List, NewType, Literal
66
from collections.abc import Sequence, Mapping, Callable, Awaitable
77
import types
88
from dataclasses import _MISSING_TYPE, _FIELDS, fields
@@ -80,13 +80,21 @@ cpdef str to_string(object obj):
8080
return obj.decode()
8181
except UnicodeDecodeError as e:
8282
raise ValueError(f"Cannot decode bytes: {e}") from e
83+
if isinstance(obj, (int, float, Decimal)):
84+
# its a number
85+
return str(obj)
8386
if callable(obj):
8487
# its a function callable returning a value
8588
try:
86-
return str(obj())
89+
val = obj()
90+
# Recursively call to_string on that result:
91+
return to_string(val)
8792
except Exception:
8893
pass
89-
return str(obj)
94+
# For any other arbitrary type, explicitly fail:
95+
raise ValueError(
96+
f"Cannot convert object of type {type(obj).__name__} to string."
97+
)
9098

9199
cpdef object to_uuid(object obj):
92100
"""Returns a UUID version of a str column.
@@ -1123,6 +1131,36 @@ cdef object _parse_typing(
11231131
# fallback to builtin parse
11241132
return _parse_builtin_type(field, T, data, encoder)
11251133

1134+
cdef object _parse_literal_type(
1135+
object field,
1136+
object T,
1137+
object data,
1138+
object encoder
1139+
):
1140+
"""
1141+
_parse_literal_type parses a typing.Literal[...] annotation.
1142+
1143+
:param field: A Field object (or similar) containing metadata
1144+
:param T: The full annotated type (e.g. typing.Literal['text/plain', 'text/html']).
1145+
:param data: The input value to check.
1146+
:param encoder: Optional encoder (not usually used for literal).
1147+
:return: Returns 'data' if it matches one of the literal choices, otherwise raises ValueError.
1148+
"""
1149+
1150+
# Each element in `targs` is a valid literal value, e.g. a string, int, etc.
1151+
# If data is exactly in that set, it's valid.
1152+
cdef tuple targs = field.args
1153+
cdef tuple i
1154+
for arg in targs:
1155+
if data == arg:
1156+
return data
1157+
1158+
# If we get here, data didn't match any literal value
1159+
raise ValueError(
1160+
f"Literal parse error for field '{field.name}': "
1161+
f"value={data!r} is not one of {targs}"
1162+
)
1163+
11261164
cdef object _handle_dataclass_type(
11271165
object field,
11281166
str name,
@@ -1386,6 +1424,10 @@ cpdef dict processing_fields(object obj, list columns):
13861424
# means that is_dataclass(T)
13871425
newval = _handle_dataclass_type(None, name, value, _type, as_objects, None)
13881426
obj.__dict__[name] = newval
1427+
elif f.origin is Literal:
1428+
# e.g. Literal[...]
1429+
newval = _parse_literal_type(f, _type, value, _encoder)
1430+
obj.__dict__[name] = newval
13891431
elif f.origin is list:
13901432
# Other typical case is when is a List of primitives.
13911433
if f._inner_priv:

datamodel/validation.pyx

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# cython: language_level=3, embedsignature=True, initializedcheck=False
22
# Copyright (C) 2018-present Jesus Lara
33
#
4-
from typing import get_args, get_origin, Union, Optional
4+
from typing import get_args, get_origin, Union, Optional, Literal
55
from collections.abc import Callable, Awaitable
66
import typing
77
import asyncio
@@ -263,6 +263,7 @@ cpdef dict _validation(
263263
cdef bint _valid = False
264264
cdef object field_meta = F.metadata
265265
cdef dict error = {}
266+
cdef list allowed_values = []
266267

267268
if not annotated_type:
268269
annotated_type = F.type
@@ -307,6 +308,16 @@ cpdef dict _validation(
307308
return {}
308309
elif field_type == 'type':
309310
return validate_type(F, name, value, annotated_type, val_type)
311+
elif F.origin is Literal:
312+
allowed_values = list(F.args)
313+
if value not in allowed_values:
314+
return _create_error(
315+
name,
316+
value,
317+
f"Invalid value for {annotated_type}.{name}, expected one of {allowed_values}",
318+
val_type,
319+
annotated_type
320+
)
310321
elif field_type == 'typing' or hasattr(annotated_type, '__module__') and annotated_type.__module__ == 'typing':
311322
if F.origin is tuple:
312323
# Check if we are in the homogeneous case: Tuple[T, ...]

datamodel/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
'simple library based on python +3.8 to use Dataclass-syntax'
77
'for interacting with Data'
88
)
9-
__version__ = '0.10.8'
9+
__version__ = '0.10.9'
1010
__copyright__ = 'Copyright (c) 2020-2024 Jesus Lara'
1111
__author__ = 'Jesus Lara'
1212
__author_email__ = '[email protected]'

examples/test_notify.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from typing import Union, Optional, Any, Literal
2+
import uuid
3+
from datetime import datetime
4+
from pathlib import Path
5+
from datamodel import BaseModel, Field
6+
7+
8+
CONTENT_TYPES = [
9+
"text/plain",
10+
"text/html",
11+
"multipart/alternative",
12+
"application/json"
13+
]
14+
15+
def auto_uuid(*args, **kwargs): # pylint: disable=W0613
16+
return uuid.uuid4()
17+
18+
19+
def now():
20+
return datetime.now()
21+
22+
class Account(BaseModel):
23+
"""
24+
Attributes for using a Provider by an User (Actor)
25+
"""
26+
27+
provider: str = Field(required=True, default="dummy")
28+
enabled: bool = Field(required=True, default=True)
29+
address: Union[str, list[str]] = Field(required=False, default_factory=list)
30+
number: Union[str, list[str]] = Field(required=False, default_factory=list)
31+
userid: str = Field(required=False, default="")
32+
attributes: dict = Field(required=False, default_factory=dict)
33+
34+
def set_address(self, address: Union[str, list[str]]):
35+
self.address = [address] if isinstance(address, str) else address
36+
37+
38+
class Actor(BaseModel):
39+
"""
40+
Basic Actor (meta-definition), can be an Sender or a Recipient
41+
"""
42+
43+
userid: uuid.UUID = Field(required=False, primary_key=True, default=auto_uuid)
44+
name: str
45+
account: Optional[Account]
46+
accounts: Optional[list[Account]]
47+
48+
def __str__(self) -> str:
49+
return f"<{self.name}: {self.userid}>"
50+
51+
class Message(BaseModel):
52+
"""
53+
Message.
54+
Base-class for Message blocks for Notify
55+
TODO:
56+
* template needs a factory function to find a jinja processor
57+
*
58+
"""
59+
60+
name: str = Field(required=True, default=auto_uuid)
61+
body: Union[str, dict] = Field(default=None)
62+
content: str = Field(required=False, default="")
63+
sent: datetime = Field(required=False, default=now)
64+
template: Path
65+
66+
class Attachment(BaseModel):
67+
"""Attachement.
68+
69+
an Attachment is any document attached to a message.
70+
"""
71+
72+
name: str = Field(required=True)
73+
content: Any = None
74+
content_type: str
75+
type: str
76+
77+
class BlockMessage(Message):
78+
"""
79+
BlockMessage.
80+
Class for Message Notifications
81+
TODO:
82+
* template needs a factory function to find a jinja processor
83+
*
84+
"""
85+
86+
sender: Union[Actor, list[Actor]] = Field(required=False)
87+
recipient: Union[Actor, list[Actor]] = Field(required=False)
88+
content_type: Literal[
89+
"text/plain",
90+
"text/html",
91+
"multipart/alternative",
92+
"application/json"
93+
] = Field(default="text/plain")
94+
attachments: list[Attachment] = Field(default_factory=list)
95+
flags: list[str]
96+
97+
def test_actor_valid():
98+
a = Actor(
99+
name="Alice",
100+
account=Account(provider="prov", enabled=True)
101+
)
102+
assert a.name == "Alice"
103+
assert a.account.provider == "prov"
104+
# accounts can be None or list[Account]
105+
a2 = Actor(
106+
name="Bob",
107+
accounts=[Account(provider="prov2")]
108+
)
109+
assert a2.accounts[0].provider == "prov2"
110+
111+
def test_actor_invalid():
112+
# Provide a wrong type for 'name'
113+
actor = Actor(name={"user": 123})
114+
print(actor.name, type(actor.name))
115+
116+
if __name__ == "__main__":
117+
test_actor_invalid()
118+
print("test_notify.py: all tests passed")

0 commit comments

Comments
 (0)