Skip to content

Commit

Permalink
Fix psycopg casting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
henadzit committed Feb 12, 2025
1 parent ce79190 commit f197ea4
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 47 deletions.
23 changes: 23 additions & 0 deletions tests/fields/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ async def test_values_list(self):
values = await testmodels.ArrayFields.get(id=obj0.id).values_list("array", flat=True)
self.assertEqual(values, [0])

async def test_eq_filter(self):
o1 = await testmodels.ArrayFields.create(array=[1, 2, 3])
o2 = await testmodels.ArrayFields.create(array=[1, 2])

found = await testmodels.ArrayFields.filter(array=[1, 2, 3]).first()
self.assertEqual(found, o1)

found = await testmodels.ArrayFields.filter(array=[1, 2]).first()
self.assertEqual(found, o2)

async def test_not_filter(self):
await testmodels.ArrayFields.create(array=[1, 2, 3])
o2 = await testmodels.ArrayFields.create(array=[1, 2])

found = await testmodels.ArrayFields.filter(array__not=[1, 2, 3]).first()
self.assertEqual(found, o2)

async def test_contains_ints(self):
await testmodels.ArrayFields.create(array=[1, 2, 3])
await testmodels.ArrayFields.create(array=[2, 3])
Expand All @@ -58,6 +75,12 @@ async def test_contains_ints(self):
)
self.assertEqual(list(found), [])

async def test_contains_smallints(self):
o1 = await testmodels.ArrayFields.create(array=[], array_smallint=[1, 2, 3])

found = await testmodels.ArrayFields.filter(array_smallint__contains=[2]).first()
self.assertEqual(found, o1)

async def test_contains_strs(self):
await testmodels.ArrayFields.create(array_str=["a", "b", "c"], array=[])

Expand Down
1 change: 1 addition & 0 deletions tests/testmodels_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ class ArrayFields(Model):
array = ArrayField()
array_null = ArrayField(null=True)
array_str = ArrayField(element_type="varchar(1)", null=True)
array_smallint = ArrayField(element_type="smallint", null=True)
8 changes: 4 additions & 4 deletions tortoise/backends/base_postgres/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from tortoise import Model
from tortoise.backends.base.executor import BaseExecutor
from tortoise.contrib.postgres.array_functions import (
postgres_array_contains,
postgres_array_contained_by,
postgres_array_overlap,
postgres_array_contains,
postgres_array_length,
postgres_array_overlap,
)
from tortoise.contrib.postgres.json_functions import (
postgres_json_contained_by,
Expand All @@ -24,16 +24,16 @@
)
from tortoise.contrib.postgres.search import SearchCriterion
from tortoise.filters import (
array_contains,
array_contained_by,
array_contains,
array_length,
array_overlap,
insensitive_posix_regex,
json_contained_by,
json_contains,
json_filter,
posix_regex,
search,
array_length,
)


Expand Down
22 changes: 11 additions & 11 deletions tortoise/contrib/postgres/array_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from enum import Enum
from typing import Any, Sequence, Union

from pypika_tortoise.terms import Array, BasicCriterion, Criterion, Term
from pypika_tortoise.functions import Function
from pypika_tortoise.terms import BasicCriterion, Criterion, Function, Term


class PostgresArrayOperators(str, Enum):
Expand All @@ -11,19 +9,21 @@ class PostgresArrayOperators(str, Enum):
OVERLAP = "&&"


def postgres_array_contains(field: Term, value: Union[Any, Sequence[Any]]) -> Criterion:
if not isinstance(value, Sequence) or isinstance(value, (str, bytes)):
value = (value,)
# The value in the functions below is casted to the exact type of the field with value_encoder
# to avoid issues with psycopg that tries to use the smallest possible type which can lead to errors,
# e.g. {1,2} will be casted to smallint[] instead of integer[].

return BasicCriterion(PostgresArrayOperators.CONTAINS, field, Array(*value))

def postgres_array_contains(field: Term, value: Term) -> Criterion:
return BasicCriterion(PostgresArrayOperators.CONTAINS, field, value)

def postgres_array_contained_by(field: Term, value: Sequence[Any]) -> Criterion:
return BasicCriterion(PostgresArrayOperators.CONTAINED_BY, field, Array(*value))

def postgres_array_contained_by(field: Term, value: Term) -> Criterion:
return BasicCriterion(PostgresArrayOperators.CONTAINED_BY, field, value)

def postgres_array_overlap(field: Term, value: Sequence[Any]) -> Criterion:
return BasicCriterion(PostgresArrayOperators.OVERLAP, field, Array(*value))

def postgres_array_overlap(field: Term, value: Term) -> Criterion:
return BasicCriterion(PostgresArrayOperators.OVERLAP, field, value)


