Skip to content

Commit 197bde5

Browse files
committed
Improve fill_value handling
1 parent 0911e9c commit 197bde5

File tree

7 files changed

+71
-42
lines changed

7 files changed

+71
-42
lines changed

numpy_groupies/aggregate_numba.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __call__(self, group_idx, a, size=None, fill_value=0, order='C',
4343

4444
# TODO: The typecheck should be done by the class itself, not by check_dtype
4545
dtype = check_dtype(dtype, self.func, a, len(group_idx))
46-
check_fill_value(fill_value, dtype)
46+
check_fill_value(fill_value, dtype, func=self.func)
4747
input_dtype = type(a) if np.isscalar(a) else a.dtype
4848
ret, counter, mean, outer = self._initialize(flat_size, fill_value, dtype, input_dtype, group_idx.size)
4949
group_idx = np.ascontiguousarray(group_idx)
@@ -85,7 +85,10 @@ def _initialize(cls, flat_size, fill_value, dtype, input_dtype, input_size):
8585
@classmethod
8686
def _finalize(cls, ret, counter, fill_value):
8787
if cls.forced_fill_value is not None and fill_value != cls.forced_fill_value:
88-
ret[counter] = fill_value
88+
if cls.counter_dtype == bool:
89+
ret[counter] = fill_value
90+
else:
91+
ret[~counter.astype(bool)] = fill_value
8992

9093
@classmethod
9194
def callable(cls, nans=False, reverse=False, scalar=False):
@@ -192,7 +195,7 @@ def __call__(self, group_idx, a, size=None, fill_value=0, order='C',
192195

193196
# TODO: The typecheck should be done by the class itself, not by check_dtype
194197
dtype = check_dtype(dtype, self.func, a, len(group_idx))
195-
check_fill_value(fill_value, dtype)
198+
check_fill_value(fill_value, dtype, func=self.func)
196199
input_dtype = type(a) if np.isscalar(a) else a.dtype
197200
ret, _, _, _= self._initialize(flat_size, fill_value, dtype, input_dtype, group_idx.size)
198201
group_idx = np.ascontiguousarray(group_idx)
@@ -354,6 +357,7 @@ def _inner(ri, val, ret, counter, mean):
354357

355358

356359
class Mean(Aggregate2pass):
360+
forced_fill_value = 0
357361
counter_fill_value = 0
358362
counter_dtype = int
359363

numpy_groupies/aggregate_numpy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .utils import check_boolean, funcs_no_separate_nan, get_func, aggregate_common_doc, isstr
44
from .utils_numpy import (aliasing, minimum_dtype, input_validation,
5-
check_dtype, minimum_dtype_scalar)
5+
check_dtype, check_fill_value, minimum_dtype_scalar)
66

77

88
def _sum(group_idx, a, size, fill_value, dtype=None):
@@ -271,6 +271,7 @@ def _aggregate_base(group_idx, a, func='sum', size=None, fill_value=0,
271271
group_idx = group_idx[good]
272272

273273
dtype = check_dtype(dtype, func, a, flat_size)
274+
check_fill_value(fill_value, dtype, func=func)
274275
func = _impl_dict[func]
275276
ret = func(group_idx, a, flat_size, fill_value=fill_value, dtype=dtype,
276277
**kwargs)

numpy_groupies/aggregate_weave.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
except ImportError:
55
from scipy.weave import inline
66

7-
from .utils import get_func, isstr, funcs_no_separate_nan, aggregate_common_doc
7+
from .utils import get_func, isstr, funcs_no_separate_nan, aggregate_common_doc, check_boolean
88
from .utils_numpy import check_dtype, aliasing, check_fill_value, input_validation
99

1010

@@ -175,7 +175,7 @@ def get_cfuncs():
175175
return c_funcs
176176

177177

178-
c_funcs = get_cfuncs()
178+
c_funcs.update(get_cfuncs())
179179

180180

181181
c_step_count = c_size('group_idx') + r"""
@@ -221,6 +221,9 @@ def step_indices(group_idx):
221221
return indices
222222

223223

224+
_force_fill_0 = frozenset({'sum', 'any', 'len', 'anynan', 'mean', 'std', 'var', 'nansum', 'nanlen', 'nanmean', 'nanstd', 'nanvar'})
225+
_force_fill_1 = frozenset({'prod', 'all', 'allnan', 'nanprod'})
226+
224227
def aggregate(group_idx, a, func='sum', size=None, fill_value=0, order='C',
225228
dtype=None, axis=None, **kwargs):
226229
func = get_func(func, aliasing, optimized_funcs)
@@ -233,15 +236,15 @@ def aggregate(group_idx, a, func='sum', size=None, fill_value=0, order='C',
233236
order=order,
234237
axis=axis)
235238
dtype = check_dtype(dtype, func, a, len(group_idx))
236-
check_fill_value(fill_value, dtype)
239+
check_fill_value(fill_value, dtype, func=func)
237240
nans = func.startswith('nan')
238241

239242
if nans:
240243
flat_size += 1
241244

242-
if func in ('sum', 'any', 'len', 'anynan', 'nansum', 'nanlen'):
245+
if func in _force_fill_0:
243246
ret = np.zeros(flat_size, dtype=dtype)
244-
elif func in ('prod', 'all', 'allnan', 'nanprod'):
247+
elif func in _force_fill_1:
245248
ret = np.ones(flat_size, dtype=dtype)
246249
else:
247250
ret = np.full(flat_size, fill_value, dtype=dtype)
@@ -250,14 +253,14 @@ def aggregate(group_idx, a, func='sum', size=None, fill_value=0, order='C',
250253
inline_vars = dict(group_idx=np.ascontiguousarray(group_idx), a=np.ascontiguousarray(a),
251254
ret=ret, fill_value=fill_value)
252255
# TODO: Have this fixed by proper raveling
253-
if func in ('std', 'var', 'nanstd', 'nanvar'):
256+
if func in {'std', 'var', 'nanstd', 'nanvar'}:
254257
counter = np.zeros_like(ret, dtype=int)
255258
inline_vars['means'] = np.zeros_like(ret)
256259
inline_vars['ddof'] = kwargs.pop('ddof', 0)
257-
elif func in ('mean', 'nanmean'):
260+
elif func in {'mean', 'nanmean'}:
258261
counter = np.zeros_like(ret, dtype=int)
259262
else:
260-
# Using inverse logic, marking anyting touched with zero for later removal
263+
# Using inverse logic, marking anything touched with zero for later removal
261264
counter = np.ones_like(ret, dtype=bool)
262265
inline_vars['counter'] = counter
263266

@@ -267,10 +270,11 @@ def aggregate(group_idx, a, func='sum', size=None, fill_value=0, order='C',
267270
inline(c_funcs[func], inline_vars.keys(), local_dict=inline_vars, define_macros=c_macros, extra_compile_args=c_args)
268271

269272
# Postprocessing
270-
if func in ('sum', 'any', 'anynan', 'nansum') and fill_value != 0:
271-
ret[counter] = fill_value
272-
elif func in ('prod', 'all', 'allnan', 'nanprod') and fill_value != 1:
273-
ret[counter] = fill_value
273+
if func in _force_fill_0 and fill_value != 0 or func in _force_fill_1 and fill_value != 1:
274+
if counter.dtype == np.bool_:
275+
ret[counter] = fill_value
276+
else:
277+
ret[~counter.astype(bool)] = fill_value
274278

275279
if nans:
276280
# Restore the shifted return array

numpy_groupies/benchmarks/generic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ def arbitrary(iterator):
3636
np.nanmean, np.nanvar, np.nanstd, 'nanfirst', 'nanlast',
3737
'cumsum', 'cumprod', 'cummax', 'cummin', arbitrary, 'sort')
3838

39-
40-
def benchmark(implementations, size=5e5, repeat=5, seed=100):
39+
def benchmark_data(size=5e5, seed=100):
4140
rnd = np.random.RandomState(seed=seed)
4241
group_idx = rnd.randint(0, int(1e3), int(size))
4342
a = rnd.random_sample(group_idx.size)
@@ -46,6 +45,11 @@ def benchmark(implementations, size=5e5, repeat=5, seed=100):
4645
nana[(nana < 0.2) & (nana != 0)] = np.nan
4746
nan_share = np.mean(np.isnan(nana))
4847
assert 0.15 < nan_share < 0.25, "%3f%% nans" % (nan_share * 100)
48+
return a, nana, group_idx
49+
50+
51+
def benchmark(implementations, repeat=5, size=5e5, seed=100):
52+
a, nana, group_idx = benchmark_data(size=size, seed=seed)
4953

5054
print("function" + ''.join(impl.__name__.rsplit('_', 1)[1].rjust(14) for impl in implementations))
5155
print("-" * (9 + 14 * len(implementations)))

numpy_groupies/tests/test_compare.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
may throw NotImplementedError in order to show missing functionality without throwing
55
test errors.
66
"""
7-
import itertools
7+
from itertools import product
88
import numpy as np
99
import pytest
1010

@@ -18,7 +18,8 @@ class AttrDict(dict):
1818

1919
@pytest.fixture(params=['np/py', 'weave/np', 'ufunc/np', 'numba/np', 'pandas/np'], scope='module')
2020
def aggregate_cmp(request, seed=100):
21-
if request.param == 'np/py':
21+
test_pair = request.param
22+
if test_pair == 'np/py':
2223
# Some functions in purepy are not implemented
2324
func_ref = _wrap_notimplemented_xfail(aggregate_purepy.aggregate)
2425
func = aggregate_numpy.aggregate
@@ -72,22 +73,34 @@ def func_preserve_order(iterator):
7273
return tmp
7374

7475

75-
func_list = ('sum', 'prod', 'min', 'max', 'all', 'any', 'mean', 'std', 'len',
76-
'argmin', 'argmax', 'anynan', 'allnan', 'cumsum',
77-
'nansum', 'nanprod', 'nanmin', 'nanmax', 'nanmean', 'nanstd', 'nanlen',
78-
func_arbitrary, func_preserve_order)
79-
80-
@pytest.mark.parametrize("func", func_list, ids=lambda x: getattr(x, '__name__', x))
81-
def test_cmp(aggregate_cmp, func, decimal=10):
82-
a = aggregate_cmp.nana if 'nan' in getattr(func, '__name__', func) else aggregate_cmp.a
83-
res = aggregate_cmp.func(aggregate_cmp.group_idx, a, func=func)
84-
ref = aggregate_cmp.func_ref(aggregate_cmp.group_idx, a, func=func)
85-
if isinstance(ref, np.ndarray):
86-
assert res.dtype == ref.dtype
87-
np.testing.assert_allclose(res, ref, rtol=10**-decimal)
76+
func_list = ('sum', 'prod', 'min', 'max', 'all', 'any', 'mean', 'std', 'var', 'len',
77+
'argmin', 'argmax', 'anynan', 'allnan', 'cumsum', func_arbitrary, func_preserve_order,
78+
'nansum', 'nanprod', 'nanmin', 'nanmax', 'nanmean', 'nanstd', 'nanvar','nanlen')
8879

8980

90-
@pytest.mark.parametrize(["ndim", "order"], itertools.product([2, 3], ["C", "F"]))
81+
@pytest.mark.parametrize(["func", "fill_value"], product(func_list, [0, 1, np.nan]),
82+
ids=lambda x: getattr(x, '__name__', x))
83+
def test_cmp(aggregate_cmp, func, fill_value, decimal=10):
84+
a = aggregate_cmp.nana if 'nan' in getattr(func, '__name__', func) else aggregate_cmp.a
85+
try:
86+
ref = aggregate_cmp.func_ref(aggregate_cmp.group_idx, a, func=func, fill_value=fill_value)
87+
except ValueError:
88+
with pytest.raises(ValueError):
89+
aggregate_cmp.func(aggregate_cmp.group_idx, a, func=func, fill_value=fill_value)
90+
else:
91+
try:
92+
res = aggregate_cmp.func(aggregate_cmp.group_idx, a, func=func, fill_value=fill_value)
93+
except ValueError:
94+
if np.isnan(fill_value) and aggregate_cmp.test_pair.endswith('py'):
95+
pytest.skip("pure python version uses lists and does not raise ValueErrors when inserting nan into integers")
96+
else:
97+
raise
98+
if isinstance(ref, np.ndarray):
99+
assert res.dtype == ref.dtype
100+
np.testing.assert_allclose(res, ref, rtol=10**-decimal)
101+
102+
103+
@pytest.mark.parametrize(["ndim", "order"], product([2, 3], ["C", "F"]))
91104
def test_cmp_ndim(aggregate_cmp, ndim, order, outsize=100, decimal=14):
92105
nindices = int(outsize ** ndim)
93106
outshape = tuple([outsize] * ndim)

numpy_groupies/tests/test_generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def test_nan_input(aggregate_all, func, groups=100):
206206

207207
def test_nan_input_len(aggregate_all, groups=100, group_size=5):
208208
if aggregate_all.__name__.endswith('pandas'):
209-
pytest.skip("pandas automatically skip nan values")
209+
pytest.skip("pandas always skips nan values")
210210
group_idx = np.arange(0, groups, dtype=int).repeat(group_size)
211211
a = np.random.random(len(group_idx))
212212
a[::2] = np.nan

numpy_groupies/utils_numpy.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Common helper functions for typing and general numpy tools."""
22
import numpy as np
33

4-
from .utils import get_aliasing
4+
from .utils import get_aliasing, check_boolean
55

66
_alias_numpy = {
77
np.add: 'sum',
@@ -168,12 +168,15 @@ def check_dtype(dtype, func_str, a, n):
168168
return a_dtype
169169

170170

171-
def check_fill_value(fill_value, dtype):
172-
try:
173-
return dtype.type(fill_value)
174-
except ValueError:
175-
raise ValueError("fill_value must be convertible into %s"
176-
% dtype.type.__name__)
171+
def check_fill_value(fill_value, dtype, func=None):
172+
if func in ('all', 'any', 'allnan', 'anynan'):
173+
check_boolean(fill_value)
174+
else:
175+
try:
176+
return dtype.type(fill_value)
177+
except ValueError:
178+
raise ValueError("fill_value must be convertible into %s"
179+
% dtype.type.__name__)
177180

178181

179182
def check_group_idx(group_idx, a=None, check_min=True):

0 commit comments

Comments
 (0)