Skip to content

Commit cbdf35c

Browse files
committed
Applied all hook fixes
1 parent 5913a35 commit cbdf35c

File tree

10 files changed

+221
-66
lines changed

10 files changed

+221
-66
lines changed

conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66

77
def pytest_configure(config):
8-
config.addinivalue_line("markers", "deselect_if(func): function to deselect tests from parametrization")
8+
config.addinivalue_line(
9+
"markers", "deselect_if(func): function to deselect tests from parametrization"
10+
)
911

1012

1113
def pytest_collection_modifyitems(config, items):

numpy_groupies/aggregate_numba.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def __call__(
6868
dtype = check_dtype(dtype, self.func, a, len(group_idx))
6969
check_fill_value(fill_value, dtype, func=self.func)
7070
input_dtype = type(a) if np.isscalar(a) else a.dtype
71-
ret, counter, mean, outer = self._initialize(flat_size, fill_value, dtype, input_dtype, group_idx.size)
71+
ret, counter, mean, outer = self._initialize(
72+
flat_size, fill_value, dtype, input_dtype, group_idx.size
73+
)
7274
group_idx = np.ascontiguousarray(group_idx)
7375

7476
if not np.isscalar(a):
@@ -141,7 +143,9 @@ def inner(ri, val, ret, counter, mean, fill_value):
141143
def loop(group_idx, a, ret, counter, mean, outer, fill_value, ddof):
142144
# ddof needs to be present for being exchangeable with loop_2pass
143145
size = len(ret)
144-
rng = range(len(group_idx) - 1, -1, -1) if reverse else range(len(group_idx))
146+
rng = (
147+
range(len(group_idx) - 1, -1, -1) if reverse else range(len(group_idx))
148+
)
145149
for i in rng:
146150
ri = group_idx[i]
147151
if ri < 0:
@@ -242,14 +246,18 @@ def __call__(
242246
axis=None,
243247
ddof=0,
244248
):
245-
iv = input_validation(group_idx, a, size=size, order=order, axis=axis, check_bounds=False)
249+
iv = input_validation(
250+
group_idx, a, size=size, order=order, axis=axis, check_bounds=False
251+
)
246252
group_idx, a, flat_size, ndim_idx, size, _ = iv
247253

248254
# TODO: The typecheck should be done by the class itself, not by check_dtype
249255
dtype = check_dtype(dtype, self.func, a, len(group_idx))
250256
check_fill_value(fill_value, dtype, func=self.func)
251257
input_dtype = type(a) if np.isscalar(a) else a.dtype
252-
ret, _, _, _ = self._initialize(flat_size, fill_value, dtype, input_dtype, group_idx.size)
258+
ret, _, _, _ = self._initialize(
259+
flat_size, fill_value, dtype, input_dtype, group_idx.size
260+
)
253261
group_idx = np.ascontiguousarray(group_idx)
254262

255263
sortidx = np.argsort(group_idx, kind="mergesort")
@@ -493,7 +501,7 @@ class CumMin(AggregateNtoN, Min):
493501

494502

495503
def get_funcs():
496-
funcs = dict()
504+
funcs = {}
497505
for op in (
498506
Sum,
499507
Prod,
@@ -530,7 +538,16 @@ def get_funcs():
530538

531539

532540
def aggregate(
533-
group_idx, a, func="sum", size=None, fill_value=0, order="C", dtype=None, axis=None, cache=True, **kwargs
541+
group_idx,
542+
a,
543+
func="sum",
544+
size=None,
545+
fill_value=0,
546+
order="C",
547+
dtype=None,
548+
axis=None,
549+
cache=True,
550+
**kwargs,
534551
):
535552
func = get_func(func, aliasing, _impl_dict)
536553
if not isinstance(func, str):
@@ -541,7 +558,9 @@ def aggregate(
541558
if cache is True:
542559
cache = _default_cache
543560
aggregate_op = cache.setdefault(func, AggregateGeneric(func))
544-
return aggregate_op(group_idx, a, size, fill_value, order, dtype, axis, **kwargs)
561+
return aggregate_op(
562+
group_idx, a, size, fill_value, order, dtype, axis, **kwargs
563+
)
545564
else:
546565
func = _impl_dict[func]
547566
return func(group_idx, a, size, fill_value, order, dtype, axis, **kwargs)

numpy_groupies/aggregate_numpy.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def _sum(group_idx, a, size, fill_value, dtype=None):
3030
ret.real = np.bincount(group_idx, weights=a.real, minlength=size)
3131
ret.imag = np.bincount(group_idx, weights=a.imag, minlength=size)
3232
else:
33-
ret = np.bincount(group_idx, weights=a, minlength=size).astype(dtype, copy=False)
33+
ret = np.bincount(group_idx, weights=a, minlength=size).astype(
34+
dtype, copy=False
35+
)
3436

3537
if fill_value != 0:
3638
_fill_untouched(group_idx, ret, fill_value)
@@ -117,7 +119,9 @@ def _argmax(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
117119
ret = np.full(size, fill_value, dtype=dtype)
118120
group_idx_max = group_idx[is_max]
119121
(argmax,) = is_max.nonzero()
120-
ret[group_idx_max[::-1]] = argmax[::-1] # reverse to ensure first value for each group wins
122+
ret[group_idx_max[::-1]] = argmax[
123+
::-1
124+
] # reverse to ensure first value for each group wins
121125
return ret
122126

123127

@@ -129,7 +133,9 @@ def _argmin(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
129133
ret = np.full(size, fill_value, dtype=dtype)
130134
group_idx_min = group_idx[is_min]
131135
(argmin,) = is_min.nonzero()
132-
ret[group_idx_min[::-1]] = argmin[::-1] # reverse to ensure first value for each group wins
136+
ret[group_idx_min[::-1]] = argmin[
137+
::-1
138+
] # reverse to ensure first value for each group wins
133139
return ret
134140

135141

@@ -143,7 +149,9 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
143149
sums.real = np.bincount(group_idx, weights=a.real, minlength=size)
144150
sums.imag = np.bincount(group_idx, weights=a.imag, minlength=size)
145151
else:
146-
sums = np.bincount(group_idx, weights=a, minlength=size).astype(dtype, copy=False)
152+
sums = np.bincount(group_idx, weights=a, minlength=size).astype(
153+
dtype, copy=False
154+
)
147155

148156
with np.errstate(divide="ignore", invalid="ignore"):
149157
ret = sums.astype(dtype, copy=False) / counts
@@ -160,15 +168,19 @@ def _sum_of_squres(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
160168
return ret
161169

162170

163-
def _var(group_idx, a, size, fill_value, dtype=np.dtype(np.float64), sqrt=False, ddof=0):
171+
def _var(
172+
group_idx, a, size, fill_value, dtype=np.dtype(np.float64), sqrt=False, ddof=0
173+
):
164174
if np.ndim(a) == 0:
165175
raise ValueError("cannot take variance with scalar a")
166176
counts = np.bincount(group_idx, minlength=size)
167177
sums = np.bincount(group_idx, weights=a, minlength=size)
168178
with np.errstate(divide="ignore", invalid="ignore"):
169179
means = sums.astype(dtype, copy=False) / counts
170180
counts = np.where(counts > ddof, counts - ddof, 0)
171-
ret = np.bincount(group_idx, (a - means[group_idx]) ** 2, minlength=size) / counts
181+
ret = (
182+
np.bincount(group_idx, (a - means[group_idx]) ** 2, minlength=size) / counts
183+
)
172184
if sqrt:
173185
ret = np.sqrt(ret) # this is now std not var
174186
if not np.isnan(fill_value):
@@ -208,7 +220,9 @@ def _array(group_idx, a, size, fill_value, dtype=None):
208220
return ret
209221

210222

211-
def _generic_callable(group_idx, a, size, fill_value, dtype=None, func=lambda g: g, **kwargs):
223+
def _generic_callable(
224+
group_idx, a, size, fill_value, dtype=None, func=lambda g: g, **kwargs
225+
):
212226
"""groups a by inds, and then applies foo to each group in turn, placing
213227
the results in an array."""
214228
groups = _array(group_idx, a, size, ())
@@ -244,7 +258,9 @@ def _cumsum(group_idx, a, size, fill_value=None, dtype=None):
244258

245259
def _nancumsum(group_idx, a, size, fill_value=None, dtype=None):
246260
a_nonans = np.where(np.isnan(a), 0, a)
247-
group_idx_nonans = np.where(np.isnan(group_idx), np.nanmax(group_idx) + 1, group_idx)
261+
group_idx_nonans = np.where(
262+
np.isnan(group_idx), np.nanmax(group_idx) + 1, group_idx
263+
)
248264
return _cumsum(group_idx_nonans, a_nonans, size, fill_value=fill_value, dtype=dtype)
249265

250266

@@ -271,7 +287,11 @@ def _nancumsum(group_idx, a, size, fill_value=None, dtype=None):
271287
sumofsquares=_sum_of_squres,
272288
generic=_generic_callable,
273289
)
274-
_impl_dict.update(("nan" + k, v) for k, v in list(_impl_dict.items()) if k not in funcs_no_separate_nan)
290+
_impl_dict.update(
291+
("nan" + k, v)
292+
for k, v in list(_impl_dict.items())
293+
if k not in funcs_no_separate_nan
294+
)
275295
_impl_dict["nancumsum"] = _nancumsum
276296

277297

@@ -321,7 +341,9 @@ def _aggregate_base(
321341
dtype = check_dtype(dtype, func, a, flat_size)
322342
check_fill_value(fill_value, dtype, func=func)
323343
func = _impl_dict[func]
324-
ret = func(group_idx, a, flat_size, fill_value=fill_value, dtype=dtype, **kwargs)
344+
ret = func(
345+
group_idx, a, flat_size, fill_value=fill_value, dtype=dtype, **kwargs
346+
)
325347

326348
# deal with ndimensional indexing
327349
if ndim_idx > 1:
@@ -335,7 +357,17 @@ def _aggregate_base(
335357
return ret
336358

337359

338-
def aggregate(group_idx, a, func="sum", size=None, fill_value=0, order="C", dtype=None, axis=None, **kwargs):
360+
def aggregate(
361+
group_idx,
362+
a,
363+
func="sum",
364+
size=None,
365+
fill_value=0,
366+
order="C",
367+
dtype=None,
368+
axis=None,
369+
**kwargs,
370+
):
339371
return _aggregate_base(
340372
group_idx,
341373
a,

numpy_groupies/aggregate_numpy_ufunc.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,17 @@ def _max(group_idx, a, size, fill_value, dtype=None):
9797
)
9898

9999

100-
def aggregate(group_idx, a, func="sum", size=None, fill_value=0, order="C", dtype=None, axis=None, **kwargs):
100+
def aggregate(
101+
group_idx,
102+
a,
103+
func="sum",
104+
size=None,
105+
fill_value=0,
106+
order="C",
107+
dtype=None,
108+
axis=None,
109+
**kwargs,
110+
):
101111
func = get_func(func, aliasing, _impl_dict)
102112
if not isinstance(func, str):
103113
raise NotImplementedError("No such ufunc available")

numpy_groupies/aggregate_pandas.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
def _wrapper(group_idx, a, size, fill_value, func="sum", dtype=None, ddof=0, **kwargs):
1717
funcname = func.__name__ if callable(func) else func
18-
kwargs = dict()
18+
kwargs = {}
1919
if funcname in ("var", "std"):
2020
kwargs["ddof"] = ddof
2121
df = pd.DataFrame({"group_idx": group_idx, "a": a})
@@ -37,7 +37,9 @@ def _wrapper(group_idx, a, size, fill_value, func="sum", dtype=None, ddof=0, **k
3737
_supported_funcs = "sum prod all any min max mean var std first last cumsum cumprod cummax cummin".split()
3838
_impl_dict = {fn: partial(_wrapper, func=fn) for fn in _supported_funcs}
3939
_impl_dict.update(
40-
("nan" + fn, partial(_wrapper, func=fn)) for fn in _supported_funcs if fn not in funcs_no_separate_nan
40+
("nan" + fn, partial(_wrapper, func=fn))
41+
for fn in _supported_funcs
42+
if fn not in funcs_no_separate_nan
4143
)
4244
_impl_dict.update(
4345
allnan=partial(_wrapper, func=allnan),
@@ -52,7 +54,17 @@ def _wrapper(group_idx, a, size, fill_value, func="sum", dtype=None, ddof=0, **k
5254
)
5355

5456

55-
def aggregate(group_idx, a, func="sum", size=None, fill_value=0, order="C", dtype=None, axis=None, **kwargs):
57+
def aggregate(
58+
group_idx,
59+
a,
60+
func="sum",
61+
size=None,
62+
fill_value=0,
63+
order="C",
64+
dtype=None,
65+
axis=None,
66+
**kwargs,
67+
):
5668
return _aggregate_base(
5769
group_idx,
5870
a,

numpy_groupies/aggregate_purepy.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def _sort(group_idx, a, reverse=False):
6767
def _argsort(unordered):
6868
return sorted(range(len(unordered)), key=lambda k: unordered[k])
6969

70-
sortidx = _argsort(list((gi, aj) for gi, aj in zip(group_idx, -a if reverse else a)))
70+
sortidx = _argsort(
71+
list((gi, aj) for gi, aj in zip(group_idx, -a if reverse else a))
72+
)
7173
revidx = _argsort(_argsort(group_idx))
7274
a_srt = [a[si] for si in sortidx]
7375
return [a_srt[ri] for ri in revidx]
@@ -93,10 +95,24 @@ def _argsort(unordered):
9395
argmin=_argmin,
9496
len=len,
9597
)
96-
_impl_dict.update(("nan" + k, v) for k, v in list(_impl_dict.items()) if k not in funcs_no_separate_nan)
98+
_impl_dict.update(
99+
("nan" + k, v)
100+
for k, v in list(_impl_dict.items())
101+
if k not in funcs_no_separate_nan
102+
)
97103

98104

99-
def aggregate(group_idx, a, func="sum", size=None, fill_value=0, order=None, dtype=None, axis=None, **kwargs):
105+
def aggregate(
106+
group_idx,
107+
a,
108+
func="sum",
109+
size=None,
110+
fill_value=0,
111+
order=None,
112+
dtype=None,
113+
axis=None,
114+
**kwargs,
115+
):
100116
if axis is not None:
101117
raise NotImplementedError("axis arg not supported in purepy implementation.")
102118

@@ -105,29 +121,37 @@ def aggregate(group_idx, a, func="sum", size=None, fill_value=0, order=None, dty
105121
try:
106122
size = 1 + int(max(group_idx))
107123
except (TypeError, ValueError):
108-
raise NotImplementedError("pure python implementation doesn't accept ndim idx input.")
124+
raise NotImplementedError(
125+
"pure python implementation doesn't accept ndim idx input."
126+
)
109127

110128
for i in group_idx:
111129
try:
112130
i = int(i)
113131
except (TypeError, ValueError):
114132
if isinstance(i, (list, tuple)):
115-
raise NotImplementedError("pure python implementation doesn't accept ndim idx input.")
133+
raise NotImplementedError(
134+
"pure python implementation doesn't accept ndim idx input."
135+
)
116136
else:
117137
try:
118138
len(i)
119139
except TypeError:
120140
raise ValueError(f"invalid value found in group_idx: {i}")
121141
else:
122-
raise NotImplementedError("pure python implementation doesn't accept ndim indexed input.")
142+
raise NotImplementedError(
143+
"pure python implementation doesn't accept ndim indexed input."
144+
)
123145
else:
124146
if i < 0:
125147
raise ValueError("group_idx contains negative value")
126148

127149
func = get_func(func, aliasing, _impl_dict)
128150
if isinstance(a, (int, float)):
129151
if func not in ("sum", "prod", "len"):
130-
raise ValueError("scalar inputs are supported only for 'sum', 'prod' and 'len'")
152+
raise ValueError(
153+
"scalar inputs are supported only for 'sum', 'prod' and 'len'"
154+
)
131155
a = [a] * len(group_idx)
132156
elif len(group_idx) != len(a):
133157
raise ValueError("group_idx and a must be of the same length")
@@ -136,7 +160,9 @@ def aggregate(group_idx, a, func="sum", size=None, fill_value=0, order=None, dty
136160
if func.startswith("nan"):
137161
func = func[3:]
138162
# remove nans
139-
group_idx, a = zip(*((ix, val) for ix, val in zip(group_idx, a) if not math.isnan(val)))
163+
group_idx, a = zip(
164+
*((ix, val) for ix, val in zip(group_idx, a) if not math.isnan(val))
165+
)
140166

141167
func = _impl_dict[func]
142168
if func is _sort:

0 commit comments

Comments
 (0)