Skip to content

Update documentation and clean up implementation of indexing functions #1913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9c64980
Remove limitations from dpnp.take implementation
antonwolfy Jul 5, 2024
3c7cebc
Add more test to cover specail cases and increase code coverage
antonwolfy Jul 5, 2024
a89136e
Applied pre-commit hook
antonwolfy Jul 5, 2024
f054818
Corrected test_over_index
antonwolfy Jul 5, 2024
aa3d9ba
Merge branch 'master' into impl-take
antonwolfy Jul 7, 2024
7e87cef
Update docsctrings with resolving typos
antonwolfy Jul 7, 2024
cbb7188
Use dpnp.reshape() to change shape and create dpnp array from usm_nda…
antonwolfy Jul 9, 2024
c170c19
Remove limitations from dpnp.place implementation
antonwolfy Jul 7, 2024
2d87bfc
Update relating tests
antonwolfy Jul 7, 2024
b527f77
Roll back changed in dpnp.vander
antonwolfy Jul 7, 2024
dc5e776
Remove data sync at the end of function
antonwolfy Jul 9, 2024
5bac010
Update indexing functions
antonwolfy Jul 7, 2024
f53ab2d
Add missing test scenario
antonwolfy Jul 7, 2024
157de7f
Updated docstring in put_along_axis() and take_along_axis() and rolle…
antonwolfy Jul 9, 2024
d0915b1
Remove data synchronization for dpnp.put()
antonwolfy Jul 9, 2024
e300914
Remove data synchronization for dpnp.nonzero()
antonwolfy Jul 9, 2024
36f2972
Remove data synchronization for dpnp.indices()
antonwolfy Jul 9, 2024
fef9f51
Remove data synchronization for dpnp.extract()
antonwolfy Jul 9, 2024
33b1707
Merge branch 'master' into adopt-indexing-if-to-asynchronous-dpctl-ex…
antonwolfy Jul 10, 2024
23fb47f
Remove data sync in dpnp.get_result_array()
antonwolfy Jul 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 87 additions & 48 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):

See also
--------
:obj:`diag_indices_from` : Return the indices to access the main
diagonal of an n-dimensional array.
:obj:`dpnp.diag_indices_from` : Return the indices to access the main
diagonal of an n-dimensional array.

Examples
--------
Expand Down Expand Up @@ -276,7 +276,7 @@ def diag_indices_from(arr):
Parameters
----------
arr : {dpnp.ndarray, usm_ndarray}
Array at least 2-D
Array at least 2-D.

Returns
-------
Expand All @@ -285,8 +285,8 @@ def diag_indices_from(arr):

See also
--------
:obj:`diag_indices` : Return the indices to access the main
diagonal of an array.
:obj:`dpnp.diag_indices` : Return the indices to access the main diagonal
of an array.

Examples
--------
Expand Down Expand Up @@ -570,14 +570,17 @@ def extract(condition, a):

usm_res = dpt.extract(usm_cond, usm_a)

dpnp.synchronize_array_data(usm_res)
return dpnp_array._create_from_usm_ndarray(usm_res)


def fill_diagonal(a, val, wrap=False):
"""
Fill the main diagonal of the given array of any dimensionality.

For an array `a` with ``a.ndim >= 2``, the diagonal is the list of values
``a[i, ..., i]`` with indices ``i`` all identical. This function modifies
the input array in-place without returning a value.

For full documentation refer to :obj:`numpy.fill_diagonal`.

Parameters
Expand Down Expand Up @@ -678,11 +681,12 @@ def fill_diagonal(a, val, wrap=False):

"""

dpnp.check_supported_arrays_type(a)
dpnp.check_supported_arrays_type(val, scalar_type=True, all_scalars=True)
usm_a = dpnp.get_usm_ndarray(a)
usm_val = dpnp.get_usm_ndarray_or_scalar(val)

if a.ndim < 2:
raise ValueError("array must be at least 2-d")

