diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4216e15a7..051983d9f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,6 +12,7 @@ Changelog 0.16.8 ------ +- Allow `Q` expression to function with `_filter` parameter - Add ``group by`` support - Fixed regression where ``GROUP BY`` class is missing for an aggregate with a specified order. diff --git a/examples/functions.py b/examples/functions.py index 110ad4da0..2d888c528 100644 --- a/examples/functions.py +++ b/examples/functions.py @@ -1,6 +1,7 @@ from tortoise import Tortoise, fields, run_async from tortoise.functions import Coalesce, Count, Length, Lower, Min, Sum, Trim, Upper from tortoise.models import Model +from tortoise.query_utils import Q class Tournament(Model): @@ -57,6 +58,11 @@ async def run(): await event.participants.add(participants[0], participants[1]) print(await Tournament.all().annotate(events_count=Count("events")).filter(events_count__gte=1)) + print( + await Tournament.all() + .annotate(events_count_with_filter=Count("events", _filter=Q(name="New Tournament"))) + .filter(events_count_with_filter__gte=1) + ) print(await Event.filter(id=event.id).first().annotate(lowest_team_id=Min("participants__id"))) diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index 3976e9c2e..2ca96fee0 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -2,6 +2,7 @@ from tortoise.contrib import test from tortoise.exceptions import ConfigurationError from tortoise.functions import Avg, Count, Min, Sum +from tortoise.query_utils import Q class TestAggregation(test.TestCase): @@ -91,6 +92,26 @@ async def test_aggregation_with_distinct(self): self.assertEqual(school_with_distinct_count.events_count, 3) self.assertEqual(school_with_distinct_count.minrelations_count, 2) + async def test_aggregation_with_filter(self): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Event 1", tournament=tournament) + await Event.create(name="Event 2", tournament=tournament) + await Event.create(name="Event 3", tournament=tournament) + + tournament_with_filter = ( + await Tournament.all() + .annotate( + all=Count("events", _filter=Q(name="New Tournament")), + one=Count("events", _filter=Q(events__name="Event 1")), + two=Count("events", _filter=Q(events__name__not="Event 1")), + ) + .first() + ) + + self.assertEqual(tournament_with_filter.all, 3) + self.assertEqual(tournament_with_filter.one, 1) + self.assertEqual(tournament_with_filter.two, 2) + async def test_group_aggregation(self): author = await Author.create(name="Some One") await Book.create(name="First!", author=author, rating=4) diff --git a/tests/test_source_field.py b/tests/test_source_field.py index e9e4e88fc..761606a47 100644 --- a/tests/test_source_field.py +++ b/tests/test_source_field.py @@ -8,6 +8,7 @@ from tortoise.contrib import test from tortoise.expressions import F from tortoise.functions import Coalesce, Count, Length, Lower, Trim, Upper +from tortoise.query_utils import Q class StraightFieldTests(test.TestCase): @@ -181,6 +182,25 @@ async def test_function(self): obj2 = await self.model.get(eyedee=obj1.eyedee) self.assertEqual(obj2.chars, "aaa") + async def test_aggregation_with_filter(self): + obj1 = await self.model.create(chars="aaa") + await self.model.create(chars="bbb", fk=obj1) + await self.model.create(chars="ccc", fk=obj1) + + obj = ( + await self.model.filter(chars="aaa") + .annotate( + all=Count("fkrev", _filter=Q(chars="aaa")), + one=Count("fkrev", _filter=Q(fkrev__chars="bbb")), + no=Count("fkrev", _filter=Q(fkrev__chars="aaa")), + ) + .first() + ) + + self.assertEqual(obj.all, 2) + self.assertEqual(obj.one, 1) + self.assertEqual(obj.no, 0) + async def test_filter_by_aggregation_field_coalesce(self): await self.model.create(chars="aaa", nullable="null") await self.model.create(chars="bbb") diff --git a/tortoise/functions.py b/tortoise/functions.py index d425f1c25..4b8add928 100644 --- a/tortoise/functions.py +++ b/tortoise/functions.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, Optional, Type, Union, cast -from pypika import Table, functions +from pypika import Case, Table, functions from pypika.functions import DistinctOptionFunction from pypika.terms import ArithmeticExpression from pypika.terms import Function as BaseFunction @@ -8,6 +8,7 @@ from tortoise.exceptions import ConfigurationError from tortoise.expressions import F from tortoise.fields.relational import BackwardFKRelation, ForeignKeyFieldInstance, RelationalField +from tortoise.query_utils import Q, QueryModifier if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model @@ -100,7 +101,7 @@ def _resolve_field_for_model( if func: field = func(self.field_object, field) - return {"joins": joins, "field": self._get_function_field(field, *default_values)} + return {"joins": joins, "field": field} def resolve(self, model: "Type[Model]", table: Table) -> dict: """ @@ -114,6 +115,7 @@ def resolve(self, model: "Type[Model]", table: Table) -> dict: if isinstance(self.field, str): function = self._resolve_field_for_model(model, table, self.field, *self.default_values) + function["field"] = self._get_function_field(function["field"], *self.default_values) return function else: field, field_object = F.resolver_arithmetic_expression(model, self.field) @@ -134,10 +136,15 @@ class Aggregate(Function): database_func = DistinctOptionFunction def __init__( - self, field: Union[str, F, ArithmeticExpression], *default_values: Any, distinct=False + self, + field: Union[str, F, ArithmeticExpression], + *default_values: Any, + distinct=False, + _filter: Optional[Q] = None, ) -> None: super().__init__(field, *default_values) self.distinct = distinct + self.filter = _filter def _get_function_field( self, field: "Union[ArithmeticExpression, Field, str]", *default_values @@ -147,6 +154,18 @@ def _get_function_field( else: return self.database_func(field, *default_values) + def _resolve_field_for_model( + self, model: "Type[Model]", table: Table, field: str, *default_values: Any + ) -> dict: + ret = super()._resolve_field_for_model(model, table, field, default_values) + if self.filter: + modifier = QueryModifier() + modifier &= self.filter.resolve(model, {}, {}, model._meta.basetable) + where_criterion, joins, having_criterion = modifier.get_query_modifiers() + ret["field"] = Case().when(where_criterion, ret["field"]).else_(None) + + return ret + ############################################################################## # Standard functions