Skip to content

Commit ca710bf

Browse files
Merge pull request #228 from phenobarbital/bugfix-optional-list-str
fix issue when Optional[List[str]] is used on type hints for BaseModel
2 parents 8ef3a71 + 27e8ed2 commit ca710bf

File tree

5 files changed

+281
-15
lines changed

5 files changed

+281
-15
lines changed

datamodel/converters.pyx

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import asyncpg.pgproto.pgproto as pgproto
1515
from cpython.ref cimport PyObject
1616
from .functions import is_empty, is_dataclass, is_iterable, is_primitive
1717
from .validation import _validation
18+
from .fields import Field
1819

1920

2021
# Maps a type to a conversion callable
@@ -445,6 +446,22 @@ cdef object _parse_dict_type(
445446
new_dict[k] = parse_typing(field, val_type, v, encoder, False)
446447
return new_dict
447448

449+
cdef object _unwrap_optional(object T):
450+
"""
451+
If T is a Union that includes NoneType (i.e. Optional[T]),
452+
return the non-None type, else T unchanged.
453+
"""
454+
cdef object orig = get_origin(T)
455+
cdef tuple args = None
456+
cdef list non_none = []
457+
if orig is Union:
458+
args = get_args(T)
459+
# If exactly one type is not NoneType, return it.
460+
non_none = [a for a in args if a is not type(None)]
461+
if len(non_none) == 1:
462+
return non_none[0]
463+
return T, args
464+
448465
cdef object _parse_list_type(
449466
object field,
450467
object T,
@@ -722,8 +739,33 @@ cdef object _parse_union_type(
722739
"""
723740
Attempt each type in the Union until one parses successfully
724741
or raise an error if all fail.
742+
If T is Optional[...] (i.e. a Union with NoneType), unwrap it.
725743
"""
726744
cdef object errors = []
745+
cdef object non_none_arg = None
746+
cdef tuple inner_targs = None
747+
cdef bint is_typing = False
748+
# If the union includes NoneType, unwrap it and use only the non-None type.
749+
if origin == Union and type(None) in targs:
750+
for arg in targs:
751+
if arg is not type(None):
752+
non_none_arg = arg
753+
break
754+
is_typing = hasattr(non_none_arg, '__module__') and non_none_arg.__module__ == 'typing'
755+
if non_none_arg is not None and is_typing is True:
756+
# Invoke the parse_typing_type
757+
field.args = get_args(non_none_arg)
758+
field.origin = get_origin(non_none_arg)
759+
if isinstance(data, list):
760+
return parse_typing(
761+
field,
762+
non_none_arg,
763+
data,
764+
encoder,
765+
False
766+
)
767+
else:
768+
pass
727769
for arg_type in targs:
728770
try:
729771
if isinstance(data, list):

datamodel/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def remove_nulls(self, obj: Any) -> dict[str, Any]:
156156
elif isinstance(obj, dict):
157157
return {
158158
key: self.remove_nulls(value) for key, value in obj.items()
159-
if value is not None
159+
if value is not None and value != {}
160160
}
161161
else:
162162
return obj

datamodel/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
'simple library based on python +3.8 to use Dataclass-syntax'
77
'for interacting with Data'
88
)
9-
__version__ = '0.8.11'
9+
__version__ = '0.8.12'
1010
__copyright__ = 'Copyright (c) 2020-2024 Jesus Lara'
1111
__author__ = 'Jesus Lara'
1212
__author_email__ = '[email protected]'

examples/test_qsmodel.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
from typing import Optional
1+
from typing import List, Optional
22
from datetime import datetime
33
from datamodel import BaseModel, Field
4+
from datamodel.exceptions import ValidationError
5+
46

57
data = {
68
'query_slug': 'walmart_mtd_postpaid_to_goal',
79
'description': 'walmart_mtd_postpaid_to_goal',
810
'conditions': {'filterdate': 'POSTPAID_DATE', 'store_tier': 'null', 'launch_group': 'null'},
911
'cond_definition': {'filterdate': 'date', 'store_tier': 'string', 'launch_group': 'string'},
10-
'fields': [], 'ordering': [],
12+
'fields': ["client_id", "client_name"], 'ordering': [],
1113
'h_filtering': False,
1214
'query_raw': 'SELECT {fields}\nFROM walmart.postpaid_metrics({filterdate}, {launch_group}, {store_tier})\n{where_cond}',
1315
'is_raw': False,
@@ -37,14 +39,14 @@ class QueryModel(BaseModel):
3739
conditions: Optional[dict] = Field(required=False, db_type='jsonb', default_factory=dict)
3840
cond_definition: Optional[dict] = Field(required=False, db_type='jsonb', default_factory=dict)
3941
## filter and grouping options
40-
fields: Optional[list] = Field(required=False, db_type='array')
42+
fields: Optional[List[str]] = Field(required=False, db_type='array', default_factory=list)
4143
filtering: Optional[dict] = Field(required=False, db_type='jsonb', default_factory=dict)
42-
ordering: Optional[list] = Field(required=False, db_type='array')
43-
grouping: Optional[list] = Field(required=False, db_type='array')
44+
ordering: Optional[List[str]] = Field(required=False, db_type='array')
45+
grouping: Optional[List[str]] = Field(required=False, db_type='array')
4446
qry_options: Optional[dict] = Field(required=False, db_type='jsonb')
4547
h_filtering: bool = Field(required=False, default=False, comment="filtering based on Hierarchical rules.")
4648
### Query Information:
47-
query_raw: str = Field(required=False)
49+
query_raw: str = Field(required=False)
4850
is_raw: bool = Field(required=False, default=False)
4951
is_cached: bool = Field(required=False, default=True)
5052
provider: str = Field(required=False, default='db')
@@ -63,25 +65,28 @@ class QueryModel(BaseModel):
6365
# Creation Information:
6466
created_at: datetime = Field(
6567
required=False,
66-
default=datetime.now(),
68+
default=datetime.now,
6769
db_default='now()'
6870
)
69-
created_by: int = Field(required=False) # TODO: validation for valid user
71+
created_by: int = Field(required=False) # TODO: validation for valid user
7072
updated_at: datetime = Field(
7173
required=False,
72-
default=datetime.now(),
74+
default=datetime.now,
7375
encoder=rigth_now
7476
)
75-
updated_by: int = Field(required=False) # TODO: validation for valid user
77+
updated_by: int = Field(required=False) # TODO: validation for valid user
7678

7779
class Meta:
7880
driver = 'pg'
7981
name = 'queries'
8082
schema = 'public'
8183
strict = True
8284
frozen = False
83-
remove_nulls = True # Auto-remove nullable (with null value) fields
85+
remove_nulls = True # Auto-remove nullable (with null value) fields
8486

8587

86-
slug = QueryModel(**data)
87-
print(slug)
88+
try:
89+
slug = QueryModel(**data)
90+
print(slug)
91+
except ValidationError as e:
92+
print(e.payload)

tests/test_qsmodel.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# tests/test_query_model.py
2+
from datetime import datetime
3+
from typing import List, Optional
4+
import pytest
5+
from datamodel import BaseModel, Field
6+
from datamodel.exceptions import ValidationError
7+
8+
9+
def rigth_now(obj) -> datetime:
10+
return datetime.now()
11+
12+
class QueryModel(BaseModel):
13+
query_slug: str = Field(required=True, primary_key=True)
14+
description: str = Field(required=False, default=None)
15+
# Source and primary attributes:
16+
source: Optional[str] = Field(required=False)
17+
params: Optional[dict] = Field(required=False, db_type='jsonb', default_factory=dict)
18+
attributes: Optional[dict] = Field(
19+
required=False,
20+
db_type='jsonb',
21+
default_factory=dict,
22+
comment="Optional Attributes for Query"
23+
)
24+
# main conditions
25+
conditions: Optional[dict] = Field(required=False, db_type='jsonb', default_factory=dict)
26+
cond_definition: Optional[dict] = Field(
27+
required=False,
28+
db_type='jsonb',
29+
default_factory=dict
30+
)
31+
## filter and grouping options
32+
fields: Optional[List[str]] = Field(
33+
required=False,
34+
db_type='array',
35+
default_factory=list
36+
)
37+
filtering: Optional[dict] = Field(required=False, db_type='jsonb', default_factory=dict)
38+
ordering: Optional[List[str]] = Field(required=False, db_type='array')
39+
grouping: Optional[List[str]] = Field(required=False, db_type='array')
40+
qry_options: Optional[dict] = Field(required=False, db_type='jsonb')
41+
h_filtering: bool = Field(required=False, default=False, comment="filtering based on Hierarchical rules.")
42+
### Query Information:
43+
query_raw: str = Field(required=False)
44+
is_raw: bool = Field(required=False, default=False)
45+
is_cached: bool = Field(required=False, default=True)
46+
provider: str = Field(required=False, default='db')
47+
parser: str = Field(required=False, default='SQLParser', comment="Parser to be used for parsing Query.")
48+
cache_timeout: int = Field(required=True, default=3600)
49+
cache_refresh: int = Field(required=True, default=0)
50+
cache_options: Optional[dict] = Field(required=False, db_type='jsonb', default_factory=dict)
51+
## Program Information:
52+
program_id: int = Field(required=True, default=1)
53+
program_slug: str = Field(required=True, default='default')
54+
# DWH information
55+
dwh: bool = Field(required=True, default=False)
56+
dwh_driver: str = Field(required=False, default=None)
57+
dwh_info: Optional[dict] = Field(required=False, db_type='jsonb')
58+
dwh_scheduler: Optional[dict] = Field(
59+
required=False,
60+
db_type='jsonb',
61+
default_factory=dict
62+
)
63+
# Creation Information:
64+
created_at: datetime = Field(
65+
required=False,
66+
default=datetime.now,
67+
db_default='now()'
68+
)
69+
created_by: int = Field(required=False) # TODO: validation for valid user
70+
updated_at: datetime = Field(
71+
required=False,
72+
default=datetime.now,
73+
encoder=rigth_now
74+
)
75+
updated_by: int = Field(required=False) # TODO: validation for valid user
76+
77+
class Meta:
78+
driver = 'pg'
79+
name = 'queries'
80+
schema = 'public'
81+
strict = True
82+
frozen = False
83+
remove_nulls = True # Auto-remove nullable (with null value) fields
84+
85+
86+
# Sample payload as provided
87+
@pytest.fixture
88+
def sample_data():
89+
return {
90+
'query_slug': 'walmart_mtd_postpaid_to_goal',
91+
'description': 'walmart_mtd_postpaid_to_goal',
92+
'conditions': {'filterdate': 'POSTPAID_DATE', 'store_tier': 'null', 'launch_group': 'null'},
93+
'cond_definition': {'filterdate': 'date', 'store_tier': 'string', 'launch_group': 'string'},
94+
'fields': ["client_id", "client_name"],
95+
'ordering': [],
96+
'h_filtering': False,
97+
'query_raw': 'SELECT {fields}\nFROM walmart.postpaid_metrics({filterdate}, {launch_group}, {store_tier})\n{where_cond}',
98+
'is_raw': False,
99+
'is_cached': True,
100+
'provider': 'db',
101+
'parser': 'pgSQLParser',
102+
'cache_timeout': 900,
103+
'cache_refresh': None,
104+
'cache_options': {},
105+
'program_id': 3,
106+
'program_slug': 'walmart',
107+
'dwh': False,
108+
'dwh_scheduler': {},
109+
'created_at': datetime(2022, 11, 18, 1, 10, 8, 872163),
110+
'updated_at': datetime(2023, 4, 18, 1, 38, 44, 466221)
111+
}
112+
113+
114+
def test_query_model_success(sample_data):
115+
"""Test that a QueryModel is created successfully and that every field has the expected value and type."""
116+
qm = QueryModel(**sample_data)
117+
118+
# Required primary key field
119+
assert qm.query_slug == sample_data['query_slug']
120+
# Description is optional
121+
assert qm.description == sample_data['description']
122+
# When not provided, source is None.
123+
assert qm.source is None
124+
125+
# Defaulted dict fields (via default_factory):
126+
assert isinstance(qm.params, dict)
127+
assert qm.params == {}
128+
assert isinstance(qm.attributes, dict)
129+
assert qm.attributes == {}
130+
assert isinstance(qm.conditions, dict)
131+
assert qm.conditions == sample_data['conditions']
132+
assert isinstance(qm.cond_definition, dict)
133+
assert qm.cond_definition == sample_data['cond_definition']
134+
135+
# List field "fields" must be a list of strings as provided.
136+
assert isinstance(qm.fields, list)
137+
assert qm.fields == sample_data['fields']
138+
139+
# Filtering and ordering
140+
assert isinstance(qm.filtering, dict)
141+
assert qm.filtering == {}
142+
assert isinstance(qm.ordering, list)
143+
assert qm.ordering == sample_data['ordering']
144+
# grouping is optional and not provided so likely None.
145+
assert qm.grouping is None
146+
147+
# Boolean and string fields
148+
assert isinstance(qm.h_filtering, bool)
149+
assert qm.h_filtering == sample_data['h_filtering']
150+
assert qm.query_raw == sample_data['query_raw']
151+
assert qm.is_raw == sample_data['is_raw']
152+
assert qm.is_cached == sample_data['is_cached']
153+
assert qm.provider == sample_data['provider']
154+
assert qm.parser == sample_data['parser']
155+
156+
# Numeric and dict defaults
157+
assert qm.cache_timeout == sample_data['cache_timeout']
158+
# Here cache_refresh is provided as None in sample but default_factory may have set it to 0.
159+
# (Your model code sets default=0 for cache_refresh.)
160+
assert qm.cache_refresh == 0
161+
assert qm.cache_options == sample_data['cache_options']
162+
163+
# Program info
164+
assert qm.program_id == sample_data['program_id']
165+
assert qm.program_slug == sample_data['program_slug']
166+
167+
# DWH info fields
168+
assert qm.dwh == sample_data['dwh']
169+
# dwh_driver and dwh_info are not provided, so should be None.
170+
assert qm.dwh_driver is None
171+
assert qm.dwh_info is None
172+
assert qm.dwh_scheduler == sample_data['dwh_scheduler']
173+
174+
# Creation information: check that created_at and updated_at are datetime objects and match
175+
assert isinstance(qm.created_at, datetime)
176+
assert qm.created_at == sample_data['created_at']
177+
assert isinstance(qm.updated_at, datetime)
178+
assert qm.updated_at == sample_data['updated_at']
179+
180+
# created_by and updated_by are not provided, so should be None.
181+
assert qm.created_by is None
182+
assert qm.updated_by is None
183+
184+
185+
def test_query_model_missing_required_field(sample_data):
186+
"""Test that missing a required field (e.g. 'query_slug') raises a ValidationError."""
187+
payload = sample_data.copy()
188+
payload.pop('query_slug')
189+
with pytest.raises(ValidationError) as excinfo:
190+
QueryModel(**payload)
191+
# Optionally, check that the error payload mentions the missing 'query_slug'
192+
assert 'query_slug' in str(excinfo.value)
193+
194+
195+
def test_query_model_remove_nulls(sample_data):
196+
"""Test that when remove_nulls is enabled, to_dict does not include keys with None values."""
197+
payload = sample_data.copy()
198+
# Intentionally set a couple of optional fields to None
199+
payload['attributes'] = None
200+
payload['params'] = None
201+
qm = QueryModel(**payload)
202+
d = qm.to_dict(remove_nulls=True)
203+
assert 'attributes' not in d
204+
assert 'params' not in d
205+
206+
207+
def test_query_model_schema_generation():
208+
"""Test that the JSON-Schema generated by QueryModel contains expected keys."""
209+
schema_dict = QueryModel.schema(as_dict=True)
210+
assert schema_dict["$schema"] == "https://json-schema.org/draft/2020-12/schema"
211+
# Title may either be the class name or a transformation of it (based on slugify_camelcase)
212+
assert "Query Model" in schema_dict["title"]
213+
# Check that required fields (like 'query_slug') are in the "required" list.
214+
required = schema_dict.get("required", [])
215+
assert "query_slug" in required
216+
# Check that properties are defined and include 'query_slug'
217+
properties = schema_dict.get("properties", {})
218+
assert isinstance(properties, dict)
219+
assert "query_slug" in properties

0 commit comments

Comments
 (0)