Skip to content

Commit a37c559

Browse files
committed
_full
1 parent d148074 commit a37c559

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

numpy_groupies/aggregate_numpy.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from packaging.version import Version
23

34
from .utils import (
45
aggregate_common_doc,
@@ -13,13 +14,27 @@
1314
check_fill_value,
1415
input_validation,
1516
iscomplexobj,
17+
maxval,
1618
minimum_dtype,
1719
minimum_dtype_scalar,
1820
minval,
19-
maxval,
2021
)
2122

2223

24+
def _full(size, fill_value, *, dtype=None, like=None):
25+
"""Backcompat for numpy < 1.20.0 which does not support the `like` kwarg"""
26+
if (
27+
like is not None # numpy bug?
28+
and not np.isscalar(like) # scalars don't work
29+
and Version(np.__version__) >= Version("1.20.0")
30+
):
31+
kwargs = {"like": like}
32+
else:
33+
kwargs = {}
34+
35+
return np.full(size, fill_value=fill_value, dtype=dtype, **kwargs)
36+
37+
2338
def _sum(group_idx, a, size, fill_value, dtype=None):
2439
dtype = minimum_dtype_scalar(fill_value, dtype, a)
2540

@@ -44,7 +59,7 @@ def _sum(group_idx, a, size, fill_value, dtype=None):
4459

4560
def _prod(group_idx, a, size, fill_value, dtype=None):
4661
dtype = minimum_dtype_scalar(fill_value, dtype, a)
47-
ret = np.full(size, fill_value, dtype=dtype, like=a)
62+
ret = _full(size, fill_value, dtype=dtype, like=a)
4863
if fill_value != 1:
4964
ret[group_idx] = 1 # product starts from 1
5065
np.multiply.at(ret, group_idx, a)
@@ -57,7 +72,7 @@ def _len(group_idx, a, size, fill_value, dtype=None):
5772

5873
def _last(group_idx, a, size, fill_value, dtype=None):
5974
dtype = minimum_dtype(fill_value, dtype or a.dtype)
60-
ret = np.full(size, fill_value, dtype=dtype, like=a)
75+
ret = _full(size, fill_value, dtype=dtype, like=a)
6176
# repeated indexing gives last value, see:
6277
# the phrase "leaving behind the last value" on this page:
6378
# http://wiki.scipy.org/Tentative_NumPy_Tutorial
@@ -67,14 +82,14 @@ def _last(group_idx, a, size, fill_value, dtype=None):
6782

6883
def _first(group_idx, a, size, fill_value, dtype=None):
6984
dtype = minimum_dtype(fill_value, dtype or a.dtype)
70-
ret = np.full(size, fill_value, dtype=dtype, like=a)
85+
ret = _full(size, fill_value, dtype=dtype, like=a)
7186
ret[group_idx[::-1]] = a[::-1] # same trick as _last, but in reverse
7287
return ret
7388

7489

7590
def _all(group_idx, a, size, fill_value, dtype=None):
7691
check_boolean(fill_value)
77-
ret = np.full(size, fill_value, dtype=bool, like=a)
92+
ret = _full(size, fill_value, dtype=bool, like=a)
7893
if not fill_value:
7994
ret[group_idx] = True
8095
ret[group_idx.compress(np.logical_not(a))] = False
@@ -83,7 +98,7 @@ def _all(group_idx, a, size, fill_value, dtype=None):
8398

8499
def _any(group_idx, a, size, fill_value, dtype=None):
85100
check_boolean(fill_value)
86-
ret = np.full(size, fill_value, dtype=bool, like=a)
101+
ret = _full(size, fill_value, dtype=bool, like=a)
87102
if fill_value:
88103
ret[group_idx] = False
89104
ret[group_idx.compress(a)] = True
@@ -93,7 +108,7 @@ def _any(group_idx, a, size, fill_value, dtype=None):
93108
def _min(group_idx, a, size, fill_value, dtype=None):
94109
dtype = minimum_dtype(fill_value, dtype or a.dtype)
95110
dmax = maxval(fill_value, dtype)
96-
ret = np.full(size, fill_value, dtype=dtype, like=a)
111+
ret = _full(size, fill_value, dtype=dtype, like=a)
97112
if fill_value != dmax:
98113
ret[group_idx] = dmax # min starts from maximum
99114
np.minimum.at(ret, group_idx, a)
@@ -103,7 +118,7 @@ def _min(group_idx, a, size, fill_value, dtype=None):
103118
def _max(group_idx, a, size, fill_value, dtype=None):
104119
dtype = minimum_dtype(fill_value, dtype or a.dtype)
105120
dmin = minval(fill_value, dtype)
106-
ret = np.full(size, fill_value, dtype=dtype, like=a)
121+
ret = _full(size, fill_value, dtype=dtype, like=a)
107122
if fill_value != dmin:
108123
ret[group_idx] = dmin # max starts from minimum
109124
np.maximum.at(ret, group_idx, a)
@@ -115,7 +130,7 @@ def _argmax(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
115130
group_max = _max(group_idx, a_, size, np.nan)
116131
# nan should never be maximum, so use a and not a_
117132
is_max = a == group_max[group_idx]
118-
ret = np.full(size, fill_value, dtype=dtype, like=a)
133+
ret = _full(size, fill_value, dtype=dtype, like=a)
119134
group_idx_max = group_idx[is_max]
120135
(argmax,) = is_max.nonzero()
121136
ret[group_idx_max[::-1]] = argmax[
@@ -129,7 +144,7 @@ def _argmin(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
129144
group_min = _min(group_idx, a_, size, np.nan)
130145
# nan should never be minimum, so use a and not a_
131146
is_min = a == group_min[group_idx]
132-
ret = np.full(size, fill_value, dtype=dtype, like=a)
147+
ret = _full(size, fill_value, dtype=dtype, like=a)
133148
group_idx_min = group_idx[is_min]
134149
(argmin,) = is_min.nonzero()
135150
ret[group_idx_min[::-1]] = argmin[
@@ -148,7 +163,9 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
148163
sums.real = np.bincount(group_idx, weights=a.real, minlength=size)
149164
sums.imag = np.bincount(group_idx, weights=a.imag, minlength=size)
150165
else:
151-
sums = np.bincount(group_idx, weights=a, minlength=size).astype(dtype, copy=False)
166+
sums = np.bincount(group_idx, weights=a, minlength=size).astype(
167+
dtype, copy=False
168+
)
152169

153170
with np.errstate(divide="ignore", invalid="ignore"):
154171
ret = sums.astype(dtype, copy=False) / counts
@@ -223,7 +240,7 @@ def _generic_callable(
223240
"""groups a by inds, and then applies foo to each group in turn, placing
224241
the results in an array."""
225242
groups = _array(group_idx, a, size, ())
226-
ret = np.full(size, fill_value, dtype=dtype or np.float64)
243+
ret = _full(size, fill_value, dtype=dtype or np.float64)
227244

228245
for i, grp in enumerate(groups):
229246
if np.ndim(grp) == 1 and len(grp) > 0:

0 commit comments

Comments
 (0)