From 0955458dda9eb05026b08aeacfcb1fbfc918a2c8 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Mon, 19 Aug 2024 10:08:58 +0200 Subject: [PATCH 1/4] Add implementation of dpnp.argwhere() --- dpnp/dpnp_iface_searching.py | 58 +++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index 455f862e614..e356bb96ffe 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -45,7 +45,7 @@ from .dpnp_array import dpnp_array from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call -__all__ = ["argmax", "argmin", "searchsorted", "where"] +__all__ = ["argmax", "argmin", "argwhere", "searchsorted", "where"] def _get_search_res_dt(a, _dtype, out): @@ -244,6 +244,62 @@ def argmin(a, axis=None, out=None, *, keepdims=False): ) +def argwhere(a): + """ + Find the indices of array elements that are non-zero, grouped by element. + + For full documentation refer to :obj:`numpy.argwhere`. + + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array. + + Returns + ------- + out : dpnp.ndarray + Indices of elements that are non-zero. Indices are grouped by element. + This array will have shape ``(N, a.ndim)`` where ``N`` is the number of + non-zero items. + + See Also + -------- + :obj:`dpnp.where` : Returns elements chosen from input arrays depending on + a condition. + :obj:`dpnp.nonzero` : Return the indices of the elements that are non-zero. + + Notes + ----- + ``dpnp.argwhere(a)`` is almost the same as + ``dpnp.transpose(dpnp.nonzero(a))``, but produces a result of the correct + shape for a 0D array. + The output of :obj:`dpnp.argwhere` is not suitable for indexing arrays. + For this purpose use :obj:`dpnp.nonzero(a)` instead. + + Examples + -------- + >>> import dpnp as np + >>> x = np.arange(6).reshape(2, 3) + >>> x + array([[0, 1, 2], + [3, 4, 5]]) + >>> np.argwhere(x > 1) + array([[0, 2], + [1, 0], + [1, 1], + [1, 2]]) + + """ + + dpnp.check_supported_arrays_type(a) + if a.ndim == 0: + # nonzero does not behave well on 0d, so promote to 1d + a = dpnp.atleast_1d(a) + # and then remove the added dimension + return dpnp.argwhere(a)[:, :0] + return dpnp.stack(dpnp.nonzero(a)).T + + def searchsorted(a, v, side="left", sorter=None): """ Find indices where elements should be inserted to maintain order. From 7fd08055a17247a6435ca7ee41849c08a2bd7619 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Mon, 19 Aug 2024 10:09:21 +0200 Subject: [PATCH 2/4] Added new tests and updated existing ones --- tests/test_search.py | 60 ++++++++++++++++++- tests/test_sycl_queue.py | 1 + tests/test_usm_type.py | 1 + .../cupy/sorting_tests/test_search.py | 2 - 4 files changed, 61 insertions(+), 3 deletions(-) diff --git a/tests/test_search.py b/tests/test_search.py index e2a7396d3e0..2fd9cb02061 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,7 +1,7 @@ import dpctl.tensor as dpt import numpy import pytest -from numpy.testing import assert_allclose, assert_array_equal, assert_raises +from numpy.testing import assert_allclose, assert_array_equal, assert_equal, assert_raises import dpnp @@ -99,6 +99,64 @@ def test_ndarray(self, axis, keepdims): assert_dtype_allclose(dpnp_res, np_res) +class TestArgwhere: + @pytest.mark.parametrize("dt", get_all_dtypes(no_none=True)) + def test_basic(self, dt): + a = numpy.array([4, 0, 2, 1, 3], dtype=dt) + ia = dpnp.array(a) + + result = dpnp.argwhere(ia) + expected = numpy.argwhere(a) + assert_equal(result, expected) + + @pytest.mark.parametrize('ndim', [0, 1, 2]) + def test_ndim(self, ndim): + # get an nd array with multiple elements in every dimension + a = numpy.empty((2,)*ndim) + + # none + a[...] = False + ia = dpnp.array(a) + + result = dpnp.argwhere(ia) + expected = numpy.argwhere(a) + assert_equal(result, expected) + + # only one + a[...] = False + a.flat[0] = True + ia = dpnp.array(a) + + result = dpnp.argwhere(ia) + expected = numpy.argwhere(a) + assert_equal(result, expected) + + # all but one + a[...] = True + a.flat[0] = False + ia = dpnp.array(a) + + result = dpnp.argwhere(ia) + expected = numpy.argwhere(a) + assert_equal(result, expected) + + # all + a[...] = True + ia = dpnp.array(a) + + result = dpnp.argwhere(ia) + expected = numpy.argwhere(a) + assert_equal(result, expected) + + def test_2d(self): + a = numpy.arange(6).reshape((2, 3)) + ia = dpnp.array(a) + + result = dpnp.argwhere(ia > 1) + expected = numpy.argwhere(a > 1) + assert_array_equal(result, expected) + + class TestWhere: @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_basic(self, dtype): diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 55570edcf71..507e9c8fe54 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -437,6 +437,7 @@ def test_meshgrid(device): pytest.param("argmax", [1.0, 2.0, 4.0, 7.0]), pytest.param("argmin", [1.0, 2.0, 4.0, 7.0]), pytest.param("argsort", [2.0, 1.0, 7.0, 4.0]), + pytest.param("argwhere", [[0, 3], [1, 4], [2, 5]]), pytest.param("cbrt", [1.0, 8.0, 27.0]), pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]), pytest.param("conjugate", [[1.0 + 1.0j, 0.0], [0.0, 1.0 + 1.0j]]), diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index a2ecdda66fa..c65e8f55ac2 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -557,6 +557,7 @@ def test_norm(usm_type, ord, axis): pytest.param("argmax", [1.0, 2.0, 4.0, 7.0]), pytest.param("argmin", [1.0, 2.0, 4.0, 7.0]), pytest.param("argsort", [2.0, 1.0, 7.0, 4.0]), + pytest.param("argwhere", [[0, 3], [1, 4], [2, 5]]), pytest.param("cbrt", [1, 8, 27]), pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]), pytest.param("conjugate", [[1.0 + 1.0j, 0.0], [0.0, 1.0 + 1.0j]]), diff --git a/tests/third_party/cupy/sorting_tests/test_search.py b/tests/third_party/cupy/sorting_tests/test_search.py index bb4093fabb0..2e3288c6bf0 100644 --- a/tests/third_party/cupy/sorting_tests/test_search.py +++ b/tests/third_party/cupy/sorting_tests/test_search.py @@ -399,7 +399,6 @@ def test_flatnonzero(self, xp, dtype): {"array": numpy.empty((0, 2, 0))}, _ids=False, # Do not generate ids from randomly generated params ) -@pytest.mark.skip("argwhere isn't implemented yet") class TestArgwhere: @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() @@ -412,7 +411,6 @@ def test_argwhere(self, xp, dtype): {"value": 0}, {"value": 3}, ) -@pytest.mark.skip("argwhere isn't implemented yet") @testing.with_requires("numpy>=1.18") class TestArgwhereZeroDimension: @testing.for_all_dtypes() From 9e9d9140bd00d740070c80af504d813f8894bf3c Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Mon, 19 Aug 2024 10:10:26 +0200 Subject: [PATCH 3/4] Applied pre-commit hooks --- tests/test_search.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_search.py b/tests/test_search.py index 2fd9cb02061..8578657baee 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,7 +1,12 @@ import dpctl.tensor as dpt import numpy import pytest -from numpy.testing import assert_allclose, assert_array_equal, assert_equal, assert_raises +from numpy.testing import ( + assert_allclose, + assert_array_equal, + assert_equal, + assert_raises, +) import dpnp @@ -109,10 +114,10 @@ def test_basic(self, dt): expected = numpy.argwhere(a) assert_equal(result, expected) - @pytest.mark.parametrize('ndim', [0, 1, 2]) + @pytest.mark.parametrize("ndim", [0, 1, 2]) def test_ndim(self, ndim): # get an nd array with multiple elements in every dimension - a = numpy.empty((2,)*ndim) + a = numpy.empty((2,) * ndim) # none a[...] = False From daf40dd2d42f611073c42fe94617b3f1409c95a0 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Mon, 19 Aug 2024 10:51:36 +0200 Subject: [PATCH 4/4] Fix broken link in description --- dpnp/dpnp_iface_searching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index e356bb96ffe..4c4576dc996 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -274,7 +274,7 @@ def argwhere(a): ``dpnp.transpose(dpnp.nonzero(a))``, but produces a result of the correct shape for a 0D array. The output of :obj:`dpnp.argwhere` is not suitable for indexing arrays. - For this purpose use :obj:`dpnp.nonzero(a)` instead. + For this purpose use :obj:`dpnp.nonzero` instead. Examples --------