Skip to content

Commit 762d477

Browse files
authored
Add implementation of dpnp.argwhere (#2000)
* Add implementation of dpnp.argwhere() * Added new tests and updated existing ones * Applied pre-commit hooks * Fix broken link in description
1 parent 256ce60 commit 762d477

File tree

5 files changed

+123
-4
lines changed

5 files changed

+123
-4
lines changed

dpnp/dpnp_iface_searching.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from .dpnp_array import dpnp_array
4646
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
4747

48-
__all__ = ["argmax", "argmin", "searchsorted", "where"]
48+
__all__ = ["argmax", "argmin", "argwhere", "searchsorted", "where"]
4949

5050

5151
def _get_search_res_dt(a, _dtype, out):
@@ -244,6 +244,62 @@ def argmin(a, axis=None, out=None, *, keepdims=False):
244244
)
245245

246246

247+
def argwhere(a):
248+
"""
249+
Find the indices of array elements that are non-zero, grouped by element.
250+
251+
For full documentation refer to :obj:`numpy.argwhere`.
252+
253+
Parameters
254+
----------
255+
a : {dpnp.ndarray, usm_ndarray}
256+
Input array.
257+
258+
Returns
259+
-------
260+
out : dpnp.ndarray
261+
Indices of elements that are non-zero. Indices are grouped by element.
262+
This array will have shape ``(N, a.ndim)`` where ``N`` is the number of
263+
non-zero items.
264+
265+
See Also
266+
--------
267+
:obj:`dpnp.where` : Returns elements chosen from input arrays depending on
268+
a condition.
269+
:obj:`dpnp.nonzero` : Return the indices of the elements that are non-zero.
270+
271+
Notes
272+
-----
273+
``dpnp.argwhere(a)`` is almost the same as
274+
``dpnp.transpose(dpnp.nonzero(a))``, but produces a result of the correct
275+
shape for a 0D array.
276+
The output of :obj:`dpnp.argwhere` is not suitable for indexing arrays.
277+
For this purpose use :obj:`dpnp.nonzero` instead.
278+
279+
Examples
280+
--------
281+
>>> import dpnp as np
282+
>>> x = np.arange(6).reshape(2, 3)
283+
>>> x
284+
array([[0, 1, 2],
285+
[3, 4, 5]])
286+
>>> np.argwhere(x > 1)
287+
array([[0, 2],
288+
[1, 0],
289+
[1, 1],
290+
[1, 2]])
291+
292+
"""
293+
294+
dpnp.check_supported_arrays_type(a)
295+
if a.ndim == 0:
296+
# nonzero does not behave well on 0d, so promote to 1d
297+
a = dpnp.atleast_1d(a)
298+
# and then remove the added dimension
299+
return dpnp.argwhere(a)[:, :0]
300+
return dpnp.stack(dpnp.nonzero(a)).T
301+
302+
247303
def searchsorted(a, v, side="left", sorter=None):
248304
"""
249305
Find indices where elements should be inserted to maintain order.

tests/test_search.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import dpctl.tensor as dpt
22
import numpy
33
import pytest
4-
from numpy.testing import assert_allclose, assert_array_equal, assert_raises
4+
from numpy.testing import (
5+
assert_allclose,
6+
assert_array_equal,
7+
assert_equal,
8+
assert_raises,
9+
)
510

611
import dpnp
712

@@ -99,6 +104,64 @@ def test_ndarray(self, axis, keepdims):
99104
assert_dtype_allclose(dpnp_res, np_res)
100105

101106

107+
class TestArgwhere:
108+
@pytest.mark.parametrize("dt", get_all_dtypes(no_none=True))
109+
def test_basic(self, dt):
110+
a = numpy.array([4, 0, 2, 1, 3], dtype=dt)
111+
ia = dpnp.array(a)
112+
113+
result = dpnp.argwhere(ia)
114+
expected = numpy.argwhere(a)
115+
assert_equal(result, expected)
116+
117+
@pytest.mark.parametrize("ndim", [0, 1, 2])
118+
def test_ndim(self, ndim):
119+
# get an nd array with multiple elements in every dimension
120+
a = numpy.empty((2,) * ndim)
121+
122+
# none
123+
a[...] = False
124+
ia = dpnp.array(a)
125+
126+
result = dpnp.argwhere(ia)
127+
expected = numpy.argwhere(a)
128+
assert_equal(result, expected)
129+
130+
# only one
131+
a[...] = False
132+
a.flat[0] = True
133+
ia = dpnp.array(a)
134+
135+
result = dpnp.argwhere(ia)
136+
expected = numpy.argwhere(a)
137+
assert_equal(result, expected)
138+
139+
# all but one
140+
a[...] = True
141+
a.flat[0] = False
142+
ia = dpnp.array(a)
143+
144+
result = dpnp.argwhere(ia)
145+
expected = numpy.argwhere(a)
146+
assert_equal(result, expected)
147+
148+
# all
149+
a[...] = True
150+
ia = dpnp.array(a)
151+
152+
result = dpnp.argwhere(ia)
153+
expected = numpy.argwhere(a)
154+
assert_equal(result, expected)
155+
156+
def test_2d(self):
157+
a = numpy.arange(6).reshape((2, 3))
158+
ia = dpnp.array(a)
159+
160+
result = dpnp.argwhere(ia > 1)
161+
expected = numpy.argwhere(a > 1)
162+
assert_array_equal(result, expected)
163+
164+
102165
class TestWhere:
103166
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
104167
def test_basic(self, dtype):

tests/test_sycl_queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def test_meshgrid(device):
437437
pytest.param("argmax", [1.0, 2.0, 4.0, 7.0]),
438438
pytest.param("argmin", [1.0, 2.0, 4.0, 7.0]),
439439
pytest.param("argsort", [2.0, 1.0, 7.0, 4.0]),
440+
pytest.param("argwhere", [[0, 3], [1, 4], [2, 5]]),
440441
pytest.param("cbrt", [1.0, 8.0, 27.0]),
441442
pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
442443
pytest.param("conjugate", [[1.0 + 1.0j, 0.0], [0.0, 1.0 + 1.0j]]),

tests/test_usm_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ def test_norm(usm_type, ord, axis):
557557
pytest.param("argmax", [1.0, 2.0, 4.0, 7.0]),
558558
pytest.param("argmin", [1.0, 2.0, 4.0, 7.0]),
559559
pytest.param("argsort", [2.0, 1.0, 7.0, 4.0]),
560+
pytest.param("argwhere", [[0, 3], [1, 4], [2, 5]]),
560561
pytest.param("cbrt", [1, 8, 27]),
561562
pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
562563
pytest.param("conjugate", [[1.0 + 1.0j, 0.0], [0.0, 1.0 + 1.0j]]),

tests/third_party/cupy/sorting_tests/test_search.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,6 @@ def test_flatnonzero(self, xp, dtype):
399399
{"array": numpy.empty((0, 2, 0))},
400400
_ids=False, # Do not generate ids from randomly generated params
401401
)
402-
@pytest.mark.skip("argwhere isn't implemented yet")
403402
class TestArgwhere:
404403
@testing.for_all_dtypes()
405404
@testing.numpy_cupy_array_equal()
@@ -412,7 +411,6 @@ def test_argwhere(self, xp, dtype):
412411
{"value": 0},
413412
{"value": 3},
414413
)
415-
@pytest.mark.skip("argwhere isn't implemented yet")
416414
@testing.with_requires("numpy>=1.18")
417415
class TestArgwhereZeroDimension:
418416
@testing.for_all_dtypes()

0 commit comments

Comments
 (0)