Skip to content

Commit f9a77a2

Browse files
author
Diptorup Deb
authored
Merge pull request #997 from chudur-budur/feature/dpnp.full_like
Implementation for dpnp.full_like()
2 parents 3ded18e + 3a7ea5e commit f9a77a2

File tree

3 files changed

+298
-11
lines changed

3 files changed

+298
-11
lines changed

numba_dpex/dpnp_iface/_intrinsic.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,71 @@ def codegen(context, builder, sig, llargs):
295295
return sig, codegen
296296

297297

298+
@intrinsic
299+
def impl_dpnp_full(
300+
ty_context,
301+
ty_shape,
302+
ty_fill_value,
303+
ty_dtype,
304+
ty_order,
305+
ty_like,
306+
ty_device,
307+
ty_usm_type,
308+
ty_sycl_queue,
309+
ty_retty_ref,
310+
):
311+
"""A numba "intrinsic" function to inject code for dpnp.full().
312+
313+
Args:
314+
ty_context (numba.core.typing.context.Context): The typing context
315+
for the codegen.
316+
ty_shape (numba.core.types.scalars.Integer or
317+
numba.core.types.containers.UniTuple): Numba type for the shape
318+
of the array.
319+
ty_fill_value (numba.core.types.scalars): One of the Numba scalar
320+
types.
321+
ty_dtype (numba.core.types.functions.NumberClass): Numba type for
322+
dtype.
323+
ty_order (numba.core.types.misc.UnicodeType): UnicodeType
324+
from numba for strings.
325+
ty_like (numba.core.types.npytypes.Array): Numba type for array.
326+
ty_device (numba.core.types.misc.UnicodeType): UnicodeType
327+
from numba for strings.
328+
ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType
329+
from numba for strings.
330+
ty_sycl_queue (numba.core.types.misc.UnicodeType): UnicodeType
331+
from numba for strings.
332+
ty_retty_ref (numba.core.types.abstract.TypeRef): Reference to
333+
a type from numba, used when a type is passed as a value.
334+
335+
Returns:
336+
tuple(numba.core.typing.templates.Signature, function): A tuple of
337+
numba function signature type and a function object.
338+
"""
339+
340+
ty_retty = ty_retty_ref.instance_type
341+
signature = ty_retty(
342+
ty_shape,
343+
ty_fill_value,
344+
ty_dtype,
345+
ty_order,
346+
ty_like,
347+
ty_device,
348+
ty_usm_type,
349+
ty_sycl_queue,
350+
ty_retty_ref,
351+
)
352+
353+
def codegen(context, builder, sig, args):
354+
fill_value = context.get_argument_value(builder, sig.args[1], args[1])
355+
ary, _ = fill_arrayobj(
356+
context, builder, sig, args, fill_value, is_like=False
357+
)
358+
return ary._getvalue()
359+
360+
return signature, codegen
361+
362+
298363
@intrinsic
299364
def impl_dpnp_empty_like(
300365
ty_context,
@@ -490,33 +555,36 @@ def codegen(context, builder, sig, llargs):
490555

491556

492557
@intrinsic
493-
def impl_dpnp_full(
558+
def impl_dpnp_full_like(
494559
ty_context,
495-
ty_shape,
560+
ty_x1,
496561
ty_fill_value,
497562
ty_dtype,
498563
ty_order,
499-
ty_like,
564+
ty_subok,
565+
ty_shape,
500566
ty_device,
501567
ty_usm_type,
502568
ty_sycl_queue,
503569
ty_retty_ref,
504570
):
505-
"""A numba "intrinsic" function to inject code for dpnp.full().
571+
"""A numba "intrinsic" function to inject code for dpnp.full_like().
506572
507573
Args:
508574
ty_context (numba.core.typing.context.Context): The typing context
509575
for the codegen.
510-
ty_shape (numba.core.types.scalars.Integer or
511-
numba.core.types.containers.UniTuple): Numba type for the shape
512-
of the array.
576+
ty_x1 (numba.core.types.npytypes.Array): Numba type class for ndarray.
513577
ty_fill_value (numba.core.types.scalars): One of the Numba scalar
514578
types.
515579
ty_dtype (numba.core.types.functions.NumberClass): Numba type for
516580
dtype.
517581
ty_order (numba.core.types.misc.UnicodeType): UnicodeType
518582
from numba for strings.
519-
ty_like (numba.core.types.npytypes.Array): Numba type for array.
583+
ty_subok (numba.core.types.scalars.Boolean): Numba type class for
584+
subok.
585+
ty_shape (numba.core.types.scalars.Integer or
586+
numba.core.types.containers.UniTuple): Numba type for the shape
587+
of the array. Not supported.
520588
ty_device (numba.core.types.misc.UnicodeType): UnicodeType
521589
from numba for strings.
522590
ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType
@@ -533,11 +601,12 @@ def impl_dpnp_full(
533601

534602
ty_retty = ty_retty_ref.instance_type
535603
signature = ty_retty(
536-
ty_shape,
604+
ty_x1,
537605
ty_fill_value,
538606
ty_dtype,
539607
ty_order,
540-
ty_like,
608+
ty_subok,
609+
ty_shape,
541610
ty_device,
542611
ty_usm_type,
543612
ty_sycl_queue,
@@ -547,7 +616,7 @@ def impl_dpnp_full(
547616
def codegen(context, builder, sig, args):
548617
fill_value = context.get_argument_value(builder, sig.args[1], args[1])
549618
ary, _ = fill_arrayobj(
550-
context, builder, sig, args, fill_value, is_like=False
619+
context, builder, sig, args, fill_value, is_like=True
551620
)
552621
return ary._getvalue()
553622

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
impl_dpnp_empty,
1616
impl_dpnp_empty_like,
1717
impl_dpnp_full,
18+
impl_dpnp_full_like,
1819
impl_dpnp_ones,
1920
impl_dpnp_ones_like,
2021
impl_dpnp_zeros,
@@ -814,6 +815,118 @@ def impl(
814815
)
815816

816817

818+
@overload(dpnp.full_like, prefer_literal=True)
819+
def ol_dpnp_full_like(
820+
x1,
821+
fill_value,
822+
dtype=None,
823+
order="C",
824+
subok=None,
825+
shape=None,
826+
device=None,
827+
usm_type=None,
828+
sycl_queue=None,
829+
):
830+
"""Creates `usm_ndarray` from USM allocation initialized with values
831+
specified by the `fill_value`.
832+
833+
This is an overloaded function implementation for dpnp.full_like().
834+
835+
Args:
836+
x1 (numba.core.types.npytypes.Array): Input array from which to
837+
derive the output array shape.
838+
fill_value (numba.core.types.scalars): One of the
839+
numba.core.types.scalar types for the value to
840+
be filled.
841+
dtype (numba.core.types.functions.NumberClass, optional):
842+
Data type of the array. Can be typestring, a `numpy.dtype`
843+
object, `numpy` char string, or a numpy scalar type.
844+
Default: None.
845+
order (str, optional): memory layout for the array "C" or "F".
846+
Default: "C".
847+
subok ('numba.core.types.scalars.BooleanLiteral', optional): A
848+
boolean literal type for the `subok` parameter defined in
849+
NumPy. If True, then the newly created array will use the
850+
sub-class type of prototype, otherwise it will be a
851+
base-class array. Defaults to False.
852+
shape (numba.core.types.containers.UniTuple, optional): The shape
853+
to override the shape of the given array. Not supported.
854+
Default: `None`
855+
device (numba.core.types.misc.StringLiteral, optional): array API
856+
concept of device where the output array is created. `device`
857+
can be `None`, a oneAPI filter selector string, an instance of
858+
:class:`dpctl.SyclDevice` corresponding to a non-partitioned
859+
SYCL device, an instance of :class:`dpctl.SyclQueue`, or a
860+
`Device` object returnedby`dpctl.tensor.usm_array.device`.
861+
Default: `None`.
862+
usm_type (numba.core.types.misc.StringLiteral or str, optional):
863+
The type of SYCL USM allocation for the output array.
864+
Allowed values are "device"|"shared"|"host".
865+
Default: `"device"`.
866+
sycl_queue (:class:`dpctl.SyclQueue`, optional): Not supported.
867+
868+
Raises:
869+
errors.TypingError: If couldn't parse input types to dpnp.full_like().
870+
errors.TypingError: If shape is provided.
871+
872+
Returns:
873+
function: Local function `impl_dpnp_full_like()`.
874+
"""
875+
876+
if shape:
877+
raise errors.TypingError(
878+
"The parameter shape is not supported "
879+
+ "inside overloaded dpnp.full_like() function."
880+
)
881+
_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim is not None else 0
882+
_dtype = _parse_dtype(dtype, data=x1)
883+
_order = x1.layout if order is None else order
884+
_usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device"
885+
_device = (
886+
_parse_device_filter_string(device) if device is not None else "unknown"
887+
)
888+
ret_ty = build_dpnp_ndarray(
889+
_ndim,
890+
layout=_order,
891+
dtype=_dtype,
892+
usm_type=_usm_type,
893+
device=_device,
894+
queue=sycl_queue,
895+
)
896+
if ret_ty:
897+
898+
def impl(
899+
x1,
900+
fill_value,
901+
dtype=None,
902+
order="C",
903+
subok=None,
904+
shape=None,
905+
device=None,
906+
usm_type=None,
907+
sycl_queue=None,
908+
):
909+
return impl_dpnp_full_like(
910+
x1,
911+
fill_value,
912+
_dtype,
913+
_order,
914+
subok,
915+
shape,
916+
_device,
917+
_usm_type,
918+
sycl_queue,
919+
ret_ty,
920+
)
921+
922+
return impl
923+
else:
924+
raise errors.TypingError(
925+
"Cannot parse input types to "
926+
+ f"function dpnp.full_like({x1}, {fill_value}, {dtype}, ...)."
927+
)
928+
929+
817930
@overload(dpnp.full, prefer_literal=True)
818931
def ol_dpnp_full(
819932
shape,
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Tests for dpnp ndarray constructors."""
6+
7+
import math
8+
9+
import dpctl
10+
import dpctl.tensor as dpt
11+
import dpnp
12+
import numpy
13+
import pytest
14+
from numba import errors
15+
16+
from numba_dpex import dpjit
17+
18+
shapes = [11, (3, 7)]
19+
dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64]
20+
usm_types = ["device", "shared", "host"]
21+
devices = ["cpu", "unknown"]
22+
fill_values = [
23+
7,
24+
-7,
25+
7.1,
26+
-7.1,
27+
math.pi,
28+
math.e,
29+
4294967295,
30+
4294967295.0,
31+
3.4028237e38,
32+
]
33+
34+
35+
@pytest.mark.parametrize("shape", shapes)
36+
@pytest.mark.parametrize("fill_value", fill_values)
37+
@pytest.mark.parametrize("dtype", dtypes)
38+
@pytest.mark.parametrize("usm_type", usm_types)
39+
@pytest.mark.parametrize("device", devices)
40+
def test_dpnp_full_like(shape, fill_value, dtype, usm_type, device):
41+
@dpjit
42+
def func(a, v):
43+
c = dpnp.full_like(a, v, dtype=dtype, usm_type=usm_type, device=device)
44+
return c
45+
46+
if isinstance(shape, int):
47+
NZ = numpy.random.rand(shape)
48+
else:
49+
NZ = numpy.random.rand(*shape)
50+
51+
try:
52+
c = func(NZ, fill_value)
53+
except Exception:
54+
pytest.fail("Calling dpnp.full_like inside dpjit failed")
55+
56+
C = numpy.full_like(NZ, fill_value, dtype=dtype)
57+
58+
if len(c.shape) == 1:
59+
assert c.shape[0] == NZ.shape[0]
60+
else:
61+
assert c.shape == NZ.shape
62+
63+
assert c.dtype == dtype
64+
assert c.usm_type == usm_type
65+
if device != "unknown":
66+
assert (
67+
c.sycl_device.filter_string
68+
== dpctl.SyclDevice(device).filter_string
69+
)
70+
else:
71+
c.sycl_device.filter_string == dpctl.SyclDevice().filter_string
72+
73+
assert numpy.array_equal(dpt.asnumpy(c._array_obj), C)
74+
75+
76+
def test_dpnp_full_like_exceptions():
77+
@dpjit
78+
def func1(a):
79+
c = dpnp.full_like(a, shape=(3, 3))
80+
return c
81+
82+
try:
83+
func1(numpy.random.rand(5, 5))
84+
except Exception as e:
85+
assert isinstance(e, errors.TypingError)
86+
assert (
87+
"No implementation of function Function(<function full_like"
88+
in str(e)
89+
)
90+
91+
queue = dpctl.SyclQueue()
92+
93+
@dpjit
94+
def func2(a):
95+
c = dpnp.full_like(a, sycl_queue=queue)
96+
return c
97+
98+
try:
99+
func2(numpy.random.rand(5, 5))
100+
except Exception as e:
101+
assert isinstance(e, errors.TypingError)
102+
assert (
103+
"No implementation of function Function(<function full_like"
104+
in str(e)
105+
)

0 commit comments

Comments
 (0)