Skip to content

Implement dpnp.common_type() #2391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions dpnp/dpnp_iface_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,25 @@
This module provides public type interface file for the library
"""

import functools

import dpctl
import dpctl.tensor as dpt
import numpy

import dpnp

from .dpnp_array import dpnp_array

# pylint: disable=no-name-in-module
from .dpnp_utils import get_usm_allocations

__all__ = [
"bool",
"bool_",
"byte",
"cdouble",
"common_type",
"complex128",
"complex64",
"complexfloating",
Expand Down Expand Up @@ -145,6 +154,67 @@
pi = numpy.pi


def common_type(*arrays):
"""
Return a scalar type which is common to the input arrays.

The return type will always be an inexact (i.e. floating point or complex)
scalar type, even if all the arrays are integer arrays.
If one of the inputs is an integer array, the minimum precision type
that is returned is the default floating point data type for the device
where the input arrays are allocated.

For full documentation refer to :obj:`numpy.common_type`.

Parameters
----------
arrays: {dpnp.ndarray, usm_ndarray}
Input arrays.

Returns
-------
out: data type
Data type object.

See Also
--------
:obj:`dpnp.dtype` : Create a data type object.

Examples
--------
>>> import dpnp as np
>>> np.common_type(np.arange(2, dtype=np.float32))
numpy.float32
>>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2))
numpy.float64 # may vary
>>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0]))
numpy.complex128 # may vary

"""

if len(arrays) == 0:
return (
dpnp.float16
if dpctl.select_default_device().has_aspect_fp16
else dpnp.float32
)

dpnp.check_supported_arrays_type(*arrays)

_, exec_q = get_usm_allocations(arrays)
default_float_dtype = dpnp.default_float_type(sycl_queue=exec_q)
dtypes = []
for a in arrays:
if not dpnp.issubdtype(a.dtype, dpnp.number):
raise TypeError("can't get common type for non-numeric array")
if dpnp.issubdtype(a.dtype, dpnp.integer):
dtypes.append(default_float_dtype)
else:
dtypes.append(a.dtype)

return functools.reduce(numpy.promote_types, dtypes).type


# pylint: disable=redefined-outer-name
def finfo(dtype):
"""
Expand Down
15 changes: 13 additions & 2 deletions dpnp/tests/third_party/cupy/test_type_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

import dpnp as cupy
from dpnp.tests.helper import has_support_aspect64
from dpnp.tests.helper import has_support_aspect16, has_support_aspect64
from dpnp.tests.third_party.cupy import testing


Expand Down Expand Up @@ -47,13 +47,17 @@ def test_can_cast(self, xp, from_dtype, to_dtype):
return ret


@pytest.mark.skip("dpnp.common_type() is not implemented yet")
class TestCommonType(unittest.TestCase):

@testing.numpy_cupy_equal()
def test_common_type_empty(self, xp):
ret = xp.common_type()
assert type(ret) is type
# NumPy always returns float16 for empty input,
# but dpnp returns float32 if the device does not support
# 16-bit precision floating point operations
if xp is numpy and not has_support_aspect16():
return xp.float32
return ret

@testing.for_all_dtypes(no_bool=True)
Expand All @@ -62,6 +66,11 @@ def test_common_type_single_argument(self, xp, dtype):
array = _generate_type_routines_input(xp, dtype, "array")
ret = xp.common_type(array)
assert type(ret) is type
# NumPy promotes integer types to float64,
# but dpnp may return float32 if the device does not support
# 64-bit precision floating point operations.
if xp is numpy and not has_support_aspect64():
return xp.float32
return ret

@testing.for_all_dtypes_combination(
Expand All @@ -73,6 +82,8 @@ def test_common_type_two_arguments(self, xp, dtype1, dtype2):
array2 = _generate_type_routines_input(xp, dtype2, "array")
ret = xp.common_type(array1, array2)
assert type(ret) is type
if xp is numpy and not has_support_aspect64():
return xp.float32
return ret

@testing.for_all_dtypes()
Expand Down
Loading