end = a.size
if a.ndim == 2:
step = a.shape[1] + 1
Expand All @@ -695,18 +699,21 @@ def fill_diagonal(a, val, wrap=False):

# TODO: implement flatiter for slice key
# a.flat[:end:step] = val
# but need to consider use case when `a` is usm_ndarray also
a_sh = a.shape
tmp_a = dpnp.ravel(a)
if dpnp.isscalar(val):
tmp_a[:end:step] = val
tmp_a = dpt.reshape(usm_a, -1)
if dpnp.isscalar(usm_val):
tmp_a[:end:step] = usm_val
else:
flat_val = val.ravel()
usm_val = dpt.reshape(usm_val, -1)

# Setitem can work only if index size equal val size.
# Using loop for general case without dependencies of val size.
for i in range(0, flat_val.size):
tmp_a[step * i : end : step * (i + 1)] = flat_val[i]
tmp_a = dpnp.reshape(tmp_a, a_sh)
a[:] = tmp_a
for i in range(0, usm_val.size):
tmp_a[step * i : end : step * (i + 1)] = usm_val[i]

tmp_a = dpt.reshape(tmp_a, a_sh)
usm_a[:] = tmp_a


def indices(
Expand Down Expand Up @@ -758,6 +765,13 @@ def indices(
with grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1)
with dimensions[i] in the i-th place.

See Also
--------
:obj:`dpnp.mgrid` : Return a dense multi-dimensional “meshgrid”.
:obj:`dpnp.ogrid` : Return an open multi-dimensional “meshgrid”.
:obj:`dpnp.meshgrid` : Return a tuple of coordinate matrices from
coordinate vectors.

Examples
--------
>>> import dpnp as np
Expand Down Expand Up @@ -800,6 +814,7 @@ def indices(
dimensions = tuple(dimensions)
n = len(dimensions)
shape = (1,) * n

if sparse:
res = ()
else:
Expand All @@ -810,6 +825,7 @@ def indices(
usm_type=usm_type,
sycl_queue=sycl_queue,
)

for i, dim in enumerate(dimensions):
idx = dpnp.arange(
dim,
Expand All @@ -818,6 +834,7 @@ def indices(
usm_type=usm_type,
sycl_queue=sycl_queue,
).reshape(shape[:i] + (dim,) + shape[i + 1 :])

if sparse:
res = res + (idx,)
else:
Expand Down Expand Up @@ -927,10 +944,12 @@ def nonzero(a):
"""
Return the indices of the elements that are non-zero.

Returns a tuple of arrays, one for each dimension of `a`,
containing the indices of the non-zero elements in that
dimension. The values in `a` are always tested and returned in
row-major, C-style order.
Returns a tuple of arrays, one for each dimension of `a`, containing
the indices of the non-zero elements in that dimension. The values in `a`
are always tested and returned in row-major, C-style order.

To group the indices by element, rather than dimension, use
:obj:`dpnp.argwhere`, which returns a row for each non-zero element.

For full documentation refer to :obj:`numpy.nonzero`.

Expand Down Expand Up @@ -1005,9 +1024,9 @@ def nonzero(a):

"""

usx_a = dpnp.get_usm_ndarray(a)
usm_a = dpnp.get_usm_ndarray(a)
return tuple(
dpnp_array._create_from_usm_ndarray(y) for y in dpt.nonzero(usx_a)
dpnp_array._create_from_usm_ndarray(y) for y in dpt.nonzero(usm_a)
)


Expand Down Expand Up @@ -1139,47 +1158,60 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"):

"""

dpnp.check_supported_arrays_type(a)

if not dpnp.is_supported_array_type(ind):
ind = dpnp.asarray(
ind, dtype=dpnp.intp, sycl_queue=a.sycl_queue, usm_type=a.usm_type
)
elif not dpnp.issubdtype(ind.dtype, dpnp.integer):
ind = dpnp.astype(ind, dtype=dpnp.intp, casting="safe")
ind = dpnp.ravel(ind)

if not dpnp.is_supported_array_type(v):
v = dpnp.asarray(
v, dtype=a.dtype, sycl_queue=a.sycl_queue, usm_type=a.usm_type
)
if v.size == 0:
return
usm_a = dpnp.get_usm_ndarray(a)

if not (axis is None or isinstance(axis, int)):
raise TypeError(f"`axis` must be of integer type, got {type(axis)}")

in_a = a
if axis is None and a.ndim > 1:
a = dpnp.ravel(in_a)

if mode not in ("wrap", "clip"):
raise ValueError(
f"clipmode must be one of 'clip' or 'wrap' (got '{mode}')"
)

usm_a = dpnp.get_usm_ndarray(a)
usm_ind = dpnp.get_usm_ndarray(ind)
usm_v = dpnp.get_usm_ndarray(v)
usm_v = dpnp.as_usm_ndarray(
v,
dtype=usm_a.dtype,
usm_type=usm_a.usm_type,
sycl_queue=usm_a.sycl_queue,
)
if usm_v.size == 0:
return

usm_ind = dpnp.as_usm_ndarray(
ind,
dtype=dpnp.intp,
usm_type=usm_a.usm_type,
sycl_queue=usm_a.sycl_queue,
)

if usm_ind.ndim != 1:
# dpt.put supports only 1-D array of indices
usm_ind = dpt.reshape(usm_ind, -1, copy=False)

if not dpnp.issubdtype(usm_ind.dtype, dpnp.integer):
# dpt.put supports only integer dtype for array of indices
usm_ind = dpt.astype(usm_ind, dpnp.intp, casting="safe")

in_usm_a = usm_a
if axis is None and usm_a.ndim > 1:
usm_a = dpt.reshape(usm_a, -1)

dpt.put(usm_a, usm_ind, usm_v, axis=axis, mode=mode)
if in_a is not a:
in_a[:] = a.reshape(in_a.shape, copy=False)
if in_usm_a._pointer != usm_a._pointer: # pylint: disable=protected-access
in_usm_a[:] = dpt.reshape(usm_a, in_usm_a.shape, copy=False)


def put_along_axis(a, ind, values, axis):
"""
Put values into the destination array by matching 1d index and data slices.

This iterates over matching 1d slices oriented along the specified axis in
the index and data arrays, and uses the former to place values into the
latter. These slices can be different lengths.

Functions returning an index along an `axis`, like :obj:`dpnp.argsort` and
:obj:`dpnp.argpartition`, produce suitable indices for this function.

For full documentation refer to :obj:`numpy.put_along_axis`.

Parameters
Expand Down Expand Up @@ -1415,6 +1447,13 @@ def take_along_axis(a, indices, axis):
"""
Take values from the input array by matching 1d index and data slices.

This iterates over matching 1d slices oriented along the specified axis in
the index and data arrays, and uses the former to look up values in the
latter. These slices can be different lengths.

Functions returning an index along an `axis`, like :obj:`dpnp.argsort` and
:obj:`dpnp.argpartition`, produce suitable indices for this function.

For full documentation refer to :obj:`numpy.take_along_axis`.

Parameters
Expand All @@ -1428,7 +1467,7 @@ def take_along_axis(a, indices, axis):
axis : int
The axis to take 1d slices along. If axis is ``None``, the input
array is treated as if it had first been flattened to 1d,
for consistency with `sort` and `argsort`.
for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.

Returns
-------
Expand Down
3 changes: 2 additions & 1 deletion tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,9 @@ def test_fill_diagonal(array, val):

@pytest.mark.parametrize(
"dimension",
[(1,), (2,), (1, 2), (2, 3), (3, 2), [1], [2], [1, 2], [2, 3], [3, 2]],
[(), (1,), (2,), (1, 2), (2, 3), (3, 2), [1], [2], [1, 2], [2, 3], [3, 2]],
ids=[
"()",
"(1, )",
"(2, )",
"(1, 2)",
Expand Down
Loading