46
46
import dpctl .utils as dpu
47
47
import numpy
48
48
from dpctl .tensor ._copy_utils import _nonzero_impl
49
+ from dpctl .tensor ._indexing_functions import _get_indexing_mode
49
50
from dpctl .tensor ._numpy_helper import normalize_axis_index
50
51
51
52
import dpnp
@@ -160,14 +161,13 @@ def choose(x1, choices, out=None, mode="raise"):
160
161
return call_origin (numpy .choose , x1 , choices , out , mode )
161
162
162
163
163
- def _take_1d_index (x , inds , axis , q , usm_type , out = None ):
164
+ def _take_index (x , inds , axis , q , usm_type , out = None , mode = 0 ):
164
165
# arg validation assumed done by caller
165
166
x_sh = x .shape
166
- ind0 = inds [0 ]
167
167
axis_end = axis + 1
168
- if 0 in x_sh [axis :axis_end ] and ind0 .size != 0 :
168
+ if 0 in x_sh [axis :axis_end ] and inds .size != 0 :
169
169
raise IndexError ("cannot take non-empty indices from an empty axis" )
170
- res_sh = x_sh [:axis ] + ind0 .shape + x_sh [axis_end :]
170
+ res_sh = x_sh [:axis ] + inds .shape + x_sh [axis_end :]
171
171
172
172
orig_out = None
173
173
if out is not None :
@@ -201,13 +201,12 @@ def _take_1d_index(x, inds, axis, q, usm_type, out=None):
201
201
_manager = dpu .SequentialOrderManager [q ]
202
202
dep_evs = _manager .submitted_events
203
203
204
- # always use wrap mode here
205
204
h_ev , take_ev = ti ._take (
206
205
src = x ,
207
- ind = inds ,
206
+ ind = ( inds ,) ,
208
207
dst = out ,
209
208
axis_start = axis ,
210
- mode = 0 ,
209
+ mode = mode ,
211
210
sycl_queue = q ,
212
211
depends = dep_evs ,
213
212
)
@@ -318,7 +317,8 @@ def compress(condition, a, axis=None, out=None):
318
317
inds = _nonzero_impl (cond_ary )
319
318
320
319
return dpnp .get_result_array (
321
- _take_1d_index (a_ary , inds , axis , exec_q , res_usm_type , out ), out = out
320
+ _take_index (a_ary , inds [0 ], axis , exec_q , res_usm_type , out = out ),
321
+ out = out ,
322
322
)
323
323
324
324
@@ -1902,8 +1902,8 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
1902
1902
1903
1903
"""
1904
1904
1905
- if mode not in ( "wrap" , "clip" ):
1906
- raise ValueError ( f"` mode` must be 'wrap' or 'clip', but got ` { mode } `." )
1905
+ # sets mode to 0 for "wrap" and 1 for "clip", raises otherwise
1906
+ mode = _get_indexing_mode ( mode )
1907
1907
1908
1908
usm_a = dpnp .get_usm_ndarray (a )
1909
1909
if not dpnp .is_supported_array_type (indices ):
@@ -1913,34 +1913,28 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
1913
1913
else :
1914
1914
usm_ind = dpnp .get_usm_ndarray (indices )
1915
1915
1916
+ res_usm_type , exec_q = get_usm_allocations ([usm_a , usm_ind ])
1917
+
1916
1918
a_ndim = a .ndim
1917
1919
if axis is None :
1918
- res_shape = usm_ind .shape
1919
-
1920
1920
if a_ndim > 1 :
1921
- # dpt.take requires flattened input array
1921
+ # flatten input array
1922
1922
usm_a = dpt .reshape (usm_a , - 1 )
1923
+ axis = 0
1923
1924
elif a_ndim == 0 :
1924
1925
axis = normalize_axis_index (operator .index (axis ), 1 )
1925
- res_shape = usm_ind .shape
1926
1926
else :
1927
1927
axis = normalize_axis_index (operator .index (axis ), a_ndim )
1928
- a_sh = a .shape
1929
- res_shape = a_sh [:axis ] + usm_ind .shape + a_sh [axis + 1 :]
1930
-
1931
- if usm_ind .ndim != 1 :
1932
- # dpt.take supports only 1-D array of indices
1933
- usm_ind = dpt .reshape (usm_ind , - 1 )
1934
1928
1935
1929
if not dpnp .issubdtype (usm_ind .dtype , dpnp .integer ):
1936
1930
# dpt.take supports only integer dtype for array of indices
1937
1931
usm_ind = dpt .astype (usm_ind , dpnp .intp , copy = False , casting = "safe" )
1938
1932
1939
- usm_res = dpt .take (usm_a , usm_ind , axis = axis , mode = mode )
1933
+ usm_res = _take_index (
1934
+ usm_a , usm_ind , axis , exec_q , res_usm_type , out = out , mode = mode
1935
+ )
1940
1936
1941
- # need to reshape the result if shape of indices array was changed
1942
- result = dpnp .reshape (usm_res , res_shape )
1943
- return dpnp .get_result_array (result , out )
1937
+ return dpnp .get_result_array (usm_res , out = out )
1944
1938
1945
1939
1946
1940
def take_along_axis (a , indices , axis , mode = "wrap" ):
0 commit comments