diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index 455f862e614..4c4576dc996 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -45,7 +45,7 @@ from .dpnp_array import dpnp_array from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call -__all__ = ["argmax", "argmin", "searchsorted", "where"] +__all__ = ["argmax", "argmin", "argwhere", "searchsorted", "where"] def _get_search_res_dt(a, _dtype, out): @@ -244,6 +244,62 @@ def argmin(a, axis=None, out=None, *, keepdims=False): ) +def argwhere(a): + """ + Find the indices of array elements that are non-zero, grouped by element. + + For full documentation refer to :obj:`numpy.argwhere`. + + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array. + + Returns + ------- + out : dpnp.ndarray + Indices of elements that are non-zero. Indices are grouped by element. + This array will have shape ``(N, a.ndim)`` where ``N`` is the number of + non-zero items. + + See Also + -------- + :obj:`dpnp.where` : Returns elements chosen from input arrays depending on + a condition. + :obj:`dpnp.nonzero` : Return the indices of the elements that are non-zero. + + Notes + ----- + ``dpnp.argwhere(a)`` is almost the same as + ``dpnp.transpose(dpnp.nonzero(a))``, but produces a result of the correct + shape for a 0D array. + The output of :obj:`dpnp.argwhere` is not suitable for indexing arrays. + For this purpose use :obj:`dpnp.nonzero` instead. + + Examples + -------- + >>> import dpnp as np + >>> x = np.arange(6).reshape(2, 3) + >>> x + array([[0, 1, 2], + [3, 4, 5]]) + >>> np.argwhere(x > 1) + array([[0, 2], + [1, 0], + [1, 1], + [1, 2]]) + + """ + + dpnp.check_supported_arrays_type(a) + if a.ndim == 0: + # nonzero does not behave well on 0d, so promote to 1d + a = dpnp.atleast_1d(a) + # and then remove the added dimension + return dpnp.argwhere(a)[:, :0] + return dpnp.stack(dpnp.nonzero(a)).T + + def searchsorted(a, v, side="left", sorter=None): """ Find indices where elements should be inserted to maintain order. diff --git a/tests/test_search.py b/tests/test_search.py index e2a7396d3e0..8578657baee 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,7 +1,12 @@ import dpctl.tensor as dpt import numpy import pytest -from numpy.testing import assert_allclose, assert_array_equal, assert_raises +from numpy.testing import ( + assert_allclose, + assert_array_equal, + assert_equal, + assert_raises, +) import dpnp @@ -99,6 +104,64 @@ def test_ndarray(self, axis, keepdims): assert_dtype_allclose(dpnp_res, np_res) +class TestArgwhere: + @pytest.mark.parametrize("dt", get_all_dtypes(no_none=True)) + def test_basic(self, dt): + a = numpy.array([4, 0, 2, 1, 3], dtype=dt) + ia = dpnp.array(a) + + result = dpnp.argwhere(ia) + expected = numpy.argwhere(a) + assert_equal(result, expected) + + @pytest.mark.parametrize("ndim", [0, 1, 2]) + def test_ndim(self, ndim): + # get an nd array with multiple elements in every dimension + a = numpy.empty((2,) * ndim) + + # none + a[...] = False + ia = dpnp.array(a) + + result = dpnp.argwhere(ia) + expected = numpy.argwhere(a) + assert_equal(result, expected) + + # only one + a[...] = False + a.flat[0] = True + ia = dpnp.array(a) + + result = dpnp.argwhere(ia) + expected = numpy.argwhere(a) + assert_equal(result, expected) + + # all but one + a[...] = True + a.flat[0] = False + ia = dpnp.array(a) + + result = dpnp.argwhere(ia) + expected = numpy.argwhere(a) + assert_equal(result, expected) + + # all + a[...] = True + ia = dpnp.array(a) + + result = dpnp.argwhere(ia) + expected = numpy.argwhere(a) + assert_equal(result, expected) + + def test_2d(self): + a = numpy.arange(6).reshape((2, 3)) + ia = dpnp.array(a) + + result = dpnp.argwhere(ia > 1) + expected = numpy.argwhere(a > 1) + assert_array_equal(result, expected) + + class TestWhere: @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_basic(self, dtype): diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 55570edcf71..507e9c8fe54 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -437,6 +437,7 @@ def test_meshgrid(device): pytest.param("argmax", [1.0, 2.0, 4.0, 7.0]), pytest.param("argmin", [1.0, 2.0, 4.0, 7.0]), pytest.param("argsort", [2.0, 1.0, 7.0, 4.0]), + pytest.param("argwhere", [[0, 3], [1, 4], [2, 5]]), pytest.param("cbrt", [1.0, 8.0, 27.0]), pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]), pytest.param("conjugate", [[1.0 + 1.0j, 0.0], [0.0, 1.0 + 1.0j]]), diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index a2ecdda66fa..c65e8f55ac2 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -557,6 +557,7 @@ def test_norm(usm_type, ord, axis): pytest.param("argmax", [1.0, 2.0, 4.0, 7.0]), pytest.param("argmin", [1.0, 2.0, 4.0, 7.0]), pytest.param("argsort", [2.0, 1.0, 7.0, 4.0]), + pytest.param("argwhere", [[0, 3], [1, 4], [2, 5]]), pytest.param("cbrt", [1, 8, 27]), pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]), pytest.param("conjugate", [[1.0 + 1.0j, 0.0], [0.0, 1.0 + 1.0j]]), diff --git a/tests/third_party/cupy/sorting_tests/test_search.py b/tests/third_party/cupy/sorting_tests/test_search.py index bb4093fabb0..2e3288c6bf0 100644 --- a/tests/third_party/cupy/sorting_tests/test_search.py +++ b/tests/third_party/cupy/sorting_tests/test_search.py @@ -399,7 +399,6 @@ def test_flatnonzero(self, xp, dtype): {"array": numpy.empty((0, 2, 0))}, _ids=False, # Do not generate ids from randomly generated params ) -@pytest.mark.skip("argwhere isn't implemented yet") class TestArgwhere: @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() @@ -412,7 +411,6 @@ def test_argwhere(self, xp, dtype): {"value": 0}, {"value": 3}, ) -@pytest.mark.skip("argwhere isn't implemented yet") @testing.with_requires("numpy>=1.18") class TestArgwhereZeroDimension: @testing.for_all_dtypes()