Skip to content

Commit 88bd56f

Browse files
committed
perf: add aggregate to lazy join (#298)
1 parent 2989e54 commit 88bd56f

File tree

2 files changed

+183
-11
lines changed

2 files changed

+183
-11
lines changed

src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py

+48-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import List, Union, cast
1+
from typing import Dict, List, Optional, Union, cast
22

33
from forestadmin.agent_toolkit.utils.context import User
44
from forestadmin.datasource_toolkit.decorators.collection_decorator import CollectionDecorator
55
from forestadmin.datasource_toolkit.interfaces.fields import ManyToOne, is_many_to_one
6+
from forestadmin.datasource_toolkit.interfaces.query.aggregation import AggregateResult, Aggregation
67
from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf
78
from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter
89
from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter
@@ -17,7 +18,7 @@ async def list(self, caller: User, filter_: PaginatedFilter, projection: Project
1718
refined_filter = cast(PaginatedFilter, await self._refine_filter(caller, filter_))
1819
ret = await self.child_collection.list(caller, refined_filter, simplified_projection)
1920

20-
return self._apply_joins_on_records(projection, simplified_projection, ret)
21+
return self._apply_joins_on_simplified_records(projection, simplified_projection, ret)
2122

2223
async def _refine_filter(
2324
self, caller: User, _filter: Union[Filter, PaginatedFilter, None]
@@ -28,18 +29,39 @@ async def _refine_filter(
2829
_filter.condition_tree = _filter.condition_tree.replace(
2930
lambda leaf: (
3031
ConditionTreeLeaf(
31-
self._get_fk_field_for_projection(leaf.field),
32+
self._get_fk_field_for_many_to_one_projection(leaf.field),
3233
leaf.operator,
3334
leaf.value,
3435
)
35-
if self._is_useless_join(leaf.field.split(":")[0], _filter.condition_tree.projection)
36+
if self._is_useless_join_for_projection(leaf.field.split(":")[0], _filter.condition_tree.projection)
3637
else leaf
3738
)
3839
)
3940

4041
return _filter
4142

42-
def _is_useless_join(self, relation: str, projection: Projection) -> bool:
43+
async def aggregate(
44+
self, caller: User, filter_: Union[Filter, None], aggregation: Aggregation, limit: Optional[int] = None
45+
) -> List[AggregateResult]:
46+
replaced = {} # new_name -> old_name; for a simpler reconciliation
47+
48+
def replacer(field_name: str) -> str:
49+
if self._is_useless_join_for_projection(field_name.split(":")[0], aggregation.projection):
50+
new_field_name = self._get_fk_field_for_many_to_one_projection(field_name)
51+
replaced[new_field_name] = field_name
52+
return new_field_name
53+
return field_name
54+
55+
new_aggregation = aggregation.replace_fields(replacer)
56+
57+
aggregate_results = await self.child_collection.aggregate(
58+
caller, cast(Filter, await self._refine_filter(caller, filter_)), new_aggregation, limit
59+
)
60+
if aggregation == new_aggregation:
61+
return aggregate_results
62+
return self._replace_fields_in_aggregate_group(aggregate_results, replaced)
63+
64+
def _is_useless_join_for_projection(self, relation: str, projection: Projection) -> bool:
4365
relation_schema = self.schema["fields"][relation]
4466
sub_projections = projection.relations[relation]
4567

@@ -49,7 +71,7 @@ def _is_useless_join(self, relation: str, projection: Projection) -> bool:
4971
and sub_projections[0] == relation_schema["foreign_key_target"]
5072
)
5173

52-
def _get_fk_field_for_projection(self, projection: str) -> str:
74+
def _get_fk_field_for_many_to_one_projection(self, projection: str) -> str:
5375
relation_name = projection.split(":")[0]
5476
relation_schema = cast(ManyToOne, self.schema["fields"][relation_name])
5577

@@ -58,18 +80,18 @@ def _get_fk_field_for_projection(self, projection: str) -> str:
5880
def _get_projection_without_useless_joins(self, projection: Projection) -> Projection:
5981
returned_projection = Projection(*projection)
6082
for relation, relation_projections in projection.relations.items():
61-
if self._is_useless_join(relation, projection):
83+
if self._is_useless_join_for_projection(relation, projection):
6284
# remove foreign key target from projection
6385
returned_projection.remove(f"{relation}:{relation_projections[0]}")
6486

6587
# add foreign keys to projection
66-
fk_field = self._get_fk_field_for_projection(relation)
88+
fk_field = self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}")
6789
if fk_field not in returned_projection:
6890
returned_projection.append(fk_field)
6991

7092
return returned_projection
7193

72-
def _apply_joins_on_records(
94+
def _apply_joins_on_simplified_records(
7395
self, initial_projection: Projection, requested_projection: Projection, records: List[RecordsDataAlias]
7496
) -> List[RecordsDataAlias]:
7597
if requested_projection == initial_projection:
@@ -84,11 +106,27 @@ def _apply_joins_on_records(
84106
relation_schema = self.schema["fields"][relation]
85107

86108
if is_many_to_one(relation_schema):
87-
fk_value = record[self._get_fk_field_for_projection(f"{relation}:{relation_projections[0]}")]
109+
fk_value = record[
110+
self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}")
111+
]
88112
record[relation] = {relation_projections[0]: fk_value} if fk_value else None
89113

90114
# remove foreign keys
91115
for projection in projections_to_rm:
92116
del record[projection]
93117

94118
return records
119+
120+
def _replace_fields_in_aggregate_group(
121+
self, aggregate_results: List[AggregateResult], field_to_replace: Dict[str, str]
122+
) -> List[AggregateResult]:
123+
for aggregate_result in aggregate_results:
124+
group = {}
125+
for field, value in aggregate_result["group"].items():
126+
if field in field_to_replace:
127+
group[field_to_replace[field]] = value
128+
else:
129+
group[field] = value
130+
aggregate_result["group"] = group
131+
132+
return aggregate_results

src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py

+135-1
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
from forestadmin.datasource_toolkit.decorators.datasource_decorator import DatasourceDecorator
1515
from forestadmin.datasource_toolkit.decorators.lazy_join.collection import LazyJoinCollectionDecorator
1616
from forestadmin.datasource_toolkit.interfaces.fields import Column, FieldType, ManyToOne, OneToMany, PrimitiveType
17+
from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation
1718
from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf
1819
from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter
20+
from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter
1921
from forestadmin.datasource_toolkit.interfaces.query.projections import Projection
2022

2123

@@ -53,6 +55,10 @@ def setUpClass(cls) -> None:
5355
column_type=PrimitiveType.STRING,
5456
type=FieldType.COLUMN,
5557
),
58+
"price": Column(
59+
column_type=PrimitiveType.NUMBER,
60+
type=FieldType.COLUMN,
61+
),
5662
}
5763
)
5864
cls.collection_person = Collection("Person", cls.datasource)
@@ -226,7 +232,7 @@ def test_should_disable_join_on_projection_but_not_in_condition_tree(self):
226232
response,
227233
)
228234

