@@ -115,60 +115,6 @@ def generic_aggregate(
115
115
return result
116
116
117
117
118
- def _normalize_dtype (dtype : DTypeLike , array_dtype : np .dtype , fill_value = None ) -> np .dtype :
119
- if dtype is None :
120
- dtype = array_dtype
121
- if dtype is np .floating :
122
- # mean, std, var always result in floating
123
- # but we preserve the array's dtype if it is floating
124
- if array_dtype .kind in "fcmM" :
125
- dtype = array_dtype
126
- else :
127
- dtype = np .dtype ("float64" )
128
- elif not isinstance (dtype , np .dtype ):
129
- dtype = np .dtype (dtype )
130
- if fill_value not in [None , dtypes .INF , dtypes .NINF , dtypes .NA ]:
131
- dtype = np .result_type (dtype , fill_value )
132
- return dtype
133
-
134
-
135
- def _maybe_promote_int (dtype ) -> np .dtype :
136
- # https://numpy.org/doc/stable/reference/generated/numpy.prod.html
137
- # The dtype of a is used by default unless a has an integer dtype of less precision
138
- # than the default platform integer.
139
- if not isinstance (dtype , np .dtype ):
140
- dtype = np .dtype (dtype )
141
- if dtype .kind == "i" :
142
- dtype = np .result_type (dtype , np .intp )
143
- elif dtype .kind == "u" :
144
- dtype = np .result_type (dtype , np .uintp )
145
- return dtype
146
-
147
-
148
- def _get_fill_value (dtype , fill_value ):
149
- """Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
150
- if fill_value in [None , dtypes .NA ] and dtype .kind in "US" :
151
- return ""
152
- if fill_value == dtypes .INF or fill_value is None :
153
- return dtypes .get_pos_infinity (dtype , max_for_int = True )
154
- if fill_value == dtypes .NINF :
155
- return dtypes .get_neg_infinity (dtype , min_for_int = True )
156
- if fill_value == dtypes .NA :
157
- if np .issubdtype (dtype , np .floating ) or np .issubdtype (dtype , np .complexfloating ):
158
- return np .nan
159
- # This is madness, but npg checks that fill_value is compatible
160
- # with array dtype even if the fill_value is never used.
161
- elif (
162
- np .issubdtype (dtype , np .integer )
163
- or np .issubdtype (dtype , np .timedelta64 )
164
- or np .issubdtype (dtype , np .datetime64 )
165
- ):
166
- return dtypes .get_neg_infinity (dtype , min_for_int = True )
167
- else :
168
- return None
169
- return fill_value
170
-
171
-
172
118
def _atleast_1d (inp , min_length : int = 1 ):
173
119
if xrutils .is_scalar (inp ):
174
120
inp = (inp ,) * min_length
@@ -435,9 +381,9 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
435
381
436
382
437
383
min_ = Aggregation ("min" , chunk = "min" , combine = "min" , fill_value = dtypes .INF )
438
- nanmin = Aggregation ("nanmin" , chunk = "nanmin" , combine = "nanmin" , fill_value = np . nan )
384
+ nanmin = Aggregation ("nanmin" , chunk = "nanmin" , combine = "nanmin" , fill_value = dtypes . NA )
439
385
max_ = Aggregation ("max" , chunk = "max" , combine = "max" , fill_value = dtypes .NINF )
440
- nanmax = Aggregation ("nanmax" , chunk = "nanmax" , combine = "nanmax" , fill_value = np . nan )
386
+ nanmax = Aggregation ("nanmax" , chunk = "nanmax" , combine = "nanmax" , fill_value = dtypes . NA )
441
387
442
388
443
389
def argreduce_preprocess (array , axis ):
@@ -634,7 +580,7 @@ def last(self) -> AlignedArrays:
634
580
# TODO: automate?
635
581
engine = "flox" ,
636
582
dtype = self .array .dtype ,
637
- fill_value = _get_fill_value (self .array .dtype , dtypes .NA ),
583
+ fill_value = dtypes . _get_fill_value (self .array .dtype , dtypes .NA ),
638
584
expected_groups = None ,
639
585
)
640
586
return AlignedArrays (array = reduced ["intermediates" ][0 ], group_idx = reduced ["groups" ])
@@ -729,15 +675,15 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
729
675
binary_op = None ,
730
676
reduction = "nanlast" ,
731
677
scan = "ffill" ,
732
- identity = np . nan ,
678
+ identity = dtypes . NA ,
733
679
mode = "concat_then_scan" ,
734
680
)
735
681
bfill = Scan (
736
682
"bfill" ,
737
683
binary_op = None ,
738
684
reduction = "nanlast" ,
739
685
scan = "ffill" ,
740
- identity = np . nan ,
686
+ identity = dtypes . NA ,
741
687
mode = "concat_then_scan" ,
742
688
preprocess = reverse ,
743
689
finalize = reverse ,
@@ -816,16 +762,27 @@ def _initialize_aggregation(
816
762
np .dtype (dtype ) if dtype is not None and not isinstance (dtype , np .dtype ) else dtype
817
763
)
818
764
819
- final_dtype = _normalize_dtype (dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value )
820
- if agg .name not in ["first" , "last" , "nanfirst" , "nanlast" , "min" , "max" , "nanmin" , "nanmax" ]:
821
- final_dtype = _maybe_promote_int (final_dtype )
765
+ final_dtype = dtypes ._normalize_dtype (
766
+ dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value
767
+ )
768
+ if agg .name not in [
769
+ "first" ,
770
+ "last" ,
771
+ "nanfirst" ,
772
+ "nanlast" ,
773
+ "min" ,
774
+ "max" ,
775
+ "nanmin" ,
776
+ "nanmax" ,
777
+ ]:
778
+ final_dtype = dtypes ._maybe_promote_int (final_dtype )
822
779
agg .dtype = {
823
780
"user" : dtype , # Save to automatically choose an engine
824
781
"final" : final_dtype ,
825
782
"numpy" : (final_dtype ,),
826
783
"intermediate" : tuple (
827
784
(
828
- _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
785
+ dtypes . _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
829
786
if int_dtype is None
830
787
else np .dtype (int_dtype )
831
788
)
@@ -838,10 +795,10 @@ def _initialize_aggregation(
838
795
# Replace sentinel fill values according to dtype
839
796
agg .fill_value ["user" ] = fill_value
840
797
agg .fill_value ["intermediate" ] = tuple (
841
- _get_fill_value (dt , fv )
798
+ dtypes . _get_fill_value (dt , fv )
842
799
for dt , fv in zip (agg .dtype ["intermediate" ], agg .fill_value ["intermediate" ])
843
800
)
844
- agg .fill_value [func ] = _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
801
+ agg .fill_value [func ] = dtypes . _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
845
802
846
803
fv = fill_value if fill_value is not None else agg .fill_value [agg .name ]
847
804
if _is_arg_reduction (agg ):
0 commit comments