67
67
_NamedAgg : TypeAlias = "tuple[str, _AggFunc]"
68
68
"""Equivalent to `pd.NamedAgg`."""
69
69
70
- SeqStrT = TypeVar ("SeqStrT " , bound = "Sequence[str]" , default = "list [str]" )
70
+ IterStrT = TypeVar ("IterStrT " , bound = "Iterable [str]" )
71
71
72
72
NonStrHashable : TypeAlias = Any
73
73
"""Because `pandas` allows *"names"* like that 😭"""
74
74
75
+ T = TypeVar ("T" )
76
+
75
77
76
78
@lru_cache (maxsize = 32 )
77
79
def _agg_func (
@@ -94,16 +96,15 @@ def _n_unique(self: NativeSeriesGroupBy) -> pd.Series[Any]:
94
96
return self .nunique (dropna = False )
95
97
96
98
97
- # PLAN
98
- # ----
99
- # - Before aggregating, rename every column that isn't already a `str`
100
- # - Proxy all incoming expressions through a rename mapper
101
- # - At the end, rename back to the original
102
- def _remap_non_str (
103
- original : Sequence [Any ], exclude : Iterable [Any ]
104
- ) -> dict [NonStrHashable , str ]:
105
- """An empty result follows a no-op path in `with_columns."""
106
- exclude = set (exclude )
99
+ def _remap_non_str (group_by : PandasLikeGroupBy ) -> dict [NonStrHashable , str ]:
100
+ """Before aggregating, rename every column that isn't already a `str`.
101
+
102
+ - Proxy all incoming expressions through a rename mapper
103
+ - At the end, rename back to the original
104
+ - An empty result follows a no-op path in `with_columns.
105
+ """
106
+ original = group_by ._original_columns
107
+ exclude = set (group_by .exclude )
107
108
if remaining := set (original ).difference (exclude ):
108
109
union = exclude .union (original )
109
110
return {
@@ -114,6 +115,106 @@ def _remap_non_str(
114
115
return {} # pragma: no cover
115
116
116
117
118
+ def collect (iterable : tuple [T , ...] | Iterable [T ], / ) -> tuple [T , ...]:
119
+ """Collect `iterable` into a `tuple`, *iff* it is not one already.
120
+
121
+ Borrowed from [`ExprIR` PR].
122
+
123
+ [`ExprIR` PR]: https://github.com/narwhals-dev/narwhals/blob/1de65d2f82ace95a9bc72667067ffdfa9d28be6d/narwhals/_plan/common.py#L396-L398
124
+ """
125
+ return iterable if isinstance (iterable , tuple ) else tuple (iterable )
126
+
127
+
128
+ class AggExpr :
129
+ """Wrapper storing the intermediate state per-`PandasLikeExpr`.
130
+
131
+ There's a lot of edge cases to handle, so aim to evaluate as little
132
+ as possible - and store anything that's needed twice.
133
+
134
+ Warning:
135
+ While a `PandasLikeExpr` can be reused - this wrapper is valid **only**
136
+ in a single `.agg(...)` operation.
137
+ """
138
+
139
+ expr : PandasLikeExpr
140
+ output_names : tuple [str , ...]
141
+ aliases : tuple [str , ...]
142
+ _native_func : _AggFunc
143
+
144
+ def __init__ (self , expr : PandasLikeExpr ) -> None :
145
+ self .expr = expr
146
+ self .output_names = ()
147
+ self .aliases = ()
148
+ self ._leaf_name : NarwhalsAggregation | Any = ""
149
+
150
+ def with_expand_names (self , group_by : PandasLikeGroupBy , / ) -> AggExpr :
151
+ """**Mutating operation**.
152
+
153
+ Stores the results of `evaluate_output_names_and_aliases`.
154
+ """
155
+ df = group_by .compliant
156
+ exclude = group_by .exclude
157
+ output_names , aliases = evaluate_output_names_and_aliases (self .expr , df , exclude )
158
+ self .output_names , self .aliases = collect (output_names ), collect (aliases )
159
+ return self
160
+
161
+ def named_aggs (
162
+ self , group_by : PandasLikeGroupBy , /
163
+ ) -> Iterator [tuple [str , _NamedAgg ]]:
164
+ aliases = collect (group_by ._aliases_str (self .aliases ))
165
+ native_func = self .native_func
166
+ if self .is_len () and self .is_anonymous ():
167
+ yield aliases [0 ], (group_by ._anonymous_column_name , native_func )
168
+ return
169
+ for output_name , alias in zip (self .output_names , aliases ):
170
+ yield alias , (output_name , native_func )
171
+
172
+ def _cast_coerced (self , group_by : PandasLikeGroupBy , / ) -> Iterator [PandasLikeExpr ]:
173
+ """Yield post-agg casts to correct for weird pandas behavior.
174
+
175
+ See https://github.com/narwhals-dev/narwhals/pull/2680#discussion_r2157251589
176
+ """
177
+ df = group_by .compliant
178
+ if self .is_n_unique () and has_non_int_nullable_dtype (df , self .output_names ):
179
+ ns = df .__narwhals_namespace__ ()
180
+ yield ns .col (* self .aliases ).cast (ns ._version .dtypes .Int64 ())
181
+
182
+ def is_len (self ) -> bool :
183
+ return self .leaf_name == "len"
184
+
185
+ def is_n_unique (self ) -> bool :
186
+ return self .leaf_name == "n_unique"
187
+
188
+ def is_anonymous (self ) -> bool :
189
+ return self .expr ._depth == 0
190
+
191
+ @property
192
+ def implementation (self ) -> Implementation :
193
+ return self .expr ._implementation
194
+
195
+ @property
196
+ def kwargs (self ) -> ScalarKwargs :
197
+ return self .expr ._scalar_kwargs
198
+
199
+ @property
200
+ def leaf_name (self ) -> NarwhalsAggregation | Any :
201
+ if name := self ._leaf_name :
202
+ return name
203
+ self ._leaf_name = PandasLikeGroupBy ._leaf_name (self .expr )
204
+ return self ._leaf_name
205
+
206
+ @property
207
+ def native_func (self ) -> _AggFunc :
208
+ if hasattr (self , "_native_func" ):
209
+ return self ._native_func
210
+ self ._native_func = _agg_func (
211
+ PandasLikeGroupBy ._remap_expr_name (self .leaf_name ),
212
+ self .implementation ,
213
+ ** self .kwargs ,
214
+ )
215
+ return self ._native_func
216
+
217
+
117
218
class PandasLikeGroupBy (
118
219
EagerGroupBy ["PandasLikeDataFrame" , "PandasLikeExpr" , NativeAggregation ]
119
220
):
@@ -140,12 +241,6 @@ class PandasLikeGroupBy(
140
241
141
242
_remap_non_str_columns : dict [NonStrHashable , str ]
142
243
143
- _casts : list [PandasLikeExpr ]
144
- """Post-aggregation dtype cast expressions.
145
-
146
- See https://github.com/narwhals-dev/narwhals/pull/2680#discussion_r2157251589
147
- """
148
-
149
244
@property
150
245
def exclude (self ) -> tuple [str , ...]:
151
246
"""Group keys to ignore when expanding multi-output aggregations."""
@@ -164,11 +259,10 @@ def __init__(
164
259
ns = df .__narwhals_namespace__ ()
165
260
frame , self ._keys , self ._output_key_names = self ._parse_keys (df , keys = keys )
166
261
self ._exclude : tuple [str , ...] = (* self ._keys , * self ._output_key_names )
167
- self ._remap_non_str_columns = _remap_non_str (self . _original_columns , self . exclude )
262
+ self ._remap_non_str_columns = _remap_non_str (self )
168
263
self ._compliant_frame = frame .with_columns (
169
264
* (ns .col (old ).alias (new ) for old , new in self ._remap_non_str_columns .items ())
170
265
)
171
- self ._casts = []
172
266
# Drop index to avoid potential collisions:
173
267
# https://github.com/narwhals-dev/narwhals/issues/1907.
174
268
native = self .compliant .native
@@ -183,68 +277,47 @@ def __init__(
183
277
)
184
278
185
279
def agg (self , * exprs : PandasLikeExpr ) -> PandasLikeDataFrame :
186
- new_names : list [str ] = self ._keys .copy ()
187
280
all_aggs_are_simple = True
281
+ agg_exprs : list [AggExpr ] = []
188
282
for expr in exprs :
189
- _ , aliases = evaluate_output_names_and_aliases (
190
- expr , self .compliant , self .exclude
191
- )
192
- new_names .extend (aliases )
283
+ agg_exprs .append (AggExpr (expr ).with_expand_names (self ))
193
284
if not self ._is_simple (expr ):
194
285
all_aggs_are_simple = False
195
286
196
- if any (not isinstance (k , str ) for k in new_names ):
197
- new_names = self ._remap_aliases (new_names )
198
287
if all_aggs_are_simple :
199
288
result : pd .DataFrame
200
- if named_aggs := self ._named_aggs (* exprs ):
289
+ if named_aggs := self ._named_aggs (agg_exprs ):
201
290
result = self ._grouped .agg (** named_aggs ) # type: ignore[call-overload]
202
291
else :
203
292
result = self .compliant .__native_namespace__ ().DataFrame (
204
293
list (self ._grouped .groups ), columns = self ._keys
205
294
)
206
- return self ._select_results (result , new_names )
207
- if self .compliant .native .empty :
295
+ elif self .compliant .native .empty :
208
296
raise empty_results_error ()
209
- return self ._agg_complex (exprs , new_names )
297
+ else :
298
+ result = self ._agg_complex (exprs )
299
+ return self ._select_results (result , agg_exprs )
300
+
301
+ def _named_aggs (self , exprs : Iterable [AggExpr ], / ) -> dict [str , _NamedAgg ]:
302
+ """Collect all aggregations into a single mapping."""
303
+ return dict (chain .from_iterable (e .named_aggs (self ) for e in exprs ))
210
304
211
305
@overload
212
- def _remap_aliases (self , names : list [str ], / ) -> list [str ]: ...
306
+ def _aliases_str (self , names : list [str ], / ) -> list [str ]: ...
213
307
@overload
214
- def _remap_aliases (self , names : SeqStrT , / ) -> list [str ] | SeqStrT : ...
215
- def _remap_aliases (self , names : SeqStrT , / ) -> list [str ] | SeqStrT :
308
+ def _aliases_str (self , names : IterStrT , / ) -> list [str ] | IterStrT : ...
309
+ def _aliases_str (self , names : IterStrT , / ) -> list [str ] | IterStrT :
310
+ """If we started with any non `str` column names, return the proxied `str` aliases for `names`."""
216
311
if remap := self ._remap_non_str_columns :
217
312
return [remap .get (name , name ) for name in names ]
218
313
return names
219
314
220
- def _named_aggs (self , * exprs : PandasLikeExpr ) -> dict [str , _NamedAgg ]:
221
- """Collect all aggregations into a single mapping."""
222
- return dict (chain .from_iterable (self ._iter_named_aggs (e ) for e in exprs ))
223
-
224
- def _iter_named_aggs (self , expr : PandasLikeExpr ) -> Iterator [tuple [str , _NamedAgg ]]:
225
- ns = self .compliant .__narwhals_namespace__ ()
226
- output_names , aliases = evaluate_output_names_and_aliases (
227
- expr , self .compliant , self .exclude
315
+ @property
316
+ def _anonymous_column_name (self ) -> str :
317
+ # `len` doesn't exist yet, so just pick a column to call size on
318
+ return next (
319
+ iter (set (self .compliant .columns ).difference (self .exclude )), self ._keys [0 ]
228
320
)
229
- remap_aliases = self ._remap_aliases (aliases )
230
- leaf_name = self ._leaf_name (expr )
231
- function_name = self ._remap_expr_name (leaf_name )
232
- aggfunc = _agg_func (function_name , expr ._implementation , ** expr ._scalar_kwargs )
233
- if leaf_name == "len" and expr ._depth == 0 :
234
- # `len` doesn't exist yet, so just pick a column to call size on
235
- first_col = next (
236
- iter (set (self .compliant .columns ).difference (self .exclude )), self ._keys [0 ]
237
- )
238
- yield remap_aliases [0 ], (first_col , aggfunc )
239
- return
240
-
241
- if leaf_name == "n_unique" and has_non_int_nullable_dtype (
242
- self .compliant , output_names
243
- ):
244
- self ._casts .append (ns .col (* aliases ).cast (ns ._version .dtypes .Int64 ()))
245
-
246
- for output_name , alias in zip (output_names , remap_aliases ):
247
- yield alias , (output_name , aggfunc )
248
321
249
322
@property
250
323
def _final_renamer (self ) -> dict [str , NonStrHashable ]:
@@ -254,7 +327,7 @@ def _final_renamer(self) -> dict[str, NonStrHashable]:
254
327
return dict (zip (temps , originals ))
255
328
256
329
def _select_results (
257
- self , df : pd .DataFrame , / , new_names : list [ str ]
330
+ self , df : pd .DataFrame , / , agg_exprs : Sequence [ AggExpr ]
258
331
) -> PandasLikeDataFrame :
259
332
"""Responsible for remapping temp column names back to original.
260
333
@@ -263,25 +336,24 @@ def _select_results(
263
336
# NOTE: Keep `inplace=True` to avoid making a redundant copy.
264
337
# This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
265
338
df .reset_index (inplace = True ) # noqa: PD002
339
+ new_names = self ._aliases_str (chain .from_iterable (e .aliases for e in agg_exprs ))
266
340
return (
267
341
self .compliant ._with_native (df , validate_column_names = False )
268
- .simple_select (* new_names )
342
+ .simple_select (* self . _keys , * new_names )
269
343
.rename (self ._final_renamer )
270
- .with_columns (* self . _casts )
344
+ .with_columns (* chain . from_iterable ( e . _cast_coerced ( self ) for e in agg_exprs ) )
271
345
)
272
346
273
- def _agg_complex (
274
- self , exprs : Iterable [PandasLikeExpr ], new_names : list [str ]
275
- ) -> PandasLikeDataFrame :
347
+ def _agg_complex (self , exprs : Iterable [PandasLikeExpr ]) -> pd .DataFrame :
276
348
warn_complex_group_by ()
277
349
impl = self .compliant ._implementation
278
350
backend_version = self .compliant ._backend_version
279
351
func = self ._apply_exprs (exprs )
280
352
apply = self ._grouped .apply
281
353
if impl .is_pandas () and backend_version >= (2 , 2 ):
282
- return self . _select_results ( apply (func , include_groups = False ), new_names )
354
+ return apply (func , include_groups = False )
283
355
else : # pragma: no cover
284
- return self . _select_results ( apply (func ), new_names )
356
+ return apply (func )
285
357
286
358
def _apply_exprs (self , exprs : Iterable [PandasLikeExpr ]) -> NativeApply :
287
359
ns = self .compliant .__narwhals_namespace__ ()
0 commit comments