Skip to content

Commit c119870

Browse files
authored
Fix custom callable reductions. (#59)
1 parent a3ea616 commit c119870

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

numpy_groupies/tests/test_generic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,24 @@ def test_agg_along_axis(aggregate_all, size, func, axis):
353353
# instead we squeeze out the extra dims in actual.
354354
np.testing.assert_allclose(actual.squeeze(), expected)
355355

356+
357+
def test_custom_callable(aggregate_all):
358+
def sum_(array):
359+
return array.sum()
360+
361+
size = (10,)
362+
axis = -1
363+
364+
group_idx = np.zeros(size, dtype=int)
365+
array = np.random.randn(*size)
366+
367+
expected = array.sum(axis=axis, keepdims=True)
368+
actual = aggregate_all(group_idx, array, axis=axis, func=sum_, fill_value=0)
369+
assert actual.ndim == array.ndim
370+
371+
np.testing.assert_allclose(actual, expected)
372+
373+
356374
def test_argreduction_nD_array_1D_idx(aggregate_all):
357375
# regression test for GH41
358376
labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0], dtype=int)

numpy_groupies/utils_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def input_validation(group_idx, a, size=None, order='C', axis=None,
283283
else:
284284
is_form_3 = group_idx.ndim == 1 and a.ndim > 1 and axis is not None
285285
orig_shape = a.shape if is_form_3 else group_idx.shape
286-
if "arg" in func:
286+
if isinstance(func, str) and "arg" in func:
287287
unravel_shape = orig_shape
288288
else:
289289
unravel_shape = None

0 commit comments

Comments
 (0)