Skip to content

implement dpnp.apply_along_axis #2169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/reference/fft.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
FFT Functions
=============
Discrete Fourier Transform
==========================

.. https://numpy.org/doc/stable/reference/routines.fft.html

Expand Down
14 changes: 14 additions & 0 deletions doc/reference/functional.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Functional programming
======================

.. https://numpy.org/doc/stable/reference/routines.functional.html

.. autosummary::
:toctree: generated/
:nosignatures:

dpnp.apply_along_axis
dpnp.apply_over_axes
dpnp.vectorize
dpnp.frompyfunc
dpnp.piecewise
2 changes: 1 addition & 1 deletion doc/reference/linalg.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Linear Algebra
Linear algebra
==============

.. https://numpy.org/doc/stable/reference/routines.linalg.html
Expand Down
2 changes: 1 addition & 1 deletion doc/reference/logic.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Logic Functions
Logic functions
===============

.. https://numpy.org/doc/stable/reference/routines.logic.html
Expand Down
2 changes: 1 addition & 1 deletion doc/reference/manipulation.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Array Manipulation Routines
Array manipulation routines
===========================

.. https://numpy.org/doc/stable/reference/routines.array-manipulation.html
Expand Down
2 changes: 1 addition & 1 deletion doc/reference/random.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Random Sampling (``dpnp.random``)
Random sampling (``dpnp.random``)
=================================

.. https://numpy.org/doc/stable/reference/random/legacy.html
Expand Down
5 changes: 3 additions & 2 deletions doc/reference/routines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Routines

The following pages describe NumPy-compatible routines.
These functions cover a subset of
`NumPy routines <https://docs.scipy.org/doc/numpy/reference/routines.html>`_.
`NumPy routines <https://numpy.org/doc/stable/reference/routines.html>`_.

.. currentmodule:: dpnp

Expand All @@ -13,10 +13,11 @@ These functions cover a subset of

creation
manipulation
indexing
binary
dtype
fft
functional
indexing
linalg
logic
math
Expand Down
2 changes: 1 addition & 1 deletion doc/reference/sorting.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Sorting, Searching, and Counting
Sorting, searching, and counting
================================

.. https://numpy.org/doc/stable/reference/routines.sort.html
Expand Down
4 changes: 2 additions & 2 deletions doc/reference/statistics.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Statistical Functions
=====================
Statistics
==========

.. https://numpy.org/doc/stable/reference/routines.statistics.html

Expand Down
3 changes: 3 additions & 0 deletions dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
from dpnp.dpnp_iface_bitwise import __all__ as __all__bitwise
from dpnp.dpnp_iface_counting import *
from dpnp.dpnp_iface_counting import __all__ as __all__counting
from dpnp.dpnp_iface_functional import *
from dpnp.dpnp_iface_functional import __all__ as __all__functional
from dpnp.dpnp_iface_histograms import *
from dpnp.dpnp_iface_histograms import __all__ as __all__histograms
from dpnp.dpnp_iface_indexing import *
Expand Down Expand Up @@ -116,6 +118,7 @@
__all__ += __all__arraycreation
__all__ += __all__bitwise
__all__ += __all__counting
__all__ += __all__functional
__all__ += __all__histograms
__all__ += __all__indexing
__all__ += __all__libmath
Expand Down
187 changes: 187 additions & 0 deletions dpnp/dpnp_iface_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# *****************************************************************************
# Copyright (c) 2024, Intel Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# - Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# - Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
# THE POSSIBILITY OF SUCH DAMAGE.
# *****************************************************************************

"""
Interface of the functional programming routines part of the DPNP

Notes
-----
This module is a face or public interface file for the library
it contains:
- Interface functions
- documentation for the functions
- The functions parameters check

"""


import numpy
from dpctl.tensor._numpy_helper import normalize_axis_index

import dpnp

__all__ = ["apply_along_axis"]


def apply_along_axis(func1d, axis, arr, *args, **kwargs):
"""
Apply a function to 1-D slices along the given axis.

Execute ``func1d(a, *args, **kwargs)`` where `func1d` operates on
1-D arrays and `a` is a 1-D slice of `arr` along `axis`.

This is equivalent to (but faster than) the following use of
:obj:`dpnp.ndindex` and :obj:`dpnp.s_`, which sets each of
``ii``, ``jj``, and ``kk`` to a tuple of indices::

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
for ii in ndindex(Ni):
for kk in ndindex(Nk):
f = func1d(arr[ii + s_[:,] + kk])
Nj = f.shape
for jj in ndindex(Nj):
out[ii + jj + kk] = f[jj]

Equivalently, eliminating the inner loop, this can be expressed as::

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
for ii in ndindex(Ni):
for kk in ndindex(Nk):
out[ii + s_[...,] + kk] = func1d(arr[ii + s_[:,] + kk])

For full documentation refer to :obj:`numpy.apply_along_axis`.

Parameters
----------
func1d : function (M,) -> (Nj...)
This function should accept 1-D arrays. It is applied to 1-D
slices of `arr` along the specified axis.
axis : int
Axis along which `arr` is sliced.
arr : {dpnp.ndarray, usm_ndarray} (Ni..., M, Nk...)
Input array.
args : any
Additional arguments to `func1d`.
kwargs : any
Additional named arguments to `func1d`.

Returns
-------
out : dpnp.ndarray (Ni..., Nj..., Nk...)
The output array. The shape of `out` is identical to the shape of
`arr`, except along the `axis` dimension. This axis is removed, and
replaced with new dimensions equal to the shape of the return value
of `func1d`.

See Also
--------
:obj:`dpnp.apply_over_axes` : Apply a function repeatedly over
multiple axes.

Examples
--------
>>> import dpnp as np
>>> def my_func(a): # Average first and last element of a 1-D array
... return (a[0] + a[-1]) * 0.5
>>> b = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> np.apply_along_axis(my_func, 0, b)
array([4., 5., 6.])
>>> np.apply_along_axis(my_func, 1, b)
array([2., 5., 8.])

For a function that returns a 1D array, the number of dimensions in
`out` is the same as `arr`.

>>> b = np.array([[8, 1, 7], [4, 3, 9], [5, 2, 6]])
>>> np.apply_along_axis(sorted, 1, b)
array([[1, 7, 8],
[3, 4, 9],
[2, 5, 6]])

For a function that returns a higher dimensional array, those dimensions
are inserted in place of the `axis` dimension.

>>> b = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> np.apply_along_axis(np.diag, -1, b)
array([[[1, 0, 0],
[0, 2, 0],
[0, 0, 3]],
[[4, 0, 0],
[0, 5, 0],
[0, 0, 6]],
[[7, 0, 0],
[0, 8, 0],
[0, 0, 9]]])

"""

