Skip to content

Commit 04503c3

Browse files
committed
Further tests cleanup; fix xfail wrapper
1 parent 11c3d88 commit 04503c3

File tree

4 files changed

+70
-40
lines changed

4 files changed

+70
-40
lines changed

numpy_groupies/aggregate_pandas.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def _wrapper(group_idx, a, size, fill_value, func='sum', dtype=None, ddof=0, **k
3939
nanlen=partial(_wrapper, func='count'),
4040
argmax=partial(_wrapper, func='idxmax'),
4141
argmin=partial(_wrapper, func='idxmin'),
42+
nanargmax=partial(_wrapper, func='idxmax'),
43+
nanargmin=partial(_wrapper, func='idxmin'),
4244
generic=_wrapper)
4345

4446

numpy_groupies/tests/__init__.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,33 +26,37 @@ def _impl_name(impl):
2626

2727

2828
_not_implemented_by_impl_name = {
29-
'numpy': ['cumprod','cummax', 'cummin'],
30-
'purepy': ['cumprod','cummax', 'cummin'],
31-
'numba': ('array', 'list'),
32-
'pandas': ('array', 'list'),
33-
'weave': ('argmin', 'argmax', 'array', 'list', 'cumsum',
34-
'<lambda>', 'func_preserve_order', 'func_arbitrary')}
29+
'numpy': ('cumprod','cummax', 'cummin'),
30+
'purepy': ('cumsum', 'cumprod','cummax', 'cummin', 'sumofsquares'),
31+
'numba': ('array', 'list', 'sort'),
32+
'pandas': ('array', 'list', 'sort', 'sumofsquares', 'nansumofsquares'),
33+
'weave': ('argmin', 'argmax', 'array', 'list', 'sort', 'cumsum', 'cummax', 'cummin',
34+
'nanargmin', 'nanargmax', 'sumofsquares', 'nansumofsquares',
35+
'<lambda>', 'custom_callable'),
36+
'ufunc': 'NO_CHECK'}
37+
3538

3639
def _wrap_notimplemented_xfail(impl, name=None):
3740
"""Some implementations lack some functionality. That's ok, let's xfail that instead of raising errors."""
3841

39-
def _try_xfail(*args, **kwargs):
42+
def try_xfail(*args, **kwargs):
4043
try:
4144
return impl(*args, **kwargs)
4245
except NotImplementedError as e:
46+
impl_name = impl.__module__.split('_')[-1]
4347
func = kwargs.pop('func', None)
4448
if callable(func):
4549
func = func.__name__
46-
wrap_funcs = _not_implemented_by_impl_name.get(func, None)
47-
if wrap_funcs is None or func in wrap_funcs:
50+
not_implemented_ok = _not_implemented_by_impl_name.get(impl_name, [])
51+
if not_implemented_ok == 'NO_CHECK' or func in not_implemented_ok:
4852
pytest.xfail("Functionality not implemented")
4953
else:
5054
raise e
5155
if name:
52-
_try_xfail.__name__ = name
56+
try_xfail.__name__ = name
5357
else:
54-
_try_xfail.__name__ = impl.__name__
55-
return _try_xfail
58+
try_xfail.__name__ = impl.__name__
59+
return try_xfail
5660

5761

