Skip to content

Add implementation of dpnp.argwhere #2000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion dpnp/dpnp_iface_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from .dpnp_array import dpnp_array
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call

__all__ = ["argmax", "argmin", "searchsorted", "where"]
__all__ = ["argmax", "argmin", "argwhere", "searchsorted", "where"]


def _get_search_res_dt(a, _dtype, out):
Expand Down Expand Up @@ -244,6 +244,62 @@ def argmin(a, axis=None, out=None, *, keepdims=False):
)


def argwhere(a):
"""
Find the indices of array elements that are non-zero, grouped by element.

For full documentation refer to :obj:`numpy.argwhere`.

Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Input array.

Returns
-------
out : dpnp.ndarray
Indices of elements that are non-zero. Indices are grouped by element.
This array will have shape ``(N, a.ndim)`` where ``N`` is the number of
non-zero items.

See Also
--------
:obj:`dpnp.where` : Returns elements chosen from input arrays depending on
a condition.
:obj:`dpnp.nonzero` : Return the indices of the elements that are non-zero.

Notes
-----
``dpnp.argwhere(a)`` is almost the same as
``dpnp.transpose(dpnp.nonzero(a))``, but produces a result of the correct
shape for a 0D array.
The output of :obj:`dpnp.argwhere` is not suitable for indexing arrays.
For this purpose use :obj:`dpnp.nonzero` instead.

Examples
--------
>>> import dpnp as np
>>> x = np.arange(6).reshape(2, 3)
>>> x
array([[0, 1, 2],
[3, 4, 5]])
>>> np.argwhere(x > 1)
array([[0, 2],
[1, 0],
[1, 1],
[1, 2]])

"""

dpnp.check_supported_arrays_type(a)
if a.ndim == 0:
# nonzero does not behave well on 0d, so promote to 1d
a = dpnp.atleast_1d(a)
# and then remove the added dimension
return dpnp.argwhere(a)[:, :0]
return dpnp.stack(dpnp.nonzero(a)).T


def searchsorted(a, v, side="left", sorter=None):
"""
Find indices where elements should be inserted to maintain order.
Expand Down
65 changes: 64 additions & 1 deletion tests/test_search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import dpctl.tensor as dpt
import numpy
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_raises
from numpy.testing import (
assert_allclose,
assert_array_equal,
assert_equal,
assert_raises,
)

import dpnp

Expand Down Expand Up @@ -99,6 +104,64 @@ def test_ndarray(self, axis, keepdims):
assert_dtype_allclose(dpnp_res, np_res)


class TestArgwhere:
@pytest.mark.parametrize("dt", get_all_dtypes(no_none=True))
def test_basic(self, dt):
a = numpy.array([4, 0, 2, 1, 3], dtype=dt)
ia = dpnp.array(a)

result = dpnp.argwhere(ia)
expected = numpy.argwhere(a)
assert_equal(result, expected)

@pytest.mark.parametrize("ndim", [0, 1, 2])
def test_ndim(self, ndim):
# get an nd array with multiple elements in every dimension
a = numpy.empty((2,) * ndim)

# none
a[...] = False
ia = dpnp.array(a)

result = dpnp.argwhere(ia)
expected = numpy.argwhere(a)
assert_equal(result, expected)

# only one
a[...] = False
a.flat[0] = True
ia = dpnp.array(a)

result = dpnp.argwhere(ia)
expected = numpy.argwhere(a)
assert_equal(result, expected)

# all but one
a[...] = True
a.flat[0] = False
ia = dpnp.array(a)

result = dpnp.argwhere(ia)
expected = numpy.argwhere(a)
assert_equal(result, expected)

# all
a[...] = True
ia = dpnp.array(a)

result = dpnp.argwhere(ia)
expected = numpy.argwhere(a)
assert_equal(result, expected)

def test_2d(self):
a = numpy.arange(6).reshape((2, 3))
ia = dpnp.array(a)

result = dpnp.argwhere(ia > 1)
expected = numpy.argwhere(a > 1)
assert_array_equal(result, expected)


class TestWhere:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
def test_basic(self, dtype):
Expand Down
1 change: 1 addition & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def test_meshgrid(device):
pytest.param("argmax", [1.0, 2.0, 4.0, 7.0]),
pytest.param("argmin", [1.0, 2.0, 4.0, 7.0]),
pytest.param("argsort", [2.0, 1.0, 7.0, 4.0]),
pytest.param("argwhere", [[0, 3], [1, 4], [2, 5]]),
pytest.param("cbrt", [1.0, 8.0, 27.0]),
pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
pytest.param("conjugate", [[1.0 + 1.0j, 0.0], [0.0, 1.0 + 1.0j]]),
Expand Down
1 change: 1 addition & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def test_norm(usm_type, ord, axis):
pytest.param("argmax", [1.0, 2.0, 4.0, 7.0]),
pytest.param("argmin", [1.0, 2.0, 4.0, 7.0]),
pytest.param("argsort", [2.0, 1.0, 7.0, 4.0]),
pytest.param("argwhere", [[0, 3], [1, 4], [2, 5]]),
pytest.param("cbrt", [1, 8, 27]),
pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
pytest.param("conjugate", [[1.0 + 1.0j, 0.0], [0.0, 1.0 + 1.0j]]),
Expand Down
2 changes: 0 additions & 2 deletions tests/third_party/cupy/sorting_tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@ def test_flatnonzero(self, xp, dtype):
{"array": numpy.empty((0, 2, 0))},
_ids=False, # Do not generate ids from randomly generated params
)
@pytest.mark.skip("argwhere isn't implemented yet")
class TestArgwhere:
@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
Expand All @@ -412,7 +411,6 @@ def test_argwhere(self, xp, dtype):
{"value": 0},
{"value": 3},
)
@pytest.mark.skip("argwhere isn't implemented yet")
@testing.with_requires("numpy>=1.18")
class TestArgwhereZeroDimension:
@testing.for_all_dtypes()
Expand Down
Loading