Skip to content

Commit 7ca0d09

Browse files
committed
Remove branching when condition is an array
Also tweaks to docstring
1 parent 84c9533 commit 7ca0d09

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,10 @@ def _take_1d_index(x, inds, axis, q, usm_type, out=None):
226226

227227
def compress(condition, a, axis=None, out=None):
228228
"""
229-
A copy of `a` without the slices along `axis` for which `condition` is
230-
``False``.
229+
Return selected slices of an array along given axis.
230+
231+
A slice of `a` is returned for each index along `axis` where `condition`
232+
is ``True``.
231233
232234
For full documentation refer to :obj:`numpy.choose`.
233235
@@ -298,15 +300,13 @@ def compress(condition, a, axis=None, out=None):
298300
axis = normalize_axis_index(operator.index(axis), a.ndim)
299301

300302
a_ary = dpnp.get_usm_ndarray(a)
301-
if not dpnp.is_supported_array_type(condition):
302-
cond_ary = dpnp.as_usm_ndarray(
303-
condition,
304-
dtype=dpnp.bool,
305-
usm_type=a_ary.usm_type,
306-
sycl_queue=a_ary.sycl_queue,
307-
)
308-
else:
309-
cond_ary = dpnp.get_usm_ndarray(condition)
303+
cond_ary = dpnp.as_usm_ndarray(
304+
condition,
305+
dtype=dpnp.bool,
306+
usm_type=a_ary.usm_type,
307+
sycl_queue=a_ary.sycl_queue,
308+
)
309+
310310
if not cond_ary.ndim == 1:
311311
raise ValueError(
312312
"`condition` must be a 1-D array or un-nested sequence"

0 commit comments

Comments
 (0)