dpnp.check_supported_arrays_type(arr)
nd = arr.ndim
exec_q = arr.sycl_queue
usm_type = arr.usm_type
axis = normalize_axis_index(axis, nd)

# arr, with the iteration axis at the end
inarr_view = dpnp.moveaxis(arr, axis, -1)

# compute indices for the iteration axes, and append a trailing ellipsis to
# prevent 0d arrays decaying to scalars
# TODO: replace with dpnp.ndindex
inds = numpy.ndindex(inarr_view.shape[:-1])
inds = (ind + (Ellipsis,) for ind in inds)

# invoke the function on the first item
try:
ind0 = next(inds)
except StopIteration:
raise ValueError(
"Cannot apply_along_axis when any iteration dimensions are 0"
) from None
res = dpnp.asanyarray(
func1d(inarr_view[ind0], *args, **kwargs),
sycl_queue=exec_q,
usm_type=usm_type,
)

# build a buffer for storing evaluations of func1d.
# remove the requested axis, and add the new ones on the end.
# laid out so that each write is contiguous.
# for a tuple index inds, buff[inds] = func1d(inarr_view[inds])
buff = dpnp.empty_like(res, shape=inarr_view.shape[:-1] + res.shape)

# save the first result, then compute and save all remaining results
buff[ind0] = res
for ind in inds:
buff[ind] = dpnp.asanyarray(
func1d(inarr_view[ind], *args, **kwargs),
sycl_queue=exec_q,
usm_type=usm_type,
)

# restore the inserted axes back to where they belong
for _ in range(res.ndim):
buff = dpnp.moveaxis(buff, -1, axis)

return buff
2 changes: 1 addition & 1 deletion dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# *****************************************************************************

"""
Interface of the Array manipulation routines part of the DPNP
Interface of the array manipulation routines part of the DPNP

Notes
-----
Expand Down
48 changes: 48 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy
import pytest
from numpy.testing import assert_array_equal, assert_equal, assert_raises

import dpnp

from .helper import get_all_dtypes


class TestApplyAlongAxis:
def test_tuple_func1d(self):
def sample_1d(x):
return x[1], x[0]

a = numpy.array([[1, 2], [3, 4]])
ia = dpnp.array(a)

# 2d insertion along first axis
expected = numpy.apply_along_axis(sample_1d, 1, a)
result = dpnp.apply_along_axis(sample_1d, 1, ia)
assert_array_equal(result, expected)

@pytest.mark.parametrize("stride", [-1, 2, -3])
def test_stride(self, stride):
a = numpy.ones((20, 10), dtype="f")
ia = dpnp.array(a)

expected = numpy.apply_along_axis(len, 0, a[::stride, ::stride])
result = dpnp.apply_along_axis(len, 0, ia[::stride, ::stride])
assert_array_equal(result, expected)

@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_args(self, dtype):
a = numpy.ones((20, 10))
ia = dpnp.array(a)

# kwargs
expected = numpy.apply_along_axis(
numpy.mean, 0, a, dtype=dtype, keepdims=True
)
result = dpnp.apply_along_axis(
dpnp.mean, 0, ia, dtype=dtype, keepdims=True
)
assert_array_equal(result, expected)

# positional args: axis, dtype, out, keepdims
result = dpnp.apply_along_axis(dpnp.mean, 0, ia, 0, dtype, None, True)
assert_array_equal(result, expected)
12 changes: 12 additions & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,6 +2178,18 @@ def test_split(func, data1, device):
assert_sycl_queue_equal(result[1].sycl_queue, x1.sycl_queue)


@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_apply_along_axis(device):
x = dpnp.arange(9, device=device).reshape(3, 3)
result = dpnp.apply_along_axis(dpnp.sum, 0, x)

assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)


@pytest.mark.parametrize(
"device_x",
valid_devices,
Expand Down
8 changes: 8 additions & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,14 @@ def test_2in_with_scalar_1out(func, data, scalar, usm_type):
assert z.usm_type == usm_type


@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_apply_along_axis(usm_type):
x = dp.arange(9, usm_type=usm_type).reshape(3, 3)
y = dp.apply_along_axis(dp.sum, 0, x)

assert x.usm_type == y.usm_type


@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_broadcast_to(usm_type):
x = dp.ones(7, usm_type=usm_type)
Expand Down
Loading
Loading