1
1
import numpy as np
2
+ from packaging .version import Version
2
3
3
4
from .utils import (
4
5
aggregate_common_doc ,
13
14
check_fill_value ,
14
15
input_validation ,
15
16
iscomplexobj ,
17
+ maxval ,
16
18
minimum_dtype ,
17
19
minimum_dtype_scalar ,
18
20
minval ,
19
- maxval ,
20
21
)
21
22
22
23
24
+ def _full (size , fill_value , * , dtype = None , like = None ):
25
+ """Backcompat for numpy < 1.20.0 which does not support the `like` kwarg"""
26
+ if (
27
+ like is not None # numpy bug?
28
+ and not np .isscalar (like ) # scalars don't work
29
+ and Version (np .__version__ ) >= Version ("1.20.0" )
30
+ ):
31
+ kwargs = {"like" : like }
32
+ else :
33
+ kwargs = {}
34
+
35
+ return np .full (size , fill_value = fill_value , dtype = dtype , ** kwargs )
36
+
37
+
23
38
def _sum (group_idx , a , size , fill_value , dtype = None ):
24
39
dtype = minimum_dtype_scalar (fill_value , dtype , a )
25
40
@@ -44,7 +59,7 @@ def _sum(group_idx, a, size, fill_value, dtype=None):
44
59
45
60
def _prod (group_idx , a , size , fill_value , dtype = None ):
46
61
dtype = minimum_dtype_scalar (fill_value , dtype , a )
47
- ret = np . full (size , fill_value , dtype = dtype , like = a )
62
+ ret = _full (size , fill_value , dtype = dtype , like = a )
48
63
if fill_value != 1 :
49
64
ret [group_idx ] = 1 # product starts from 1
50
65
np .multiply .at (ret , group_idx , a )
@@ -57,7 +72,7 @@ def _len(group_idx, a, size, fill_value, dtype=None):
57
72
58
73
def _last (group_idx , a , size , fill_value , dtype = None ):
59
74
dtype = minimum_dtype (fill_value , dtype or a .dtype )
60
- ret = np . full (size , fill_value , dtype = dtype , like = a )
75
+ ret = _full (size , fill_value , dtype = dtype , like = a )
61
76
# repeated indexing gives last value, see:
62
77
# the phrase "leaving behind the last value" on this page:
63
78
# http://wiki.scipy.org/Tentative_NumPy_Tutorial
@@ -67,14 +82,14 @@ def _last(group_idx, a, size, fill_value, dtype=None):
67
82
68
83
def _first (group_idx , a , size , fill_value , dtype = None ):
69
84
dtype = minimum_dtype (fill_value , dtype or a .dtype )
70
- ret = np . full (size , fill_value , dtype = dtype , like = a )
85
+ ret = _full (size , fill_value , dtype = dtype , like = a )
71
86
ret [group_idx [::- 1 ]] = a [::- 1 ] # same trick as _last, but in reverse
72
87
return ret
73
88
74
89
75
90
def _all (group_idx , a , size , fill_value , dtype = None ):
76
91
check_boolean (fill_value )
77
- ret = np . full (size , fill_value , dtype = bool , like = a )
92
+ ret = _full (size , fill_value , dtype = bool , like = a )
78
93
if not fill_value :
79
94
ret [group_idx ] = True
80
95
ret [group_idx .compress (np .logical_not (a ))] = False
@@ -83,7 +98,7 @@ def _all(group_idx, a, size, fill_value, dtype=None):
83
98
84
99
def _any (group_idx , a , size , fill_value , dtype = None ):
85
100
check_boolean (fill_value )
86
- ret = np . full (size , fill_value , dtype = bool , like = a )
101
+ ret = _full (size , fill_value , dtype = bool , like = a )
87
102
if fill_value :
88
103
ret [group_idx ] = False
89
104
ret [group_idx .compress (a )] = True
@@ -93,7 +108,7 @@ def _any(group_idx, a, size, fill_value, dtype=None):
93
108
def _min (group_idx , a , size , fill_value , dtype = None ):
94
109
dtype = minimum_dtype (fill_value , dtype or a .dtype )
95
110
dmax = maxval (fill_value , dtype )
96
- ret = np . full (size , fill_value , dtype = dtype , like = a )
111
+ ret = _full (size , fill_value , dtype = dtype , like = a )
97
112
if fill_value != dmax :
98
113
ret [group_idx ] = dmax # min starts from maximum
99
114
np .minimum .at (ret , group_idx , a )
@@ -103,7 +118,7 @@ def _min(group_idx, a, size, fill_value, dtype=None):
103
118
def _max (group_idx , a , size , fill_value , dtype = None ):
104
119
dtype = minimum_dtype (fill_value , dtype or a .dtype )
105
120
dmin = minval (fill_value , dtype )
106
- ret = np . full (size , fill_value , dtype = dtype , like = a )
121
+ ret = _full (size , fill_value , dtype = dtype , like = a )
107
122
if fill_value != dmin :
108
123
ret [group_idx ] = dmin # max starts from minimum
109
124
np .maximum .at (ret , group_idx , a )
@@ -115,7 +130,7 @@ def _argmax(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
115
130
group_max = _max (group_idx , a_ , size , np .nan )
116
131
# nan should never be maximum, so use a and not a_
117
132
is_max = a == group_max [group_idx ]
118
- ret = np . full (size , fill_value , dtype = dtype , like = a )
133
+ ret = _full (size , fill_value , dtype = dtype , like = a )
119
134
group_idx_max = group_idx [is_max ]
120
135
(argmax ,) = is_max .nonzero ()
121
136
ret [group_idx_max [::- 1 ]] = argmax [
@@ -129,7 +144,7 @@ def _argmin(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
129
144
group_min = _min (group_idx , a_ , size , np .nan )
130
145
# nan should never be minimum, so use a and not a_
131
146
is_min = a == group_min [group_idx ]
132
- ret = np . full (size , fill_value , dtype = dtype , like = a )
147
+ ret = _full (size , fill_value , dtype = dtype , like = a )
133
148
group_idx_min = group_idx [is_min ]
134
149
(argmin ,) = is_min .nonzero ()
135
150
ret [group_idx_min [::- 1 ]] = argmin [
@@ -148,7 +163,9 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
148
163
sums .real = np .bincount (group_idx , weights = a .real , minlength = size )
149
164
sums .imag = np .bincount (group_idx , weights = a .imag , minlength = size )
150
165
else :
151
- sums = np .bincount (group_idx , weights = a , minlength = size ).astype (dtype , copy = False )
166
+ sums = np .bincount (group_idx , weights = a , minlength = size ).astype (
167
+ dtype , copy = False
168
+ )
152
169
153
170
with np .errstate (divide = "ignore" , invalid = "ignore" ):
154
171
ret = sums .astype (dtype , copy = False ) / counts
@@ -223,7 +240,7 @@ def _generic_callable(
223
240
"""groups a by inds, and then applies foo to each group in turn, placing
224
241
the results in an array."""
225
242
groups = _array (group_idx , a , size , ())
226
- ret = np . full (size , fill_value , dtype = dtype or np .float64 )
243
+ ret = _full (size , fill_value , dtype = dtype or np .float64 )
227
244
228
245
for i , grp in enumerate (groups ):
229
246
if np .ndim (grp ) == 1 and len (grp ) > 0 :
0 commit comments