diff --git a/.github/actions/coverage/action.yml b/.github/actions/coverage/action.yml index 9869110a..bfcb73a0 100644 --- a/.github/actions/coverage/action.yml +++ b/.github/actions/coverage/action.yml @@ -15,7 +15,6 @@ runs: steps: - name: Install poetry shell: bash - working-directory: ./src/ run: pipx install poetry - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 @@ -30,17 +29,16 @@ runs: uses: actions/download-artifact@v4 with: path: coverage + merge-multiple: true - name: Combine reports shell: bash - working-directory: ./src/ - run: poetry -C datasource_toolkit run coverage combine ../coverage/code-coverage-report.*/.coverage + run: pwd && ls -a ./coverage/code-coverage-report.agent_toolkit.3.10/.coverage && poetry -C ./src/datasource_toolkit run coverage combine ./coverage/code-coverage-report.*/.coverage - name: Send coverage uses: paambaati/codeclimate-action@v2.7.4 env: CC_TEST_REPORTER_ID: ${{ inputs.CC_TEST_REPORTER_ID }} with: - workingDirectory: ./src/ - coverageCommand: poetry -C datasource_toolkit run coverage xml + coverageCommand: poetry -C ./src/datasource_toolkit run coverage xml # debug # - name: Archive code coverage final results diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py index 42612c2b..59d956df 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py @@ -1,8 +1,9 @@ -from typing import List, Union, cast +from typing import Dict, List, Optional, Union, cast from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.decorators.collection_decorator import CollectionDecorator from forestadmin.datasource_toolkit.interfaces.fields import ManyToOne, is_many_to_one +from forestadmin.datasource_toolkit.interfaces.query.aggregation import AggregateResult, Aggregation from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter @@ -17,7 +18,7 @@ async def list(self, caller: User, filter_: PaginatedFilter, projection: Project refined_filter = cast(PaginatedFilter, await self._refine_filter(caller, filter_)) ret = await self.child_collection.list(caller, refined_filter, simplified_projection) - return self._apply_joins_on_records(projection, simplified_projection, ret) + return self._apply_joins_on_simplified_records(projection, simplified_projection, ret) async def _refine_filter( self, caller: User, _filter: Union[Filter, PaginatedFilter, None] @@ -28,18 +29,39 @@ async def _refine_filter( _filter.condition_tree = _filter.condition_tree.replace( lambda leaf: ( ConditionTreeLeaf( - self._get_fk_field_for_projection(leaf.field), + self._get_fk_field_for_many_to_one_projection(leaf.field), leaf.operator, leaf.value, ) - if self._is_useless_join(leaf.field.split(":")[0], _filter.condition_tree.projection) + if self._is_useless_join_for_projection(leaf.field.split(":")[0], _filter.condition_tree.projection) else leaf ) ) return _filter - def _is_useless_join(self, relation: str, projection: Projection) -> bool: + async def aggregate( + self, caller: User, filter_: Union[Filter, None], aggregation: Aggregation, limit: Optional[int] = None + ) -> List[AggregateResult]: + replaced = {} # new_name -> old_name; for a simpler reconciliation + + def replacer(field_name: str) -> str: + if self._is_useless_join_for_projection(field_name.split(":")[0], aggregation.projection): + new_field_name = self._get_fk_field_for_many_to_one_projection(field_name) + replaced[new_field_name] = field_name + return new_field_name + return field_name + + new_aggregation = aggregation.replace_fields(replacer) + + aggregate_results = await self.child_collection.aggregate( + caller, cast(Filter, await self._refine_filter(caller, filter_)), new_aggregation, limit + ) + if aggregation == new_aggregation: + return aggregate_results + return self._replace_fields_in_aggregate_group(aggregate_results, replaced) + + def _is_useless_join_for_projection(self, relation: str, projection: Projection) -> bool: relation_schema = self.schema["fields"][relation] sub_projections = projection.relations[relation] @@ -49,7 +71,7 @@ def _is_useless_join(self, relation: str, projection: Projection) -> bool: and sub_projections[0] == relation_schema["foreign_key_target"] ) - def _get_fk_field_for_projection(self, projection: str) -> str: + def _get_fk_field_for_many_to_one_projection(self, projection: str) -> str: relation_name = projection.split(":")[0] relation_schema = cast(ManyToOne, self.schema["fields"][relation_name]) @@ -58,18 +80,18 @@ def _get_fk_field_for_projection(self, projection: str) -> str: def _get_projection_without_useless_joins(self, projection: Projection) -> Projection: returned_projection = Projection(*projection) for relation, relation_projections in projection.relations.items(): - if self._is_useless_join(relation, projection): + if self._is_useless_join_for_projection(relation, projection): # remove foreign key target from projection returned_projection.remove(f"{relation}:{relation_projections[0]}") # add foreign keys to projection - fk_field = self._get_fk_field_for_projection(relation) + fk_field = self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}") if fk_field not in returned_projection: returned_projection.append(fk_field) return returned_projection - def _apply_joins_on_records( + def _apply_joins_on_simplified_records( self, initial_projection: Projection, requested_projection: Projection, records: List[RecordsDataAlias] ) -> List[RecordsDataAlias]: if requested_projection == initial_projection: @@ -84,7 +106,9 @@ def _apply_joins_on_records( relation_schema = self.schema["fields"][relation] if is_many_to_one(relation_schema): - fk_value = record[self._get_fk_field_for_projection(f"{relation}:{relation_projections[0]}")] + fk_value = record[ + self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}") + ] record[relation] = {relation_projections[0]: fk_value} if fk_value else None # remove foreign keys @@ -92,3 +116,17 @@ def _apply_joins_on_records( del record[projection] return records + + def _replace_fields_in_aggregate_group( + self, aggregate_results: List[AggregateResult], field_to_replace: Dict[str, str] + ) -> List[AggregateResult]: + for aggregate_result in aggregate_results: + group = {} + for field, value in aggregate_result["group"].items(): + if field in field_to_replace: + group[field_to_replace[field]] = value + else: + group[field] = value + aggregate_result["group"] = group + + return aggregate_results diff --git a/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py b/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py index 7e30e1b3..64710104 100644 --- a/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py +++ b/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py @@ -14,8 +14,10 @@ from forestadmin.datasource_toolkit.decorators.datasource_decorator import DatasourceDecorator from forestadmin.datasource_toolkit.decorators.lazy_join.collection import LazyJoinCollectionDecorator from forestadmin.datasource_toolkit.interfaces.fields import Column, FieldType, ManyToOne, OneToMany, PrimitiveType +from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter +from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter from forestadmin.datasource_toolkit.interfaces.query.projections import Projection @@ -53,6 +55,10 @@ def setUpClass(cls) -> None: column_type=PrimitiveType.STRING, type=FieldType.COLUMN, ), + "price": Column( + column_type=PrimitiveType.NUMBER, + type=FieldType.COLUMN, + ), } ) cls.collection_person = Collection("Person", cls.datasource) @@ -226,7 +232,7 @@ def test_should_disable_join_on_projection_but_not_in_condition_tree(self): response, ) - def test_should_correctly_handle_null_relations(self): + def test_should_correctly_handle_null_relations_on_list(self): with patch.object( self.collection_book, "list", @@ -252,3 +258,131 @@ def test_should_correctly_handle_null_relations(self): ], result, ) + + def test_should_not_join_on_aggregate_when_group_by_foreign_pk(self): + with patch.object( + self.collection_book, + "aggregate", + new_callable=AsyncMock, + return_value=[ + {"value": 1824.11, "group": {"author_id": 2}}, + {"value": 824.11, "group": {"author_id": 3}}, + ], + ) as mock_aggregate: + result = self.loop.run_until_complete( + self.decorated_book_collection.aggregate( + self.mocked_caller, + Filter({}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}), + None, + ) + ) + self.assertEqual( + result, + [ + {"value": 1824.11, "group": {"author:id": 2}}, + {"value": 824.11, "group": {"author:id": 3}}, + ], + ) + mock_aggregate.assert_awaited_once_with( + self.mocked_caller, + Filter({}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}), + None, + ) + + def test_should_join_on_aggregate_when_group_by_foreign_field(self): + with patch.object( + self.collection_book, + "aggregate", + new_callable=AsyncMock, + return_value=[ + {"value": 1824.11, "group": {"author:first_name": "Isaac"}}, + {"value": 824.11, "group": {"author:first_name": "JK"}}, + ], + ) as mock_aggregate: + result = self.loop.run_until_complete( + self.decorated_book_collection.aggregate( + self.mocked_caller, + Filter({}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:first_name"}]}), + None, + ) + ) + self.assertEqual( + result, + [ + {"value": 1824.11, "group": {"author:first_name": "Isaac"}}, + {"value": 824.11, "group": {"author:first_name": "JK"}}, + ], + ) + mock_aggregate.assert_awaited_once_with( + self.mocked_caller, + Filter({}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:first_name"}]}), + None, + ) + + def test_should_not_join_on_aggregate_when_group_by_foreign_pk_and_filter_on_foreign_pk(self): + with patch.object( + self.collection_book, + "aggregate", + new_callable=AsyncMock, + return_value=[ + {"value": 1824.11, "group": {"author_id": 2}}, + {"value": 824.11, "group": {"author_id": 3}}, + ], + ) as mock_aggregate: + result = self.loop.run_until_complete( + self.decorated_book_collection.aggregate( + self.mocked_caller, + Filter({"condition_tree": ConditionTreeLeaf("author:id", "not_equal", 50)}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}), + None, + ) + ) + self.assertEqual( + result, + [ + {"value": 1824.11, "group": {"author:id": 2}}, + {"value": 824.11, "group": {"author:id": 3}}, + ], + ) + mock_aggregate.assert_awaited_once_with( + self.mocked_caller, + Filter({"condition_tree": ConditionTreeLeaf("author_id", "not_equal", 50)}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}), + None, + ) + + def test_should_join_on_aggregate_when_group_by_foreign_pk_and_filter_on_foreign_field(self): + with patch.object( + self.collection_book, + "aggregate", + new_callable=AsyncMock, + return_value=[ + {"value": 1824.11, "group": {"author_id": 2}}, + {"value": 824.11, "group": {"author_id": 3}}, + ], + ) as mock_aggregate: + result = self.loop.run_until_complete( + self.decorated_book_collection.aggregate( + self.mocked_caller, + Filter({"condition_tree": ConditionTreeLeaf("author:first_name", "not_equal", "wrong_name")}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}), + None, + ) + ) + self.assertEqual( + result, + [ + {"value": 1824.11, "group": {"author:id": 2}}, + {"value": 824.11, "group": {"author:id": 3}}, + ], + ) + mock_aggregate.assert_awaited_once_with( + self.mocked_caller, + Filter({"condition_tree": ConditionTreeLeaf("author:first_name", "not_equal", "wrong_name")}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}), + None, + )