Skip to content

Commit

Permalink
perf: add aggregate to lazy join (#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbarreau committed Jan 6, 2025
1 parent 2989e54 commit 04d5010
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 16 deletions.
8 changes: 3 additions & 5 deletions .github/actions/coverage/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ runs:
steps:
- name: Install poetry
shell: bash
working-directory: ./src/
run: pipx install poetry
- name: Set up Python ${{ inputs.python-version }}
uses: actions/setup-python@v4
Expand All @@ -30,17 +29,16 @@ runs:
uses: actions/download-artifact@v4
with:
path: coverage
merge-multiple: true
- name: Combine reports
shell: bash
working-directory: ./src/
run: poetry -C datasource_toolkit run coverage combine ../coverage/code-coverage-report.*/.coverage
run: pwd && ls -a ./coverage/code-coverage-report.agent_toolkit.3.10/.coverage && poetry -C ./src/datasource_toolkit run coverage combine ./coverage/code-coverage-report.*/.coverage
- name: Send coverage
uses: paambaati/[email protected]
env:
CC_TEST_REPORTER_ID: ${{ inputs.CC_TEST_REPORTER_ID }}
with:
workingDirectory: ./src/
coverageCommand: poetry -C datasource_toolkit run coverage xml
coverageCommand: poetry -C ./src/datasource_toolkit run coverage xml

# debug
# - name: Archive code coverage final results
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List, Union, cast
from typing import Dict, List, Optional, Union, cast

from forestadmin.agent_toolkit.utils.context import User
from forestadmin.datasource_toolkit.decorators.collection_decorator import CollectionDecorator
from forestadmin.datasource_toolkit.interfaces.fields import ManyToOne, is_many_to_one
from forestadmin.datasource_toolkit.interfaces.query.aggregation import AggregateResult, Aggregation
from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf
from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter
from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter
Expand All @@ -17,7 +18,7 @@ async def list(self, caller: User, filter_: PaginatedFilter, projection: Project
refined_filter = cast(PaginatedFilter, await self._refine_filter(caller, filter_))
ret = await self.child_collection.list(caller, refined_filter, simplified_projection)

return self._apply_joins_on_records(projection, simplified_projection, ret)
return self._apply_joins_on_simplified_records(projection, simplified_projection, ret)

async def _refine_filter(
self, caller: User, _filter: Union[Filter, PaginatedFilter, None]
Expand All @@ -28,18 +29,39 @@ async def _refine_filter(
_filter.condition_tree = _filter.condition_tree.replace(
lambda leaf: (
ConditionTreeLeaf(
self._get_fk_field_for_projection(leaf.field),
self._get_fk_field_for_many_to_one_projection(leaf.field),
leaf.operator,
leaf.value,
)
if self._is_useless_join(leaf.field.split(":")[0], _filter.condition_tree.projection)
if self._is_useless_join_for_projection(leaf.field.split(":")[0], _filter.condition_tree.projection)
else leaf
)
)

return _filter

def _is_useless_join(self, relation: str, projection: Projection) -> bool:
async def aggregate(
self, caller: User, filter_: Union[Filter, None], aggregation: Aggregation, limit: Optional[int] = None
) -> List[AggregateResult]:
replaced = {} # new_name -> old_name; for a simpler reconciliation

def replacer(field_name: str) -> str:
if self._is_useless_join_for_projection(field_name.split(":")[0], aggregation.projection):
new_field_name = self._get_fk_field_for_many_to_one_projection(field_name)
replaced[new_field_name] = field_name
return new_field_name
return field_name

new_aggregation = aggregation.replace_fields(replacer)

aggregate_results = await self.child_collection.aggregate(
caller, cast(Filter, await self._refine_filter(caller, filter_)), new_aggregation, limit
)
if aggregation == new_aggregation:
return aggregate_results
return self._replace_fields_in_aggregate_group(aggregate_results, replaced)

def _is_useless_join_for_projection(self, relation: str, projection: Projection) -> bool:
relation_schema = self.schema["fields"][relation]
sub_projections = projection.relations[relation]

Expand All @@ -49,7 +71,7 @@ def _is_useless_join(self, relation: str, projection: Projection) -> bool:
and sub_projections[0] == relation_schema["foreign_key_target"]
)

def _get_fk_field_for_projection(self, projection: str) -> str:
def _get_fk_field_for_many_to_one_projection(self, projection: str) -> str:
relation_name = projection.split(":")[0]
relation_schema = cast(ManyToOne, self.schema["fields"][relation_name])

Expand All @@ -58,18 +80,18 @@ def _get_fk_field_for_projection(self, projection: str) -> str:
def _get_projection_without_useless_joins(self, projection: Projection) -> Projection:
returned_projection = Projection(*projection)
for relation, relation_projections in projection.relations.items():
if self._is_useless_join(relation, projection):
if self._is_useless_join_for_projection(relation, projection):
# remove foreign key target from projection
returned_projection.remove(f"{relation}:{relation_projections[0]}")

# add foreign keys to projection
fk_field = self._get_fk_field_for_projection(relation)
fk_field = self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}")
if fk_field not in returned_projection:
returned_projection.append(fk_field)

return returned_projection

def _apply_joins_on_records(
def _apply_joins_on_simplified_records(
self, initial_projection: Projection, requested_projection: Projection, records: List[RecordsDataAlias]
) -> List[RecordsDataAlias]:
if requested_projection == initial_projection:
Expand All @@ -84,11 +106,27 @@ def _apply_joins_on_records(
relation_schema = self.schema["fields"][relation]

if is_many_to_one(relation_schema):
fk_value = record[self._get_fk_field_for_projection(f"{relation}:{relation_projections[0]}")]
fk_value = record[
self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}")
]
record[relation] = {relation_projections[0]: fk_value} if fk_value else None

# remove foreign keys
for projection in projections_to_rm:
del record[projection]

return records

def _replace_fields_in_aggregate_group(
self, aggregate_results: List[AggregateResult], field_to_replace: Dict[str, str]
) -> List[AggregateResult]:
for aggregate_result in aggregate_results:
group = {}
for field, value in aggregate_result["group"].items():
if field in field_to_replace:
group[field_to_replace[field]] = value
else:
group[field] = value
aggregate_result["group"] = group

return aggregate_results
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from forestadmin.datasource_toolkit.decorators.datasource_decorator import DatasourceDecorator
from forestadmin.datasource_toolkit.decorators.lazy_join.collection import LazyJoinCollectionDecorator
from forestadmin.datasource_toolkit.interfaces.fields import Column, FieldType, ManyToOne, OneToMany, PrimitiveType
from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation
from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf
from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter
from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter
from forestadmin.datasource_toolkit.interfaces.query.projections import Projection


Expand Down Expand Up @@ -53,6 +55,10 @@ def setUpClass(cls) -> None:
column_type=PrimitiveType.STRING,
type=FieldType.COLUMN,
),
"price": Column(
column_type=PrimitiveType.NUMBER,
type=FieldType.COLUMN,
),
}
)
cls.collection_person = Collection("Person", cls.datasource)
Expand Down Expand Up @@ -226,7 +232,7 @@ def test_should_disable_join_on_projection_but_not_in_condition_tree(self):
response,
)

