Skip to content

Commit c2df420

Browse files
committed
refactor: Move all state into AggExpr, docs
Resolving #2680 (comment)
1 parent 4944e38 commit c2df420

File tree

1 file changed

+140
-68
lines changed

1 file changed

+140
-68
lines changed

narwhals/_pandas_like/group_by.py

Lines changed: 140 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,13 @@
6767
_NamedAgg: TypeAlias = "tuple[str, _AggFunc]"
6868
"""Equivalent to `pd.NamedAgg`."""
6969

70-
SeqStrT = TypeVar("SeqStrT", bound="Sequence[str]", default="list[str]")
70+
IterStrT = TypeVar("IterStrT", bound="Iterable[str]")
7171

7272
NonStrHashable: TypeAlias = Any
7373
"""Because `pandas` allows *"names"* like that 😭"""
7474

75+
T = TypeVar("T")
76+
7577

7678
@lru_cache(maxsize=32)
7779
def _agg_func(
@@ -94,16 +96,15 @@ def _n_unique(self: NativeSeriesGroupBy) -> pd.Series[Any]:
9496
return self.nunique(dropna=False)
9597

9698

97-
# PLAN
98-
# ----
99-
# - Before aggregating, rename every column that isn't already a `str`
100-
# - Proxy all incoming expressions through a rename mapper
101-
# - At the end, rename back to the original
102-
def _remap_non_str(
103-
original: Sequence[Any], exclude: Iterable[Any]
104-
) -> dict[NonStrHashable, str]:
105-
"""An empty result follows a no-op path in `with_columns."""
106-
exclude = set(exclude)
99+
def _remap_non_str(group_by: PandasLikeGroupBy) -> dict[NonStrHashable, str]:
100+
"""Before aggregating, rename every column that isn't already a `str`.
101+
102+
- Proxy all incoming expressions through a rename mapper
103+
- At the end, rename back to the original
104+
- An empty result follows a no-op path in `with_columns.
105+
"""
106+
original = group_by._original_columns
107+
exclude = set(group_by.exclude)
107108
if remaining := set(original).difference(exclude):
108109
union = exclude.union(original)
109110
return {
@@ -114,6 +115,106 @@ def _remap_non_str(
114115
return {} # pragma: no cover
115116

116117

118+
def collect(iterable: tuple[T, ...] | Iterable[T], /) -> tuple[T, ...]:
119+
"""Collect `iterable` into a `tuple`, *iff* it is not one already.
120+
121+
Borrowed from [`ExprIR` PR].
122+
123+
[`ExprIR` PR]: https://github.com/narwhals-dev/narwhals/blob/1de65d2f82ace95a9bc72667067ffdfa9d28be6d/narwhals/_plan/common.py#L396-L398
124+
"""
125+
return iterable if isinstance(iterable, tuple) else tuple(iterable)
126+
127+
128+
class AggExpr:
129+
"""Wrapper storing the intermediate state per-`PandasLikeExpr`.
130+
131+
There's a lot of edge cases to handle, so aim to evaluate as little
132+
as possible - and store anything that's needed twice.
133+
134+
Warning:
135+
While a `PandasLikeExpr` can be reused - this wrapper is valid **only**
136+
in a single `.agg(...)` operation.
137+
"""
138+
139+
expr: PandasLikeExpr
140+
output_names: tuple[str, ...]
141+
aliases: tuple[str, ...]
142+
_native_func: _AggFunc
143+
144+
def __init__(self, expr: PandasLikeExpr) -> None:
145+
self.expr = expr
146+
self.output_names = ()
147+
self.aliases = ()
148+
self._leaf_name: NarwhalsAggregation | Any = ""
149+
150+
def with_expand_names(self, group_by: PandasLikeGroupBy, /) -> AggExpr:
151+
"""**Mutating operation**.
152+
153+
Stores the results of `evaluate_output_names_and_aliases`.
154+
"""
155+
df = group_by.compliant
156+
exclude = group_by.exclude
157+
output_names, aliases = evaluate_output_names_and_aliases(self.expr, df, exclude)
158+
self.output_names, self.aliases = collect(output_names), collect(aliases)
159+
return self
160+
161+
def named_aggs(
162+
self, group_by: PandasLikeGroupBy, /
163+
) -> Iterator[tuple[str, _NamedAgg]]:
164+
aliases = collect(group_by._aliases_str(self.aliases))
165+
native_func = self.native_func
166+
if self.is_len() and self.is_anonymous():
167+
yield aliases[0], (group_by._anonymous_column_name, native_func)
168+
return
169+
for output_name, alias in zip(self.output_names, aliases):
170+
yield alias, (output_name, native_func)
171+
172+
def _cast_coerced(self, group_by: PandasLikeGroupBy, /) -> Iterator[PandasLikeExpr]:
173+
"""Yield post-agg casts to correct for weird pandas behavior.
174+
175+
See https://github.com/narwhals-dev/narwhals/pull/2680#discussion_r2157251589
176+
"""
177+
df = group_by.compliant
178+
if self.is_n_unique() and has_non_int_nullable_dtype(df, self.output_names):
179+
ns = df.__narwhals_namespace__()
180+
yield ns.col(*self.aliases).cast(ns._version.dtypes.Int64())
181+
182+
def is_len(self) -> bool:
183+
return self.leaf_name == "len"
184+
185+
def is_n_unique(self) -> bool:
186+
return self.leaf_name == "n_unique"
187+
188+
def is_anonymous(self) -> bool:
189+
return self.expr._depth == 0
190+
191+
@property
192+
def implementation(self) -> Implementation:
193+
return self.expr._implementation
194+
195+
@property
196+
def kwargs(self) -> ScalarKwargs:
197+
return self.expr._scalar_kwargs
198+
199+
@property
200+
def leaf_name(self) -> NarwhalsAggregation | Any:
201+
if name := self._leaf_name:
202+
return name
203+
self._leaf_name = PandasLikeGroupBy._leaf_name(self.expr)
204+
return self._leaf_name
205+
206+
@property
207+
def native_func(self) -> _AggFunc:
208+
if hasattr(self, "_native_func"):
209+
return self._native_func
210+
self._native_func = _agg_func(
211+
PandasLikeGroupBy._remap_expr_name(self.leaf_name),
212+
self.implementation,
213+
**self.kwargs,
214+
)
215+
return self._native_func
216+
217+
117218
class PandasLikeGroupBy(
118219
EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr", NativeAggregation]
119220
):
@@ -140,12 +241,6 @@ class PandasLikeGroupBy(
140241

141242
_remap_non_str_columns: dict[NonStrHashable, str]
142243

143-
_casts: list[PandasLikeExpr]
144-
"""Post-aggregation dtype cast expressions.
145-
146-
See https://github.com/narwhals-dev/narwhals/pull/2680#discussion_r2157251589
147-
"""
148-
149244
@property
150245
def exclude(self) -> tuple[str, ...]:
151246
"""Group keys to ignore when expanding multi-output aggregations."""
@@ -164,11 +259,10 @@ def __init__(
164259
ns = df.__narwhals_namespace__()
165260
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
166261
self._exclude: tuple[str, ...] = (*self._keys, *self._output_key_names)
167-
self._remap_non_str_columns = _remap_non_str(self._original_columns, self.exclude)
262+
self._remap_non_str_columns = _remap_non_str(self)
168263
self._compliant_frame = frame.with_columns(
169264
*(ns.col(old).alias(new) for old, new in self._remap_non_str_columns.items())
170265
)
171-
self._casts = []
172266
# Drop index to avoid potential collisions:
173267
# https://github.com/narwhals-dev/narwhals/issues/1907.
174268
native = self.compliant.native
@@ -183,68 +277,47 @@ def __init__(
183277
)
184278

185279
def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame:
186-
new_names: list[str] = self._keys.copy()
187280
all_aggs_are_simple = True
281+
agg_exprs: list[AggExpr] = []
188282
for expr in exprs:
189-
_, aliases = evaluate_output_names_and_aliases(
190-
expr, self.compliant, self.exclude
191-
)
192-
new_names.extend(aliases)
283+
agg_exprs.append(AggExpr(expr).with_expand_names(self))
193284
if not self._is_simple(expr):
194285
all_aggs_are_simple = False
195286

196-
if any(not isinstance(k, str) for k in new_names):
197-
new_names = self._remap_aliases(new_names)
198287
if all_aggs_are_simple:
199288
result: pd.DataFrame
200-
if named_aggs := self._named_aggs(*exprs):
289+
if named_aggs := self._named_aggs(agg_exprs):
201290
result = self._grouped.agg(**named_aggs) # type: ignore[call-overload]
202291
else:
203292
result = self.compliant.__native_namespace__().DataFrame(
204293
list(self._grouped.groups), columns=self._keys
205294
)
206-
return self._select_results(result, new_names)
207-
if self.compliant.native.empty:
295+
elif self.compliant.native.empty:
208296
raise empty_results_error()
209-
return self._agg_complex(exprs, new_names)
297+
else:
298+
result = self._agg_complex(exprs)
299+
return self._select_results(result, agg_exprs)
300+
301+
def _named_aggs(self, exprs: Iterable[AggExpr], /) -> dict[str, _NamedAgg]:
302+
"""Collect all aggregations into a single mapping."""
303+
return dict(chain.from_iterable(e.named_aggs(self) for e in exprs))
210304

211305
@overload
212-
def _remap_aliases(self, names: list[str], /) -> list[str]: ...
306+
def _aliases_str(self, names: list[str], /) -> list[str]: ...
213307
@overload
214-
def _remap_aliases(self, names: SeqStrT, /) -> list[str] | SeqStrT: ...
215-
def _remap_aliases(self, names: SeqStrT, /) -> list[str] | SeqStrT:
308+
def _aliases_str(self, names: IterStrT, /) -> list[str] | IterStrT: ...
309+
def _aliases_str(self, names: IterStrT, /) -> list[str] | IterStrT:
310+
"""If we started with any non `str` column names, return the proxied `str` aliases for `names`."""
216311
if remap := self._remap_non_str_columns:
217312
return [remap.get(name, name) for name in names]
218313
return names
219314

220-
def _named_aggs(self, *exprs: PandasLikeExpr) -> dict[str, _NamedAgg]:
221-
"""Collect all aggregations into a single mapping."""
222-
return dict(chain.from_iterable(self._iter_named_aggs(e) for e in exprs))
223-
224-
def _iter_named_aggs(self, expr: PandasLikeExpr) -> Iterator[tuple[str, _NamedAgg]]:
225-
ns = self.compliant.__narwhals_namespace__()
226-
output_names, aliases = evaluate_output_names_and_aliases(
227-
expr, self.compliant, self.exclude
315+
@property
316+
def _anonymous_column_name(self) -> str:
317+
# `len` doesn't exist yet, so just pick a column to call size on
318+
return next(
319+
iter(set(self.compliant.columns).difference(self.exclude)), self._keys[0]
228320
)
229-
remap_aliases = self._remap_aliases(aliases)
230-
leaf_name = self._leaf_name(expr)
231-
function_name = self._remap_expr_name(leaf_name)
232-
aggfunc = _agg_func(function_name, expr._implementation, **expr._scalar_kwargs)
233-
if leaf_name == "len" and expr._depth == 0:
234-
# `len` doesn't exist yet, so just pick a column to call size on
235-
first_col = next(
236-
iter(set(self.compliant.columns).difference(self.exclude)), self._keys[0]
237-
)
238-
yield remap_aliases[0], (first_col, aggfunc)
239-
return
240-
241-
if leaf_name == "n_unique" and has_non_int_nullable_dtype(
242-
self.compliant, output_names
243-
):
244-
self._casts.append(ns.col(*aliases).cast(ns._version.dtypes.Int64()))
245-
246-
for output_name, alias in zip(output_names, remap_aliases):
247-
yield alias, (output_name, aggfunc)
248321

249322
@property
250323
def _final_renamer(self) -> dict[str, NonStrHashable]:
@@ -254,7 +327,7 @@ def _final_renamer(self) -> dict[str, NonStrHashable]:
254327
return dict(zip(temps, originals))
255328

256329
def _select_results(
257-
self, df: pd.DataFrame, /, new_names: list[str]
330+
self, df: pd.DataFrame, /, agg_exprs: Sequence[AggExpr]
258331
) -> PandasLikeDataFrame:
259332
"""Responsible for remapping temp column names back to original.
260333
@@ -263,25 +336,24 @@ def _select_results(
263336
# NOTE: Keep `inplace=True` to avoid making a redundant copy.
264337
# This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
265338
df.reset_index(inplace=True) # noqa: PD002
339+
new_names = self._aliases_str(chain.from_iterable(e.aliases for e in agg_exprs))
266340
return (
267341
self.compliant._with_native(df, validate_column_names=False)
268-
.simple_select(*new_names)
342+
.simple_select(*self._keys, *new_names)
269343
.rename(self._final_renamer)
270-
.with_columns(*self._casts)
344+
.with_columns(*chain.from_iterable(e._cast_coerced(self) for e in agg_exprs))
271345
)
272346

273-
def _agg_complex(
274-
self, exprs: Iterable[PandasLikeExpr], new_names: list[str]
275-
) -> PandasLikeDataFrame:
347+
def _agg_complex(self, exprs: Iterable[PandasLikeExpr]) -> pd.DataFrame:
276348
warn_complex_group_by()
277349
impl = self.compliant._implementation
278350
backend_version = self.compliant._backend_version
279351
func = self._apply_exprs(exprs)
280352
apply = self._grouped.apply
281353
if impl.is_pandas() and backend_version >= (2, 2):
282-
return self._select_results(apply(func, include_groups=False), new_names)
354+
return apply(func, include_groups=False)
283355
else: # pragma: no cover
284-
return self._select_results(apply(func), new_names)
356+
return apply(func)
285357

286358
def _apply_exprs(self, exprs: Iterable[PandasLikeExpr]) -> NativeApply:
287359
ns = self.compliant.__narwhals_namespace__()

0 commit comments

Comments
 (0)