def postgres_array_length(field: Term, value: int) -> Criterion:
Expand Down
40 changes: 18 additions & 22 deletions tortoise/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,33 +355,29 @@ def _process_filter_kwarg(
join = None

if value is None and f"{key}__isnull" in model._meta.filters:
param = model._meta.get_filter(f"{key}__isnull")
filter = model._meta.get_filter(f"{key}__isnull")
value = True
else:
param = model._meta.get_filter(key)
filter = model._meta.get_filter(key)

pk_db_field = model._meta.db_pk_column
if param.get("table"):
if "table" in filter:
# join the table
join = (
param["table"],
table[pk_db_field] == param["table"][param["backward_key"]],
filter["table"],
table[model._meta.db_pk_column] == filter["table"][filter["backward_key"]],
)
if param.get("value_encoder"):
value = param["value_encoder"](value, model)
op = param["operator"]
criterion = op(param["table"][param["field"]], value)
else:
if isinstance(value, Term):
encoded_value = value
else:
field_object = model._meta.fields_map[param["field"]]
encoded_value = (
param["value_encoder"](value, model, field_object)
if param.get("value_encoder")
else field_object.to_db_value(value, model)
)
op = param["operator"]
criterion = op(table[param["source_field"]], encoded_value)
if "value_encoder" in filter:
value = filter["value_encoder"](value, model)
table = filter["table"]
elif not isinstance(value, Term):
field_object = model._meta.fields_map[filter["field"]]
value = (
filter["value_encoder"](value, model, field_object)
if "value_encoder" in filter
else field_object.to_db_value(value, model)
)
op = filter["operator"]
criterion = op(table[filter.get("source_field", filter["field"])], value)
return criterion, join

def _resolve_regular_kwarg(
Expand Down
7 changes: 7 additions & 0 deletions tortoise/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ def _get_dialects(self) -> dict[str, dict]:

return ret

def get_db_field_type(self) -> str:
"""
Returns the DB field type for this field for the current dialect.
"""
dialect = self.model._meta.db.capabilities.dialect
return self.get_for_dialect(dialect, "SQL_TYPE")

def get_db_field_types(self) -> Optional[dict[str, str]]:
"""
Returns the DB types for this field.
Expand Down
34 changes: 24 additions & 10 deletions tortoise/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pypika_tortoise.enums import DatePart, Matching, SqlTypes
from pypika_tortoise.functions import Cast, Extract, Upper
from pypika_tortoise.terms import (
Array,
BasicCriterion,
Criterion,
Equality,
Expand Down Expand Up @@ -81,6 +82,13 @@ def json_encoder(value: Any, instance: "Model", field: Field) -> dict:
return value


def array_encoder(value: Union[Any, Sequence[Any]], instance: "Model", field: Field) -> Any:
# Casting to the exact type of the field to avoid issues with psycopg that tries
# to use the smallest possible type which can lead to errors,
# e.g. {1,2} will be casted to smallint[] instead of integer[].
return Cast(Array(*value), field.get_db_field_type())


##############################################################################
# Operators
# Should be type: (field: Term, value: Any) -> Criterion:
Expand Down Expand Up @@ -343,42 +351,41 @@ def get_backward_fk_filters(


def get_json_filter(field_name: str, source_field: str) -> dict[str, FilterInfoDict]:
actual_field_name = field_name
return {
field_name: {
"field": actual_field_name,
"field": field_name,
"source_field": source_field,
"operator": operator.eq,
},
f"{field_name}__not": {
"field": actual_field_name,
"field": field_name,
"source_field": source_field,
"operator": not_equal,
},
f"{field_name}__isnull": {
"field": actual_field_name,
"field": field_name,
"source_field": source_field,
"operator": is_null,
"value_encoder": bool_encoder,
},
f"{field_name}__not_isnull": {
"field": actual_field_name,
"field": field_name,
"source_field": source_field,
"operator": not_null,
"value_encoder": bool_encoder,
},
f"{field_name}__contains": {
"field": actual_field_name,
"field": field_name,
"source_field": source_field,
"operator": json_contains,
},
f"{field_name}__contained_by": {
"field": actual_field_name,
"field": field_name,
"source_field": source_field,
"operator": json_contained_by,
},
f"{field_name}__filter": {
"field": actual_field_name,
"field": field_name,
"source_field": source_field,
"operator": json_filter,
"value_encoder": json_encoder,
Expand All @@ -399,17 +406,21 @@ def get_json_filter_operator(
return key_parts, filter_value, operator_


def get_array_filter(field_name: str, source_field: str) -> dict[str, FilterInfoDict]:
def get_array_filter(
field_name: str, source_field: str, field: ArrayField
) -> dict[str, FilterInfoDict]:
return {
field_name: {
"field": field_name,
"source_field": source_field,
"operator": operator.eq,
"value_encoder": array_encoder,
},
f"{field_name}__not": {
"field": field_name,
"source_field": source_field,
"operator": not_equal,
"value_encoder": array_encoder,
},
f"{field_name}__isnull": {
"field": field_name,
Expand All @@ -427,16 +438,19 @@ def get_array_filter(field_name: str, source_field: str) -> dict[str, FilterInfo
"field": field_name,
"source_field": source_field,
"operator": array_contains,
"value_encoder": array_encoder,
},
f"{field_name}__contained_by": {
"field": field_name,
"source_field": source_field,
"operator": array_contained_by,
"value_encoder": array_encoder,
},
f"{field_name}__overlap": {
"field": field_name,
"source_field": source_field,
"operator": array_overlap,
"value_encoder": array_encoder,
},
f"{field_name}__len": {
"field": field_name,
Expand All @@ -458,7 +472,7 @@ def get_filters_for_field(
if isinstance(field, JSONField):
return get_json_filter(field_name, source_field)
if isinstance(field, ArrayField):
return get_array_filter(field_name, source_field)
return get_array_filter(field_name, source_field, field)

actual_field_name = field_name
if field_name == "pk" and field:
Expand Down

0 comments on commit f197ea4

Please sign in to comment.