Skip to content

Commit 7d4f1ca

Browse files
Implement dpnp.common_type() (#2391)
This PR suggests adding an implementation of `dpnp.common_type()` and updates cupy `TestCommonType` tests.
1 parent 771f9b2 commit 7d4f1ca

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

dpnp/dpnp_iface_types.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,25 @@
3232
This module provides public type interface file for the library
3333
"""
3434

35+
import functools
36+
37+
import dpctl
3538
import dpctl.tensor as dpt
3639
import numpy
3740

41+
import dpnp
42+
3843
from .dpnp_array import dpnp_array
3944

45+
# pylint: disable=no-name-in-module
46+
from .dpnp_utils import get_usm_allocations
47+
4048
__all__ = [
4149
"bool",
4250
"bool_",
4351
"byte",
4452
"cdouble",
53+
"common_type",
4554
"complex128",
4655
"complex64",
4756
"complexfloating",
@@ -145,6 +154,67 @@
145154
pi = numpy.pi
146155

147156

157+
def common_type(*arrays):
158+
"""
159+
Return a scalar type which is common to the input arrays.
160+
161+
The return type will always be an inexact (i.e. floating point or complex)
162+
scalar type, even if all the arrays are integer arrays.
163+
If one of the inputs is an integer array, the minimum precision type
164+
that is returned is the default floating point data type for the device
165+
where the input arrays are allocated.
166+
167+
For full documentation refer to :obj:`numpy.common_type`.
168+
169+
Parameters
170+
----------
171+
arrays: {dpnp.ndarray, usm_ndarray}
172+
Input arrays.
173+
174+
Returns
175+
-------
176+
out: data type
177+
Data type object.
178+
179+
See Also
180+
--------
181+
:obj:`dpnp.dtype` : Create a data type object.
182+
183+
Examples
184+
--------
185+
>>> import dpnp as np
186+
>>> np.common_type(np.arange(2, dtype=np.float32))
187+
numpy.float32
188+
>>> np.common_type(np.arange(2, dtype=np.float32), np.arange(2))
189+
numpy.float64 # may vary
190+
>>> np.common_type(np.arange(4), np.array([45, 6.j]), np.array([45.0]))
191+
numpy.complex128 # may vary
192+
193+
"""
194+
195+
if len(arrays) == 0:
196+
return (
197+
dpnp.float16
198+
if dpctl.select_default_device().has_aspect_fp16
199+
else dpnp.float32
200+
)
201+
202+
dpnp.check_supported_arrays_type(*arrays)
203+
204+
_, exec_q = get_usm_allocations(arrays)
205+
default_float_dtype = dpnp.default_float_type(sycl_queue=exec_q)
206+
dtypes = []
207+
for a in arrays:
208+
if not dpnp.issubdtype(a.dtype, dpnp.number):
209+
raise TypeError("can't get common type for non-numeric array")
210+
if dpnp.issubdtype(a.dtype, dpnp.integer):
211+
dtypes.append(default_float_dtype)
212+
else:
213+
dtypes.append(a.dtype)
214+
215+
return functools.reduce(numpy.promote_types, dtypes).type
216+
217+
148218
# pylint: disable=redefined-outer-name
149219
def finfo(dtype):
150220
"""

dpnp/tests/third_party/cupy/test_type_routines.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
import dpnp as cupy
7-
from dpnp.tests.helper import has_support_aspect64
7+
from dpnp.tests.helper import has_support_aspect16, has_support_aspect64
88
from dpnp.tests.third_party.cupy import testing
99

1010

@@ -47,13 +47,17 @@ def test_can_cast(self, xp, from_dtype, to_dtype):
4747
return ret
4848

4949

50-
@pytest.mark.skip("dpnp.common_type() is not implemented yet")
5150
class TestCommonType(unittest.TestCase):
5251

5352
@testing.numpy_cupy_equal()
5453
def test_common_type_empty(self, xp):
5554
ret = xp.common_type()
5655
assert type(ret) is type
56+
# NumPy always returns float16 for empty input,
57+
# but dpnp returns float32 if the device does not support
58+
# 16-bit precision floating point operations
59+
if xp is numpy and not has_support_aspect16():
60+
return xp.float32
5761
return ret
5862

5963
@testing.for_all_dtypes(no_bool=True)
@@ -62,6 +66,11 @@ def test_common_type_single_argument(self, xp, dtype):
6266
array = _generate_type_routines_input(xp, dtype, "array")
6367
ret = xp.common_type(array)
6468
assert type(ret) is type
69+
# NumPy promotes integer types to float64,
70+
# but dpnp may return float32 if the device does not support
71+
# 64-bit precision floating point operations.
72+
if xp is numpy and not has_support_aspect64():
73+
return xp.float32
6574
return ret
6675

6776
@testing.for_all_dtypes_combination(
@@ -73,6 +82,8 @@ def test_common_type_two_arguments(self, xp, dtype1, dtype2):
7382
array2 = _generate_type_routines_input(xp, dtype2, "array")
7483
ret = xp.common_type(array1, array2)
7584
assert type(ret) is type
85+
if xp is numpy and not has_support_aspect64():
86+
return xp.float32
7687
return ret
7788

7889
@testing.for_all_dtypes()

0 commit comments

Comments
 (0)