5862
func_list = ('sum', 'prod', 'min', 'max', 'all', 'any', 'mean', 'std', 'var', 'len',

numpy_groupies/tests/test_compare.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ def aggregate_cmp(request, seed=100):
6767
return AttrDict(locals())
6868

6969

70+
def _deselect_purepy(aggregate_cmp, *args, **kwargs):
71+
# purepy implementation does not handle ndim arrays
72+
# This is a won't fix and should be deselected instead of skipped
73+
return aggregate_cmp.endswith('py')
74+
75+
76+
def _deselect_purepy_nanfuncs(aggregate_cmp, func, *args, **kwargs):
77+
# purepy implementation does not handle nan values correctly
78+
# This is a won't fix and should be deselected instead of skipped
79+
return 'nan' in getattr(func, '__name__', func) and aggregate_cmp.endswith('py')
80+
81+
7082
def func_arbitrary(iterator):
7183
tmp = 0
7284
for x in iterator:
@@ -81,14 +93,10 @@ def func_preserve_order(iterator):
8193
return tmp
8294

8395

84-
def _deselect_purepy_nanfuncs(aggregate_cmp, func, fill_value):
85-
# purepy implementation does not handle nan values correctly
86-
return 'nan' in getattr(func, '__name__', func) and aggregate_cmp.endswith('py')
87-
88-
96+
@pytest.mark.filterwarnings("ignore:numpy.ufunc size changed")
8997
@pytest.mark.deselect_if(func=_deselect_purepy_nanfuncs)
90-
@pytest.mark.parametrize(["func", "fill_value"], product(func_list, [0, 1, np.nan]),
91-
ids=lambda x: getattr(x, '__name__', x))
98+
@pytest.mark.parametrize("fill_value", [0, 1, np.nan])
99+
@pytest.mark.parametrize("func", func_list, ids=lambda x: getattr(x, '__name__', x))
92100
def test_cmp(aggregate_cmp, func, fill_value, decimal=10):
93101
is_nanfunc = 'nan' in getattr(func, '__name__', func)
94102
a = aggregate_cmp.nana if is_nanfunc else aggregate_cmp.a
@@ -107,9 +115,16 @@ def test_cmp(aggregate_cmp, func, fill_value, decimal=10):
107115
raise
108116
if isinstance(ref, np.ndarray):
109117
assert res.dtype == ref.dtype
110-
np.testing.assert_allclose(res, ref, rtol=10**-decimal)
118+
try:
119+
np.testing.assert_allclose(res, ref, rtol=10**-decimal)
120+
except AssertionError:
121+
if 'arg' in func and aggregate_cmp.test_pair.startswith('pandas'):
122+
pytest.xfail("pandas doesn't fill indices for all-nan groups with fill_value, but with -inf instead")
123+
else:
124+
raise
111125

112126

127+
@pytest.mark.deselect_if(func=_deselect_purepy)
113128
@pytest.mark.parametrize(["ndim", "order"], product([2, 3], ["C", "F"]))
114129
def test_cmp_ndim(aggregate_cmp, ndim, order, outsize=100, decimal=14):
115130
nindices = int(outsize ** ndim)

numpy_groupies/tests/test_generic.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@ def aggregate_all(request):
1717
return _wrap_notimplemented_xfail(impl.aggregate, 'aggregate_' + name)
1818

1919

20+
def _deselect_purepy(aggregate_all, *args, **kwargs):
21+
# purepy implementations does not handle nan values and ndim correctly.
22+
# So it needs to be excluded from several tests."""
23+
return aggregate_all.__name__.endswith('purepy')
24+
25+
26+
def _deselect_purepy_and_invalid_axis(aggregate_all, size, axis, *args, **kwargs):
27+
if axis >= len(size):
28+
return True
29+
if aggregate_all.__name__.endswith('purepy'):
30+
# purepy does not handle axis parameter
31+
return True
32+
33+
2034
def test_preserve_missing(aggregate_all):
2135
res = aggregate_all(np.array([0, 1, 3, 1, 3]), np.arange(101, 106, dtype=int))
2236
np.testing.assert_array_equal(res, np.array([101, 206, 0, 208]))
@@ -100,12 +114,14 @@ def test_array_ordering(aggregate_all, order, size=10):
100114
assert aggregate_all(np.zeros(size, dtype=int), mat[0, :], order=order)[0] == sum(range(size))
101115

102116

117+
@pytest.mark.deselect_if(func=_deselect_purepy)
103118
@pytest.mark.parametrize("size", [None, (10, 2)])
104119
def test_ndim_group_idx(aggregate_all, size):
105120
group_idx = np.vstack((np.repeat(np.arange(10), 10), np.repeat([0, 1], 50)))
106121
aggregate_all(group_idx, 1, size=size)
107122

108123

124+
@pytest.mark.deselect_if(func=_deselect_purepy)
109125
@pytest.mark.parametrize(["ndim", "order"], itertools.product([1, 2, 3], ["C", "F"]))
110126
def test_ndim_indexing(aggregate_all, ndim, order, outsize=10):
111127
nindices = int(outsize ** ndim)
@@ -242,9 +258,8 @@ def test_argmin_argmax_nonans(aggregate_all):
242258
np.testing.assert_array_equal(res, [3, -1, -1, 5])
243259

244260

261+
@pytest.mark.deselect_if(func=_deselect_purepy)
245262
def test_argmin_argmax_nans(aggregate_all):
246-
if aggregate_all.__name__.endswith('purepy'):
247-
pytest.xfail("purepy doesn't handle nan values correctly")
248263
if aggregate_all.__name__.endswith('pandas'):
249264
pytest.xfail("pandas always ignores nans")
250265

@@ -258,9 +273,10 @@ def test_argmin_argmax_nans(aggregate_all):
258273
np.testing.assert_array_equal(res, [3, -1, -1, -1])
259274

260275

276+
@pytest.mark.deselect_if(func=_deselect_purepy)
261277
def test_nanargmin_nanargmax_nans(aggregate_all):
262-
if aggregate_all.__name__.endswith('purepy'):
263-
pytest.xfail("purepy doesn't handle nan values correctly")
278+
if aggregate_all.__name__.endswith('pandas'):
279+
pytest.xfail("pandas doesn't fill indices for all-nan groups with fill_value but with -inf instead")
264280

265281
group_idx = np.array([0, 0, 0, 0, 3, 3, 3, 3])
266282
a = np.array([4, 4, np.nan, 1, np.nan, np.nan, np.nan, np.nan])
@@ -339,12 +355,8 @@ def test_list_ordering(aggregate_all, order):
339355
a = a[::-1]
340356
ref = a[:4]
341357

342-
try:
343-
res = aggregate_all(group_idx, a, func=list)
344-
except NotImplementedError:
345-
pytest.xfail("Function not yet implemented")
346-
else:
347-
np.testing.assert_array_equal(np.array(res[0]), ref)
358+
res = aggregate_all(group_idx, a, func=list)
359+
np.testing.assert_array_equal(np.array(res[0]), ref)
348360

349361

350362
@pytest.mark.parametrize("order", ["normal", "reverse"])
@@ -360,14 +372,6 @@ def test_sort(aggregate_all, order):
360372
np.testing.assert_array_equal(res, ref)
361373

362374

363-
def _deselect_purepy_and_invalid_axis(aggregate_all, func, size, axis):
364-
if axis >= len(size):
365-
return True
366-
if aggregate_all.__name__.endswith('purepy'):
367-
# purepy does not handle axis parameter
368-
return True
369-
370-
371375
@pytest.mark.deselect_if(func=_deselect_purepy_and_invalid_axis)
372376
@pytest.mark.parametrize("axis", (0, 1))
373377
@pytest.mark.parametrize("size", ((12,), (12, 5)))
@@ -424,6 +428,7 @@ def test_along_axis(aggregate_all, func, size, axis):
424428
np.testing.assert_allclose(actual.squeeze(), expected)
425429

426430

431+
@pytest.mark.deselect_if(func=_deselect_purepy)
427432
def test_not_last_axis_reduction(aggregate_all):
428433
group_idx = np.array([1, 2, 2, 0, 1])
429434
a = np.array([
@@ -442,8 +447,9 @@ def test_not_last_axis_reduction(aggregate_all):
442447
np.testing.assert_allclose(expected, actual)
443448

444449

450+
@pytest.mark.deselect_if(func=_deselect_purepy)
445451
def test_custom_callable(aggregate_all):
446-
def sum_(x):
452+
def custom_callable(x):
447453
return x.sum()
448454

449455
size = (10,)
@@ -453,12 +459,13 @@ def sum_(x):
453459
a = np.random.randn(*size)
454460

455461
expected = a.sum(axis=axis, keepdims=True)
456-
actual = aggregate_all(group_idx, a, axis=axis, func=sum_, fill_value=0)
462+
actual = aggregate_all(group_idx, a, axis=axis, func=custom_callable, fill_value=0)
457463
assert actual.ndim == a.ndim
458464

459465
np.testing.assert_allclose(actual, expected)
460466

461467

468+
@pytest.mark.deselect_if(func=_deselect_purepy)
462469
def test_argreduction_nD_array_1D_idx(aggregate_all):
463470
# https://github.com/ml31415/numpy-groupies/issues/41
464471
group_idx = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0], dtype=int)
@@ -468,6 +475,7 @@ def test_argreduction_nD_array_1D_idx(aggregate_all):
468475
np.testing.assert_equal(actual, expected)
469476

470477

478+
@pytest.mark.deselect_if(func=_deselect_purepy)
471479
def test_argreduction_negative_fill_value(aggregate_all):
472480
if aggregate_all.__name__.endswith('pandas'):
473481
pytest.xfail("pandas always skips nan values")
@@ -479,6 +487,7 @@ def test_argreduction_negative_fill_value(aggregate_all):
479487
np.testing.assert_equal(actual, expected)
480488

481489

490+
@pytest.mark.deselect_if(func=_deselect_purepy)
482491
@pytest.mark.parametrize("nan_inds", (None, tuple([[1, 4, 5], Ellipsis]), tuple((1, (0, 1, 2, 3)))))
483492
@pytest.mark.parametrize("ddof", (0, 1))
484493
@pytest.mark.parametrize("func", ("nanvar", "nanstd"))

0 commit comments

Comments
 (0)