def test_should_correctly_handle_null_relations(self):
def test_should_correctly_handle_null_relations_on_list(self):
with patch.object(
self.collection_book,
"list",
Expand All @@ -252,3 +258,131 @@ def test_should_correctly_handle_null_relations(self):
],
result,
)

def test_should_not_join_on_aggregate_when_group_by_foreign_pk(self):
with patch.object(
self.collection_book,
"aggregate",
new_callable=AsyncMock,
return_value=[
{"value": 1824.11, "group": {"author_id": 2}},
{"value": 824.11, "group": {"author_id": 3}},
],
) as mock_aggregate:
result = self.loop.run_until_complete(
self.decorated_book_collection.aggregate(
self.mocked_caller,
Filter({}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}),
None,
)
)
self.assertEqual(
result,
[
{"value": 1824.11, "group": {"author:id": 2}},
{"value": 824.11, "group": {"author:id": 3}},
],
)
mock_aggregate.assert_awaited_once_with(
self.mocked_caller,
Filter({}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}),
None,
)

def test_should_join_on_aggregate_when_group_by_foreign_field(self):
with patch.object(
self.collection_book,
"aggregate",
new_callable=AsyncMock,
return_value=[
{"value": 1824.11, "group": {"author:first_name": "Isaac"}},
{"value": 824.11, "group": {"author:first_name": "JK"}},
],
) as mock_aggregate:
result = self.loop.run_until_complete(
self.decorated_book_collection.aggregate(
self.mocked_caller,
Filter({}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:first_name"}]}),
None,
)
)
self.assertEqual(
result,
[
{"value": 1824.11, "group": {"author:first_name": "Isaac"}},
{"value": 824.11, "group": {"author:first_name": "JK"}},
],
)
mock_aggregate.assert_awaited_once_with(
self.mocked_caller,
Filter({}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:first_name"}]}),
None,
)

def test_should_not_join_on_aggregate_when_group_by_foreign_pk_and_filter_on_foreign_pk(self):
with patch.object(
self.collection_book,
"aggregate",
new_callable=AsyncMock,
return_value=[
{"value": 1824.11, "group": {"author_id": 2}},
{"value": 824.11, "group": {"author_id": 3}},
],
) as mock_aggregate:
result = self.loop.run_until_complete(
self.decorated_book_collection.aggregate(
self.mocked_caller,
Filter({"condition_tree": ConditionTreeLeaf("author:id", "not_equal", 50)}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}),
None,
)
)
self.assertEqual(
result,
[
{"value": 1824.11, "group": {"author:id": 2}},
{"value": 824.11, "group": {"author:id": 3}},
],
)
mock_aggregate.assert_awaited_once_with(
self.mocked_caller,
Filter({"condition_tree": ConditionTreeLeaf("author_id", "not_equal", 50)}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}),
None,
)

def test_should_join_on_aggregate_when_group_by_foreign_pk_and_filter_on_foreign_field(self):
with patch.object(
self.collection_book,
"aggregate",
new_callable=AsyncMock,
return_value=[
{"value": 1824.11, "group": {"author_id": 2}},
{"value": 824.11, "group": {"author_id": 3}},
],
) as mock_aggregate:
result = self.loop.run_until_complete(
self.decorated_book_collection.aggregate(
self.mocked_caller,
Filter({"condition_tree": ConditionTreeLeaf("author:first_name", "not_equal", "wrong_name")}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}),
None,
)
)
self.assertEqual(
result,
[
{"value": 1824.11, "group": {"author:id": 2}},
{"value": 824.11, "group": {"author:id": 3}},
],
)
mock_aggregate.assert_awaited_once_with(
self.mocked_caller,
Filter({"condition_tree": ConditionTreeLeaf("author:first_name", "not_equal", "wrong_name")}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}),
None,
)

0 comments on commit 04d5010

Please sign in to comment.