23
23
)
24
24
from pandas .core .base import DataError
25
25
from typing import Type , Callable
26
+ from collections .abc import Iterable , Container
26
27
import warnings
27
28
28
29
37
38
ReductionFunction ,
38
39
BinaryFunction ,
39
40
GroupbyReduceFunction ,
41
+ groupby_reduce_functions ,
40
42
)
41
43
42
44
@@ -2443,33 +2445,57 @@ def _callable_func(self, func, axis, *args, **kwargs):
2443
2445
# nature. They require certain data to exist on the same partition, and
2444
2446
# after the shuffle, there should be only a local map required.
2445
2447
2446
- groupby_count = GroupbyReduceFunction .register (
2447
- lambda df , ** kwargs : df .count (** kwargs ), lambda df , ** kwargs : df .sum (** kwargs )
2448
- )
2449
- groupby_any = GroupbyReduceFunction .register (
2450
- lambda df , ** kwargs : df .any (** kwargs ), lambda df , ** kwargs : df .any (** kwargs )
2451
- )
2452
- groupby_min = GroupbyReduceFunction .register (
2453
- lambda df , ** kwargs : df .min (** kwargs ), lambda df , ** kwargs : df .min (** kwargs )
2454
- )
2455
- groupby_prod = GroupbyReduceFunction .register (
2456
- lambda df , ** kwargs : df .prod (** kwargs ), lambda df , ** kwargs : df .prod (** kwargs )
2457
- )
2458
- groupby_max = GroupbyReduceFunction .register (
2459
- lambda df , ** kwargs : df .max (** kwargs ), lambda df , ** kwargs : df .max (** kwargs )
2460
- )
2461
- groupby_all = GroupbyReduceFunction .register (
2462
- lambda df , ** kwargs : df .all (** kwargs ), lambda df , ** kwargs : df .all (** kwargs )
2463
- )
2464
- groupby_sum = GroupbyReduceFunction .register (
2465
- lambda df , ** kwargs : df .sum (** kwargs ), lambda df , ** kwargs : df .sum (** kwargs )
2466
- )
2448
+ groupby_count = GroupbyReduceFunction .register (* groupby_reduce_functions ["count" ])
2449
+ groupby_any = GroupbyReduceFunction .register (* groupby_reduce_functions ["any" ])
2450
+ groupby_min = GroupbyReduceFunction .register (* groupby_reduce_functions ["min" ])
2451
+ groupby_prod = GroupbyReduceFunction .register (* groupby_reduce_functions ["prod" ])
2452
+ groupby_max = GroupbyReduceFunction .register (* groupby_reduce_functions ["max" ])
2453
+ groupby_all = GroupbyReduceFunction .register (* groupby_reduce_functions ["all" ])
2454
+ groupby_sum = GroupbyReduceFunction .register (* groupby_reduce_functions ["sum" ])
2467
2455
groupby_size = GroupbyReduceFunction .register (
2468
- lambda df , ** kwargs : pandas .DataFrame (df .size ()),
2469
- lambda df , ** kwargs : df .sum (),
2470
- method = "size" ,
2456
+ * groupby_reduce_functions ["size" ], method = "size"
2471
2457
)
2472
2458
2459
+ def _groupby_dict_reduce (
2460
+ self , by , axis , agg_func , agg_args , agg_kwargs , groupby_kwargs , drop = False
2461
+ ):
2462
+ map_dict = {}
2463
+ reduce_dict = {}
2464
+ rename_columns = any (
2465
+ not isinstance (fn , str ) and isinstance (fn , Iterable )
2466
+ for fn in agg_func .values ()
2467
+ )
2468
+ for col , col_funcs in agg_func .items ():
2469
+ if not rename_columns :
2470
+ map_dict [col ], reduce_dict [col ] = groupby_reduce_functions [col_funcs ]
2471
+ continue
2472
+
2473
+ if isinstance (col_funcs , str ):
2474
+ col_funcs = [col_funcs ]
2475
+
2476
+ map_fns = []
2477
+ for i , fn in enumerate (col_funcs ):
2478
+ if not isinstance (fn , str ) and isinstance (fn , Iterable ):
2479
+ new_col_name , func = fn
2480
+ elif isinstance (fn , str ):
2481
+ new_col_name , func = fn , fn
2482
+ else :
2483
+ raise TypeError
2484
+
2485
+ map_fns .append ((new_col_name , groupby_reduce_functions [func ][0 ]))
2486
+ reduce_dict [(col , new_col_name )] = groupby_reduce_functions [func ][1 ]
2487
+ map_dict [col ] = map_fns
2488
+ return GroupbyReduceFunction .register (map_dict , reduce_dict )(
2489
+ query_compiler = self ,
2490
+ by = by ,
2491
+ axis = axis ,
2492
+ groupby_args = groupby_kwargs ,
2493
+ map_args = agg_kwargs ,
2494
+ reduce_args = agg_kwargs ,
2495
+ numeric_only = False ,
2496
+ drop = drop ,
2497
+ )
2498
+
2473
2499
def groupby_agg (
2474
2500
self ,
2475
2501
by ,
@@ -2481,6 +2507,31 @@ def groupby_agg(
2481
2507
groupby_kwargs ,
2482
2508
drop = False ,
2483
2509
):
2510
+ def is_reduce_fn (fn , deep_level = 0 ):
2511
+ if not isinstance (fn , str ) and isinstance (fn , Container ):
2512
+ # `deep_level` parameter specifies the number of nested containers that was met:
2513
+ # - if it's 0, then we're outside of container, `fn` could be either function name
2514
+ # or container of function names/renamers.
2515
+ # - if it's 1, then we're inside container of function names/renamers. `fn` must be
2516
+ # either function name or renamer (renamer is some container which length == 2,
2517
+ # the first element is the new column name and the second is the function name).
2518
+ assert deep_level == 0 or (
2519
+ deep_level > 0 and len (fn ) == 2
2520
+ ), f"Got the renamer with incorrect length, expected 2 got { len (fn )} ."
2521
+ return (
2522
+ all (is_reduce_fn (f , deep_level + 1 ) for f in fn )
2523
+ if deep_level == 0
2524
+ else is_reduce_fn (fn [1 ], deep_level + 1 )
2525
+ )
2526
+ return isinstance (fn , str ) and fn in groupby_reduce_functions
2527
+
2528
+ if isinstance (agg_func , dict ) and all (
2529
+ is_reduce_fn (x ) for x in agg_func .values ()
2530
+ ):
2531
+ return self ._groupby_dict_reduce (
2532
+ by , axis , agg_func , agg_args , agg_kwargs , groupby_kwargs , drop
2533
+ )
2534
+
2484
2535
if callable (agg_func ):
2485
2536
agg_func = wrap_udf_function (agg_func )
2486
2537
0 commit comments