Skip to content

Commit

Permalink
Allow to filter on an aggregate function (#362)
Browse files Browse the repository at this point in the history
  • Loading branch information
hqsz authored Apr 22, 2020
1 parent fc21bed commit 803115d
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Changelog

0.16.8
------
- Allow `Q` expression to function with `_filter` parameter
- Add ``group by`` support
- Fixed regression where ``GROUP BY`` class is missing for an aggregate with a specified order.

Expand Down
6 changes: 6 additions & 0 deletions examples/functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tortoise import Tortoise, fields, run_async
from tortoise.functions import Coalesce, Count, Length, Lower, Min, Sum, Trim, Upper
from tortoise.models import Model
from tortoise.query_utils import Q


class Tournament(Model):
Expand Down Expand Up @@ -57,6 +58,11 @@ async def run():
await event.participants.add(participants[0], participants[1])

print(await Tournament.all().annotate(events_count=Count("events")).filter(events_count__gte=1))
print(
await Tournament.all()
.annotate(events_count_with_filter=Count("events", _filter=Q(name="New Tournament")))
.filter(events_count_with_filter__gte=1)
)

print(await Event.filter(id=event.id).first().annotate(lowest_team_id=Min("participants__id")))

Expand Down
21 changes: 21 additions & 0 deletions tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from tortoise.contrib import test
from tortoise.exceptions import ConfigurationError
from tortoise.functions import Avg, Count, Min, Sum
from tortoise.query_utils import Q


class TestAggregation(test.TestCase):
Expand Down Expand Up @@ -91,6 +92,26 @@ async def test_aggregation_with_distinct(self):
self.assertEqual(school_with_distinct_count.events_count, 3)
self.assertEqual(school_with_distinct_count.minrelations_count, 2)

async def test_aggregation_with_filter(self):
tournament = await Tournament.create(name="New Tournament")
await Event.create(name="Event 1", tournament=tournament)
await Event.create(name="Event 2", tournament=tournament)
await Event.create(name="Event 3", tournament=tournament)

tournament_with_filter = (
await Tournament.all()
.annotate(
all=Count("events", _filter=Q(name="New Tournament")),
one=Count("events", _filter=Q(events__name="Event 1")),
two=Count("events", _filter=Q(events__name__not="Event 1")),
)
.first()
)

self.assertEqual(tournament_with_filter.all, 3)
self.assertEqual(tournament_with_filter.one, 1)
self.assertEqual(tournament_with_filter.two, 2)

async def test_group_aggregation(self):
author = await Author.create(name="Some One")
await Book.create(name="First!", author=author, rating=4)
Expand Down
20 changes: 20 additions & 0 deletions tests/test_source_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tortoise.contrib import test
from tortoise.expressions import F
from tortoise.functions import Coalesce, Count, Length, Lower, Trim, Upper
from tortoise.query_utils import Q


class StraightFieldTests(test.TestCase):
Expand Down Expand Up @@ -181,6 +182,25 @@ async def test_function(self):
obj2 = await self.model.get(eyedee=obj1.eyedee)
self.assertEqual(obj2.chars, "aaa")

async def test_aggregation_with_filter(self):
obj1 = await self.model.create(chars="aaa")
await self.model.create(chars="bbb", fk=obj1)
await self.model.create(chars="ccc", fk=obj1)

obj = (
await self.model.filter(chars="aaa")
.annotate(
all=Count("fkrev", _filter=Q(chars="aaa")),
one=Count("fkrev", _filter=Q(fkrev__chars="bbb")),
no=Count("fkrev", _filter=Q(fkrev__chars="aaa")),
)
.first()
)

self.assertEqual(obj.all, 2)
self.assertEqual(obj.one, 1)
self.assertEqual(obj.no, 0)

async def test_filter_by_aggregation_field_coalesce(self):
await self.model.create(chars="aaa", nullable="null")
await self.model.create(chars="bbb")
Expand Down
25 changes: 22 additions & 3 deletions tortoise/functions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import TYPE_CHECKING, Any, Optional, Type, Union, cast

from pypika import Table, functions
from pypika import Case, Table, functions
from pypika.functions import DistinctOptionFunction
from pypika.terms import ArithmeticExpression
from pypika.terms import Function as BaseFunction

from tortoise.exceptions import ConfigurationError
from tortoise.expressions import F
from tortoise.fields.relational import BackwardFKRelation, ForeignKeyFieldInstance, RelationalField
from tortoise.query_utils import Q, QueryModifier

if TYPE_CHECKING: # pragma: nocoverage
from tortoise.models import Model
Expand Down Expand Up @@ -100,7 +101,7 @@ def _resolve_field_for_model(
if func:
field = func(self.field_object, field)

return {"joins": joins, "field": self._get_function_field(field, *default_values)}
return {"joins": joins, "field": field}

def resolve(self, model: "Type[Model]", table: Table) -> dict:
"""
Expand All @@ -114,6 +115,7 @@ def resolve(self, model: "Type[Model]", table: Table) -> dict:

if isinstance(self.field, str):
function = self._resolve_field_for_model(model, table, self.field, *self.default_values)
function["field"] = self._get_function_field(function["field"], *self.default_values)
return function
else:
field, field_object = F.resolver_arithmetic_expression(model, self.field)
Expand All @@ -134,10 +136,15 @@ class Aggregate(Function):
database_func = DistinctOptionFunction

def __init__(
self, field: Union[str, F, ArithmeticExpression], *default_values: Any, distinct=False
self,
field: Union[str, F, ArithmeticExpression],
*default_values: Any,
distinct=False,
_filter: Optional[Q] = None,
) -> None:
super().__init__(field, *default_values)
self.distinct = distinct
self.filter = _filter

def _get_function_field(
self, field: "Union[ArithmeticExpression, Field, str]", *default_values
Expand All @@ -147,6 +154,18 @@ def _get_function_field(
else:
return self.database_func(field, *default_values)

def _resolve_field_for_model(
self, model: "Type[Model]", table: Table, field: str, *default_values: Any
) -> dict:
ret = super()._resolve_field_for_model(model, table, field, default_values)
if self.filter:
modifier = QueryModifier()
modifier &= self.filter.resolve(model, {}, {}, model._meta.basetable)
where_criterion, joins, having_criterion = modifier.get_query_modifiers()
ret["field"] = Case().when(where_criterion, ret["field"]).else_(None)

return ret


##############################################################################
# Standard functions
Expand Down

0 comments on commit 803115d

Please sign in to comment.