1
- from typing import List , Union , cast
1
+ from typing import Dict , List , Optional , Union , cast
2
2
3
3
from forestadmin .agent_toolkit .utils .context import User
4
4
from forestadmin .datasource_toolkit .decorators .collection_decorator import CollectionDecorator
5
5
from forestadmin .datasource_toolkit .interfaces .fields import ManyToOne , is_many_to_one
6
+ from forestadmin .datasource_toolkit .interfaces .query .aggregation import AggregateResult , Aggregation
6
7
from forestadmin .datasource_toolkit .interfaces .query .condition_tree .nodes .leaf import ConditionTreeLeaf
7
8
from forestadmin .datasource_toolkit .interfaces .query .filter .paginated import PaginatedFilter
8
9
from forestadmin .datasource_toolkit .interfaces .query .filter .unpaginated import Filter
@@ -17,7 +18,7 @@ async def list(self, caller: User, filter_: PaginatedFilter, projection: Project
17
18
refined_filter = cast (PaginatedFilter , await self ._refine_filter (caller , filter_ ))
18
19
ret = await self .child_collection .list (caller , refined_filter , simplified_projection )
19
20
20
- return self ._apply_joins_on_records (projection , simplified_projection , ret )
21
+ return self ._apply_joins_on_simplified_records (projection , simplified_projection , ret )
21
22
22
23
async def _refine_filter (
23
24
self , caller : User , _filter : Union [Filter , PaginatedFilter , None ]
@@ -28,18 +29,39 @@ async def _refine_filter(
28
29
_filter .condition_tree = _filter .condition_tree .replace (
29
30
lambda leaf : (
30
31
ConditionTreeLeaf (
31
- self ._get_fk_field_for_projection (leaf .field ),
32
+ self ._get_fk_field_for_many_to_one_projection (leaf .field ),
32
33
leaf .operator ,
33
34
leaf .value ,
34
35
)
35
- if self ._is_useless_join (leaf .field .split (":" )[0 ], _filter .condition_tree .projection )
36
+ if self ._is_useless_join_for_projection (leaf .field .split (":" )[0 ], _filter .condition_tree .projection )
36
37
else leaf
37
38
)
38
39
)
39
40
40
41
return _filter
41
42
42
- def _is_useless_join (self , relation : str , projection : Projection ) -> bool :
43
+ async def aggregate (
44
+ self , caller : User , filter_ : Union [Filter , None ], aggregation : Aggregation , limit : Optional [int ] = None
45
+ ) -> List [AggregateResult ]:
46
+ replaced = {} # new_name -> old_name; for a simpler reconciliation
47
+
48
+ def replacer (field_name : str ) -> str :
49
+ if self ._is_useless_join_for_projection (field_name .split (":" )[0 ], aggregation .projection ):
50
+ new_field_name = self ._get_fk_field_for_many_to_one_projection (field_name )
51
+ replaced [new_field_name ] = field_name
52
+ return new_field_name
53
+ return field_name
54
+
55
+ new_aggregation = aggregation .replace_fields (replacer )
56
+
57
+ aggregate_results = await self .child_collection .aggregate (
58
+ caller , cast (Filter , await self ._refine_filter (caller , filter_ )), new_aggregation , limit
59
+ )
60
+ if aggregation == new_aggregation :
61
+ return aggregate_results
62
+ return self ._replace_fields_in_aggregate_group (aggregate_results , replaced )
63
+
64
+ def _is_useless_join_for_projection (self , relation : str , projection : Projection ) -> bool :
43
65
relation_schema = self .schema ["fields" ][relation ]
44
66
sub_projections = projection .relations [relation ]
45
67
@@ -49,7 +71,7 @@ def _is_useless_join(self, relation: str, projection: Projection) -> bool:
49
71
and sub_projections [0 ] == relation_schema ["foreign_key_target" ]
50
72
)
51
73
52
- def _get_fk_field_for_projection (self , projection : str ) -> str :
74
+ def _get_fk_field_for_many_to_one_projection (self , projection : str ) -> str :
53
75
relation_name = projection .split (":" )[0 ]
54
76
relation_schema = cast (ManyToOne , self .schema ["fields" ][relation_name ])
55
77
@@ -58,18 +80,18 @@ def _get_fk_field_for_projection(self, projection: str) -> str:
58
80
def _get_projection_without_useless_joins (self , projection : Projection ) -> Projection :
59
81
returned_projection = Projection (* projection )
60
82
for relation , relation_projections in projection .relations .items ():
61
- if self ._is_useless_join (relation , projection ):
83
+ if self ._is_useless_join_for_projection (relation , projection ):
62
84
# remove foreign key target from projection
63
85
returned_projection .remove (f"{ relation } :{ relation_projections [0 ]} " )
64
86
65
87
# add foreign keys to projection
66
- fk_field = self ._get_fk_field_for_projection ( relation )
88
+ fk_field = self ._get_fk_field_for_many_to_one_projection ( f" { relation } : { relation_projections [ 0 ] } " )
67
89
if fk_field not in returned_projection :
68
90
returned_projection .append (fk_field )
69
91
70
92
return returned_projection
71
93
72
- def _apply_joins_on_records (
94
+ def _apply_joins_on_simplified_records (
73
95
self , initial_projection : Projection , requested_projection : Projection , records : List [RecordsDataAlias ]
74
96
) -> List [RecordsDataAlias ]:
75
97
if requested_projection == initial_projection :
@@ -84,11 +106,27 @@ def _apply_joins_on_records(
84
106
relation_schema = self .schema ["fields" ][relation ]
85
107
86
108
if is_many_to_one (relation_schema ):
87
- fk_value = record [self ._get_fk_field_for_projection (f"{ relation } :{ relation_projections [0 ]} " )]
109
+ fk_value = record [
110
+ self ._get_fk_field_for_many_to_one_projection (f"{ relation } :{ relation_projections [0 ]} " )
111
+ ]
88
112
record [relation ] = {relation_projections [0 ]: fk_value } if fk_value else None
89
113
90
114
# remove foreign keys
91
115
for projection in projections_to_rm :
92
116
del record [projection ]
93
117
94
118
return records
119
+
120
+ def _replace_fields_in_aggregate_group (
121
+ self , aggregate_results : List [AggregateResult ], field_to_replace : Dict [str , str ]
122
+ ) -> List [AggregateResult ]:
123
+ for aggregate_result in aggregate_results :
124
+ group = {}
125
+ for field , value in aggregate_result ["group" ].items ():
126
+ if field in field_to_replace :
127
+ group [field_to_replace [field ]] = value
128
+ else :
129
+ group [field ] = value
130
+ aggregate_result ["group" ] = group
131
+
132
+ return aggregate_results
0 commit comments