Skip to content

Commit 8f14bbf

Browse files
committed
Add CuPy support
1 parent 71f119f commit 8f14bbf

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

numpy_groupies/utils_numpy.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _ravel_group_idx(group_idx, a, axis, size, order, method="ravel"):
219219
size = []
220220
for ii, s in enumerate(a.shape):
221221
if method == "ravel":
222-
ii_idx = group_idx_in if ii == axis else np.arange(s)
222+
ii_idx = group_idx_in if ii == axis else np.arange(s, like=group_idx_in)
223223
ii_shape = [1] * ndim_a
224224
ii_shape[ii] = s
225225
group_idx.append(ii_idx.reshape(ii_shape))
@@ -249,10 +249,8 @@ def offset_labels(group_idx, inshape, axis, order, size):
249249
group_idx = np.moveaxis(group_idx, axis, -1)
250250
newshape = group_idx.shape[:-1] + (-1,)
251251

252-
group_idx = (group_idx +
253-
np.arange(np.prod(newshape[:-1]), dtype=int).reshape(newshape)
254-
* size
255-
)
252+
offset_ = np.arange(np.prod(newshape[:-1]), dtype=int, like=group_idx).reshape(newshape)
253+
group_idx = group_idx + offset_ * size
256254
if axis not in (-1, len(inshape) - 1):
257255
return np.moveaxis(group_idx, -1, axis)
258256
else:

0 commit comments

Comments
 (0)