229-
def test_should_correctly_handle_null_relations(self):
235+
def test_should_correctly_handle_null_relations_on_list(self):
230236
with patch.object(
231237
self.collection_book,
232238
"list",
@@ -252,3 +258,131 @@ def test_should_correctly_handle_null_relations(self):
252258
],
253259
result,
254260
)
261+
262+
def test_should_not_join_on_aggregate_when_group_by_foreign_pk(self):
263+
with patch.object(
264+
self.collection_book,
265+
"aggregate",
266+
new_callable=AsyncMock,
267+
return_value=[
268+
{"value": 1824.11, "group": {"author_id": 2}},
269+
{"value": 824.11, "group": {"author_id": 3}},
270+
],
271+
) as mock_aggregate:
272+
result = self.loop.run_until_complete(
273+
self.decorated_book_collection.aggregate(
274+
self.mocked_caller,
275+
Filter({}),
276+
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}),
277+
None,
278+
)
279+
)
280+
self.assertEqual(
281+
result,
282+
[
283+
{"value": 1824.11, "group": {"author:id": 2}},
284+
{"value": 824.11, "group": {"author:id": 3}},
285+
],
286+
)
287+
mock_aggregate.assert_awaited_once_with(
288+
self.mocked_caller,
289+
Filter({}),
290+
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}),
291+
None,
292+
)
293+
294+
def test_should_join_on_aggregate_when_group_by_foreign_field(self):
295+
with patch.object(
296+
self.collection_book,
297+
"aggregate",
298+
new_callable=AsyncMock,
299+
return_value=[
300+
{"value": 1824.11, "group": {"author:first_name": "Isaac"}},
301+
{"value": 824.11, "group": {"author:first_name": "JK"}},
302+
],
303+
) as mock_aggregate:
304+
result = self.loop.run_until_complete(
305+
self.decorated_book_collection.aggregate(
306+
self.mocked_caller,
307+
Filter({}),
308+
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:first_name"}]}),
309+
None,
310+
)
311+
)
312+
self.assertEqual(
313+
result,
314+
[
315+
{"value": 1824.11, "group": {"author:first_name": "Isaac"}},
316+
{"value": 824.11, "group": {"author:first_name": "JK"}},
317+
],
318+
)
319+
mock_aggregate.assert_awaited_once_with(
320+
self.mocked_caller,
321+
Filter({}),
322+
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:first_name"}]}),
323+
None,
324+
)
325+
326+
def test_should_not_join_on_aggregate_when_group_by_foreign_pk_and_filter_on_foreign_pk(self):
327+
with patch.object(
328+
self.collection_book,
329+
"aggregate",
330+
new_callable=AsyncMock,
331+
return_value=[
332+
{"value": 1824.11, "group": {"author_id": 2}},
333+
{"value": 824.11, "group": {"author_id": 3}},
334+
],
335+
) as mock_aggregate:
336+
result = self.loop.run_until_complete(
337+
self.decorated_book_collection.aggregate(
338+
self.mocked_caller,
339+
Filter({"condition_tree": ConditionTreeLeaf("author:id", "not_equal", 50)}),
340+
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}),
341+
None,
342+
)
343+
)
344+
self.assertEqual(
345+
result,
346+
[
347+
{"value": 1824.11, "group": {"author:id": 2}},
348+
{"value": 824.11, "group": {"author:id": 3}},
349+
],
350+
)
351+
mock_aggregate.assert_awaited_once_with(
352+
self.mocked_caller,
353+
Filter({"condition_tree": ConditionTreeLeaf("author_id", "not_equal", 50)}),
354+
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}),
355+
None,
356+
)
357+
358+
def test_should_join_on_aggregate_when_group_by_foreign_pk_and_filter_on_foreign_field(self):
359+
with patch.object(
360+
self.collection_book,
361+
"aggregate",
362+
new_callable=AsyncMock,
363+
return_value=[
364+
{"value": 1824.11, "group": {"author_id": 2}},
365+
{"value": 824.11, "group": {"author_id": 3}},
366+
],
367+
) as mock_aggregate:
368+
result = self.loop.run_until_complete(
369+
self.decorated_book_collection.aggregate(
370+
self.mocked_caller,
371+
Filter({"condition_tree": ConditionTreeLeaf("author:first_name", "not_equal", "wrong_name")}),
372+
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}),
373+
None,
374+
)
375+
)
376+
self.assertEqual(
377+
result,
378+
[
379+
{"value": 1824.11, "group": {"author:id": 2}},
380+
{"value": 824.11, "group": {"author:id": 3}},
381+
],
382+
)
383+
mock_aggregate.assert_awaited_once_with(
384+
self.mocked_caller,
385+
Filter({"condition_tree": ConditionTreeLeaf("author:first_name", "not_equal", "wrong_name")}),
386+
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}),
387+
None,
388+
)

0 commit comments

Comments
 (0)