Skip to content

Commit f8886cd

Browse files
committed
Re-use _take_index for dpnp.take
Should slightly improve efficiency by escaping an additional copy where `out` is not `None` and flattening of indices
1 parent cfc664e commit f8886cd

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import dpctl.utils as dpu
4747
import numpy
4848
from dpctl.tensor._copy_utils import _nonzero_impl
49+
from dpctl.tensor._indexing_functions import _get_indexing_mode
4950
from dpctl.tensor._numpy_helper import normalize_axis_index
5051

5152
import dpnp
@@ -160,14 +161,13 @@ def choose(x1, choices, out=None, mode="raise"):
160161
return call_origin(numpy.choose, x1, choices, out, mode)
161162

162163

163-
def _take_1d_index(x, inds, axis, q, usm_type, out=None):
164+
def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
164165
# arg validation assumed done by caller
165166
x_sh = x.shape
166-
ind0 = inds[0]
167167
axis_end = axis + 1
168-
if 0 in x_sh[axis:axis_end] and ind0.size != 0:
168+
if 0 in x_sh[axis:axis_end] and inds.size != 0:
169169
raise IndexError("cannot take non-empty indices from an empty axis")
170-
res_sh = x_sh[:axis] + ind0.shape + x_sh[axis_end:]
170+
res_sh = x_sh[:axis] + inds.shape + x_sh[axis_end:]
171171

172172
orig_out = None
173173
if out is not None:
@@ -201,13 +201,12 @@ def _take_1d_index(x, inds, axis, q, usm_type, out=None):
201201
_manager = dpu.SequentialOrderManager[q]
202202
dep_evs = _manager.submitted_events
203203

204-
# always use wrap mode here
205204
h_ev, take_ev = ti._take(
206205
src=x,
207-
ind=inds,
206+
ind=(inds,),
208207
dst=out,
209208
axis_start=axis,
210-
mode=0,
209+
mode=mode,
211210
sycl_queue=q,
212211
depends=dep_evs,
213212
)
@@ -318,7 +317,8 @@ def compress(condition, a, axis=None, out=None):
318317
inds = _nonzero_impl(cond_ary)
319318

320319
return dpnp.get_result_array(
321-
_take_1d_index(a_ary, inds, axis, exec_q, res_usm_type, out), out=out
320+
_take_index(a_ary, inds[0], axis, exec_q, res_usm_type, out=out),
321+
out=out,
322322
)
323323

324324

@@ -1902,8 +1902,8 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
19021902
19031903
"""
19041904

1905-
if mode not in ("wrap", "clip"):
1906-
raise ValueError(f"`mode` must be 'wrap' or 'clip', but got `{mode}`.")
1905+
# sets mode to 0 for "wrap" and 1 for "clip", raises otherwise
1906+
mode = _get_indexing_mode(mode)
19071907

19081908
usm_a = dpnp.get_usm_ndarray(a)
19091909
if not dpnp.is_supported_array_type(indices):
@@ -1913,34 +1913,28 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
19131913
else:
19141914
usm_ind = dpnp.get_usm_ndarray(indices)
19151915

1916+
res_usm_type, exec_q = get_usm_allocations([usm_a, usm_ind])
1917+
19161918
a_ndim = a.ndim
19171919
if axis is None:
1918-
res_shape = usm_ind.shape
1919-
19201920
if a_ndim > 1:
1921-
# dpt.take requires flattened input array
1921+
# flatten input array
19221922
usm_a = dpt.reshape(usm_a, -1)
1923+
axis = 0
19231924
elif a_ndim == 0:
19241925
axis = normalize_axis_index(operator.index(axis), 1)
1925-
res_shape = usm_ind.shape
19261926
else:
19271927
axis = normalize_axis_index(operator.index(axis), a_ndim)
1928-
a_sh = a.shape
1929-
res_shape = a_sh[:axis] + usm_ind.shape + a_sh[axis + 1 :]
1930-
1931-
if usm_ind.ndim != 1:
1932-
# dpt.take supports only 1-D array of indices
1933-
usm_ind = dpt.reshape(usm_ind, -1)
19341928

19351929
if not dpnp.issubdtype(usm_ind.dtype, dpnp.integer):
19361930
# dpt.take supports only integer dtype for array of indices
19371931
usm_ind = dpt.astype(usm_ind, dpnp.intp, copy=False, casting="safe")
19381932

1939-
usm_res = dpt.take(usm_a, usm_ind, axis=axis, mode=mode)
1933+
usm_res = _take_index(
1934+
usm_a, usm_ind, axis, exec_q, res_usm_type, out=out, mode=mode
1935+
)
19401936

1941-
# need to reshape the result if shape of indices array was changed
1942-
result = dpnp.reshape(usm_res, res_shape)
1943-
return dpnp.get_result_array(result, out)
1937+
return dpnp.get_result_array(usm_res, out=out)
19441938

19451939

19461940
def take_along_axis(a, indices, axis, mode="wrap"):

0 commit comments

Comments
 (0)