Skip to content

Commit 6a689ae

Browse files
authored
Merge pull request #46 from dcherian/faster-input-validation
Faster input validation
2 parents 7a31d1f + d6f1bf1 commit 6a689ae

File tree

1 file changed

+47
-18
lines changed

1 file changed

+47
-18
lines changed

numpy_groupies/utils_numpy.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,53 @@ def check_group_idx(group_idx, a=None, check_min=True):
189189
raise ValueError("group_idx contains negative indices")
190190

191191

192+
def _ravel_group_idx(group_idx, a, axis, size, order, method="ravel"):
193+
ndim_a = a.ndim
194+
# Create the broadcast-ready multidimensional indexing.
195+
# Note the user could do this themselves, so this is
196+
# very much just a convenience.
197+
size_in = int(np.max(group_idx)) + 1 if size is None else size
198+
group_idx_in = group_idx
199+
group_idx = []
200+
size = []
201+
for ii, s in enumerate(a.shape):
202+
if method == "ravel":
203+
ii_idx = group_idx_in if ii == axis else np.arange(s)
204+
ii_shape = [1] * ndim_a
205+
ii_shape[ii] = s
206+
group_idx.append(ii_idx.reshape(ii_shape))
207+
size.append(size_in if ii == axis else s)
208+
# Use the indexing, and return. It's a bit simpler than
209+
# using trying to keep all the logic below happy
210+
if method == "ravel":
211+
group_idx = np.ravel_multi_index(group_idx, size, order=order,
212+
mode='raise')
213+
elif method == "offset":
214+
group_idx = offset_labels(group_idx_in, a.shape, axis, order, size_in)
215+
return group_idx, size
216+
217+
def offset_labels(group_idx, inshape, axis, order, size):
218+
"""
219+
Offset group labels by dimension. This is used when we
220+
reduce over a subset of the dimensions of by. It assumes that the reductions
221+
dimensions have been flattened in the last dimension
222+
Copied from
223+
https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy
224+
"""
225+
if axis not in (-1, len(inshape) - 1):
226+
newshape = (s for idx, s in enumerate(inshape) if idx != axis) + (inshape[axis],)
227+
else:
228+
newshape = inshape
229+
group_idx = np.broadcast_to(group_idx, newshape)
230+
group_idx: np.ndarray = (
231+
group_idx
232+
+ np.arange(np.prod(group_idx.shape[:-1]), dtype=int).reshape((*group_idx.shape[:-1], -1))
233+
* size
234+
)
235+
return group_idx.reshape(inshape).ravel()
236+
192237
def input_validation(group_idx, a, size=None, order='C', axis=None,
193-
ravel_group_idx=True, check_bounds=True):
238+
ravel_group_idx=True, check_bounds=True, method="ravel"):
194239
""" Do some fairly extensive checking of group_idx and a, trying to
195240
give the user as much help as possible with what is wrong. Also,
196241
convert ndim-indexing to 1d indexing.
@@ -230,23 +275,7 @@ def input_validation(group_idx, a, size=None, order='C', axis=None,
230275
raise NotImplementedError("when using axis arg, size must be"
231276
"None or scalar.")
232277
else:
233-
# Create the broadcast-ready multidimensional indexing.
234-
# Note the user could do this themselves, so this is
235-
# very much just a convenience.
236-
size_in = int(np.max(group_idx)) + 1 if size is None else size
237-
group_idx_in = group_idx
238-
group_idx = []
239-
size = []
240-
for ii, s in enumerate(a.shape):
241-
ii_idx = group_idx_in if ii == axis else np.arange(s)
242-
ii_shape = [1] * ndim_a
243-
ii_shape[ii] = s
244-
group_idx.append(ii_idx.reshape(ii_shape))
245-
size.append(size_in if ii == axis else s)
246-
# Use the indexing, and return. It's a bit simpler than
247-
# using trying to keep all the logic below happy
248-
group_idx = np.ravel_multi_index(group_idx, size, order=order,
249-
mode='raise')
278+
group_idx, size = _ravel_group_idx(group_idx, a, axis, size, order, method=method)
250279
flat_size = np.prod(size)
251280
ndim_idx = ndim_a
252281
return group_idx.ravel(), a.ravel(), flat_size, ndim_idx, size

0 commit comments

Comments
 (0)