Skip to content

Commit 34e72f9

Browse files
committed
chore: refactor a bit
1 parent 84634e8 commit 34e72f9

File tree

1 file changed

+33
-28
lines changed
  • src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join

1 file changed

+33
-28
lines changed
Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Union, cast
1+
from typing import Dict, List, Optional, Union, cast
22

33
from forestadmin.agent_toolkit.utils.context import User
44
from forestadmin.datasource_toolkit.decorators.collection_decorator import CollectionDecorator
@@ -18,7 +18,7 @@ async def list(self, caller: User, filter_: PaginatedFilter, projection: Project
1818
refined_filter = cast(PaginatedFilter, await self._refine_filter(caller, filter_))
1919
ret = await self.child_collection.list(caller, refined_filter, simplified_projection)
2020

21-
return self._apply_joins_on_records(projection, simplified_projection, ret)
21+
return self._apply_joins_on_simplified_records(projection, simplified_projection, ret)
2222

2323
async def _refine_filter(
2424
self, caller: User, _filter: Union[Filter, PaginatedFilter, None]
@@ -29,11 +29,11 @@ async def _refine_filter(
2929
_filter.condition_tree = _filter.condition_tree.replace(
3030
lambda leaf: (
3131
ConditionTreeLeaf(
32-
self._get_fk_field_for_projection(leaf.field),
32+
self._get_fk_field_for_many_to_one_projection(leaf.field),
3333
leaf.operator,
3434
leaf.value,
3535
)
36-
if self._is_useless_join(leaf.field.split(":")[0], _filter.condition_tree.projection)
36+
if self._is_useless_join_for_projection(leaf.field.split(":")[0], _filter.condition_tree.projection)
3737
else leaf
3838
)
3939
)
@@ -43,36 +43,25 @@ async def _refine_filter(
4343
async def aggregate(
4444
self, caller: User, filter_: Union[Filter, None], aggregation: Aggregation, limit: Optional[int] = None
4545
) -> List[AggregateResult]:
46-
replaced = {}
46+
replaced = {} # new_name -> old_name; for a simpler reconciliation
4747

4848
def replacer(field_name: str) -> str:
49-
if self._is_useless_join(field_name.split(":")[0], aggregation.projection):
50-
new_field_name = self._get_fk_field_for_projection(field_name)
49+
if self._is_useless_join_for_projection(field_name.split(":")[0], aggregation.projection):
50+
new_field_name = self._get_fk_field_for_many_to_one_projection(field_name)
5151
replaced[new_field_name] = field_name
5252
return new_field_name
53-
else:
54-
return field_name
53+
return field_name
5554

5655
new_aggregation = aggregation.replace_fields(replacer)
5756

58-
aggregate_result = await self.child_collection.aggregate(
57+
aggregate_results = await self.child_collection.aggregate(
5958
caller, cast(Filter, await self._refine_filter(caller, filter_)), new_aggregation, limit
6059
)
6160
if aggregation == new_aggregation:
62-
return aggregate_result
61+
return aggregate_results
62+
return self._replace_fields_in_aggregate_group(aggregate_results, replaced)
6363

64-
for result in aggregate_result:
65-
group = {}
66-
for field, value in result["group"].items():
67-
if field in replaced:
68-
group[replaced[field]] = value
69-
else:
70-
group[field] = value
71-
result["group"] = group
72-
73-
return aggregate_result
74-
75-
def _is_useless_join(self, relation: str, projection: Projection) -> bool:
64+
def _is_useless_join_for_projection(self, relation: str, projection: Projection) -> bool:
7665
relation_schema = self.schema["fields"][relation]
7766
sub_projections = projection.relations[relation]
7867

@@ -82,7 +71,7 @@ def _is_useless_join(self, relation: str, projection: Projection) -> bool:
8271
and sub_projections[0] == relation_schema["foreign_key_target"]
8372
)
8473

85-
def _get_fk_field_for_projection(self, projection: str) -> str:
74+
def _get_fk_field_for_many_to_one_projection(self, projection: str) -> str:
8675
relation_name = projection.split(":")[0]
8776
relation_schema = cast(ManyToOne, self.schema["fields"][relation_name])
8877

@@ -91,18 +80,18 @@ def _get_fk_field_for_projection(self, projection: str) -> str:
9180
def _get_projection_without_useless_joins(self, projection: Projection) -> Projection:
9281
returned_projection = Projection(*projection)
9382
for relation, relation_projections in projection.relations.items():
94-
if self._is_useless_join(relation, projection):
83+
if self._is_useless_join_for_projection(relation, projection):
9584
# remove foreign key target from projection
9685
returned_projection.remove(f"{relation}:{relation_projections[0]}")
9786

9887
# add foreign keys to projection
99-
fk_field = self._get_fk_field_for_projection(f"{relation}:{relation_projections[0]}")
88+
fk_field = self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}")
10089
if fk_field not in returned_projection:
10190
returned_projection.append(fk_field)
10291

10392
return returned_projection
10493

105-
def _apply_joins_on_records(
94+
def _apply_joins_on_simplified_records(
10695
self, initial_projection: Projection, requested_projection: Projection, records: List[RecordsDataAlias]
10796
) -> List[RecordsDataAlias]:
10897
if requested_projection == initial_projection:
@@ -117,11 +106,27 @@ def _apply_joins_on_records(
117106
relation_schema = self.schema["fields"][relation]
118107

119108
if is_many_to_one(relation_schema):
120-
fk_value = record[self._get_fk_field_for_projection(f"{relation}:{relation_projections[0]}")]
109+
fk_value = record[
110+
self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}")
111+
]
121112
record[relation] = {relation_projections[0]: fk_value} if fk_value else None
122113

123114
# remove foreign keys
124115
for projection in projections_to_rm:
125116
del record[projection]
126117

127118
return records
119+
120+
def _replace_fields_in_aggregate_group(
121+
self, aggregate_results: List[AggregateResult], field_to_replace: Dict[str, str]
122+
) -> List[AggregateResult]:
123+
for aggregate_result in aggregate_results:
124+
group = {}
125+
for field, value in aggregate_result["group"].items():
126+
if field in field_to_replace:
127+
group[field_to_replace[field]] = value
128+
else:
129+
group[field] = value
130+
aggregate_result["group"] = group
131+
132+
return aggregate_results

0 commit comments

Comments
 (0)