Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
tanh,
trunc,
)
from ._reduction import argmax, argmin, max, min, sum
from ._reduction import argmax, argmin, max, min, prod, sum
from ._testing import allclose

__all__ = [
Expand Down Expand Up @@ -313,4 +313,5 @@
"min",
"argmax",
"argmin",
"prod",
]
63 changes: 62 additions & 1 deletion dpctl/tensor/_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _reduction_over_axis(
def sum(x, axis=None, dtype=None, keepdims=False):
"""sum(x, axis=None, dtype=None, keepdims=False)

Calculates the sum of the input array `x`.
Calculates the sum of elements in the input array `x`.

Args:
x (usm_ndarray):
Expand Down Expand Up @@ -202,6 +202,67 @@ def sum(x, axis=None, dtype=None, keepdims=False):
)


def prod(x, axis=None, dtype=None, keepdims=False):
"""prod(x, axis=None, dtype=None, keepdims=False)

Calculates the product of elements in the input array `x`.

Args:
x (usm_ndarray):
input array.
axis (Optional[int, Tuple[int,...]]):
axis or axes along which sums must be computed. If a tuple
of unique integers, sums are computed over multiple axes.
If `None`, the sum is computed over the entire array.
Default: `None`.
dtype (Optional[dtype]):
data type of the returned array. If `None`, the default data
type is inferred from the "kind" of the input array data type.
* If `x` has a real-valued floating-point data type,
the returned array will have the default real-valued
floating-point data type for the device where input
array `x` is allocated.
* If x` has signed integral data type, the returned array
will have the default signed integral type for the device
where input array `x` is allocated.
* If `x` has unsigned integral data type, the returned array
will have the default unsigned integral type for the device
where input array `x` is allocated.
* If `x` has a complex-valued floating-point data typee,
the returned array will have the default complex-valued
floating-pointer data type for the device where input
array `x` is allocated.
* If `x` has a boolean data type, the returned array will
have the default signed integral type for the device
where input array `x` is allocated.
If the data type (either specified or resolved) differs from the
data type of `x`, the input array elements are cast to the
specified data type before computing the sum. Default: `None`.
keepdims (Optional[bool]):
if `True`, the reduced axes (dimensions) are included in the result
as singleton dimensions, so that the returned array remains
compatible with the input arrays according to Array Broadcasting
rules. Otherwise, if `False`, the reduced axes are not included in
the returned array. Default: `False`.
Returns:
usm_ndarray:
an array containing the products. If the product was computed over
the entire array, a zero-dimensional array is returned. The returned
array has the data type as described in the `dtype` parameter
description above.
"""
return _reduction_over_axis(
x,
axis,
dtype,
keepdims,
ti._prod_over_axis,
ti._prod_over_axis_dtype_supported,
_default_reduction_dtype,
_identity=1,
)


def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
Expand Down
Loading