Skip to content

Commit cde375f

Browse files
authored
Fix eager compilation with signature for dppy.kernel (#291)
* Enable AOT with signature for dppy.kernel * Added queue to specialize * Skip when Windows and level0 * Skip test for windows regardless of backend * Update import order * Update platform check * Make function internal
1 parent 04529af commit cde375f

File tree

6 files changed

+92
-39
lines changed

6 files changed

+92
-39
lines changed

numba_dppy/compiler.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ def compile_kernel(sycl_queue, pyfunc, args, access_types, debug=False):
142142
print("compile_kernel", args)
143143
debug = True
144144
if not sycl_queue:
145-
# This will be get_current_queue
146-
sycl_queue = dpctl.get_current_queue()
145+
# We expect the sycl_queue to be provided when this function is called
146+
raise ValueError("SYCL queue is required for compiling a kernel")
147147

148148
cres = compile_with_dppy(pyfunc, None, args, debug=debug)
149149
func = cres.library.get_function(cres.fndesc.llvm_func_name)
@@ -561,32 +561,42 @@ def check_for_invalid_access_type(self, access_type):
561561

562562

563563
class JitDPPYKernel(DPPYKernelBase):
564-
def __init__(self, func, access_types):
564+
def __init__(self, func, debug, access_types):
565565

566566
super(JitDPPYKernel, self).__init__()
567567

568568
self.py_func = func
569569
self.definitions = {}
570+
self.debug = debug
570571
self.access_types = access_types
571572

572573
from .descriptor import dppy_target
573574

574575
self.typingctx = dppy_target.typing_context
575576

577+
def _get_argtypes(self, *args):
578+
"""
579+
Convenience function to get the type of each argument.
580+
"""
581+
return tuple([self.typingctx.resolve_argument_type(a) for a in args])
582+
576583
def __call__(self, *args, **kwargs):
577584
assert not kwargs, "Keyword Arguments are not supported"
578-
if self.sycl_queue is None:
579-
try:
580-
self.sycl_queue = dpctl.get_current_queue()
581-
except:
582-
_raise_no_device_found_error()
585+
try:
586+
current_queue = dpctl.get_current_queue()
587+
except:
588+
_raise_no_device_found_error()
583589

584-
kernel = self.specialize(*args)
590+
argtypes = self._get_argtypes(*args)
591+
kernel = self.specialize(argtypes, current_queue)
585592
cfg = kernel.configure(self.sycl_queue, self.global_size, self.local_size)
586593
cfg(*args)
587594

588-
def specialize(self, *args):
589-
argtypes = tuple([self.typingctx.resolve_argument_type(a) for a in args])
595+
def specialize(self, argtypes, queue):
596+
# We specialize for argtypes and queue. These two are used as key for
597+
# caching as well.
598+
assert queue is not None
599+
590600
sycl_ctx = None
591601
kernel = None
592602
# we were previously using the _env_ptr of the device_env, the sycl_queue
@@ -598,11 +608,9 @@ def specialize(self, *args):
598608
if result:
599609
sycl_ctx, kernel = result
600610

601-
if sycl_ctx and sycl_ctx == self.sycl_queue.sycl_context:
611+
if sycl_ctx and sycl_ctx == queue.sycl_context:
602612
return kernel
603613
else:
604-
kernel = compile_kernel(
605-
self.sycl_queue, self.py_func, argtypes, self.access_types
606-
)
607-
self.definitions[key_definitions] = (self.sycl_queue.sycl_context, kernel)
614+
kernel = compile_kernel(queue, self.py_func, argtypes, self.access_types)
615+
self.definitions[key_definitions] = (queue.sycl_context, kernel)
608616
return kernel

numba_dppy/decorators.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import dpctl
1516
from numba.core import sigutils, types
17+
1618
from .compiler import (
1719
compile_kernel,
1820
JitDPPYKernel,
@@ -40,20 +42,27 @@ def kernel(signature=None, access_types=None, debug=False):
4042
def autojit(debug=False, access_types=None):
4143
def _kernel_autojit(pyfunc):
4244
ordered_arg_access_types = get_ordered_arg_access_types(pyfunc, access_types)
43-
return JitDPPYKernel(pyfunc, ordered_arg_access_types)
45+
return JitDPPYKernel(pyfunc, debug, ordered_arg_access_types)
4446

4547
return _kernel_autojit
4648

4749

4850
def _kernel_jit(signature, debug, access_types):
4951
argtypes, restype = sigutils.normalize_signature(signature)
52+
5053
if restype is not None and restype != types.void:
5154
msg = "DPPY kernel must have void return type but got {restype}"
5255
raise TypeError(msg.format(restype=restype))
5356

5457
def _wrapped(pyfunc):
58+
current_queue = dpctl.get_current_queue()
5559
ordered_arg_access_types = get_ordered_arg_access_types(pyfunc, access_types)
56-
return compile_kernel(None, pyfunc, argtypes, ordered_arg_access_types, debug)
60+
# We create an instance of JitDPPYKernel to make sure at call time
61+
# we are going through the caching mechanism.
62+
dppy_kernel = JitDPPYKernel(pyfunc, debug, ordered_arg_access_types)
63+
# This will make sure we are compiling eagerly.
64+
dppy_kernel.specialize(argtypes, current_queue)
65+
return dppy_kernel
5766

5867
return _wrapped
5968

numba_dppy/tests/kernel_tests/test_atomic_op.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,10 @@ def f(a):
200200
LLVM_SPIRV_ROOT_old_val = dppy.config.LLVM_SPIRV_ROOT
201201
dppy.config.LLVM_SPIRV_ROOT = LLVM_SPIRV_ROOT
202202

203-
with dpctl.device_context(filter_str):
204-
kern = kernel[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a)
203+
with dpctl.device_context(filter_str) as sycl_queue:
204+
kern = kernel[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
205+
kernel._get_argtypes(a), sycl_queue
206+
)
205207
if filter_str != "opencl:cpu:0":
206208
assert "__spirv_AtomicFAddEXT" in kern.assembly
207209
else:
@@ -212,6 +214,8 @@ def f(a):
212214

213215
# To bypass caching
214216
kernel = dppy.kernel(f)
215-
with dpctl.device_context(filter_str):
216-
kern = kernel[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a)
217+
with dpctl.device_context(filter_str) as sycl_queue:
218+
kern = kernel[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
219+
kernel._get_argtypes(a), sycl_queue
220+
)
217221
assert "__spirv_AtomicFAddEXT" not in kern.assembly

numba_dppy/tests/kernel_tests/test_barrier.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import platform
16+
1517
import numpy as np
16-
import numba_dppy as dppy
1718
import pytest
1819
import dpctl
20+
21+
import numba_dppy as dppy
1922
from numba_dppy.tests._helper import skip_test
2023

2124

@@ -31,11 +34,26 @@ def filter_str(request):
3134
return request.param
3235

3336

37+
def skip_if_win():
38+
return platform.system == "Windows"
39+
40+
3441
def test_proper_lowering(filter_str):
3542
if skip_test(filter_str):
3643
pytest.skip()
37-
# @dppy.kernel("void(float32[::1])")
38-
@dppy.kernel
44+
45+
# We perform eager compilation at the site of
46+
# @dppy.kernel. This takes the default dpctl
47+
# queue which is level_zero backed. Level_zero
48+
# is not yet supported on Windows platform and
49+
# hence we skip these tests if the platform is
50+
# Windows regardless of which backend filter_str
51+
# specifies.
52+
if skip_if_win():
53+
pytest.skip()
54+
55+
# This will trigger eager compilation
56+
@dppy.kernel("void(float32[::1])")
3957
def twice(A):
4058
i = dppy.get_global_id(0)
4159
d = A[i]
@@ -56,8 +74,11 @@ def twice(A):
5674
def test_no_arg_barrier_support(filter_str):
5775
if skip_test(filter_str):
5876
pytest.skip()
59-
# @dppy.kernel("void(float32[::1])")
60-
@dppy.kernel
77+
78+
if skip_if_win():
79+
pytest.skip()
80+
81+
@dppy.kernel("void(float32[::1])")
6182
def twice(A):
6283
i = dppy.get_global_id(0)
6384
d = A[i]
@@ -81,8 +102,10 @@ def test_local_memory(filter_str):
81102
pytest.skip()
82103
blocksize = 10
83104

84-
# @dppy.kernel("void(float32[::1])")
85-
@dppy.kernel
105+
if skip_if_win():
106+
pytest.skip()
107+
108+
@dppy.kernel("void(float32[::1])")
86109
def reverse_array(A):
87110
lm = dppy.local.array(shape=10, dtype=np.float32)
88111
i = dppy.get_global_id(0)

numba_dppy/tests/kernel_tests/test_caching.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,16 @@ def data_parallel_sum(a, b, c):
5050
b = np.array(np.random.random(N), dtype=np.float32)
5151
c = np.ones_like(a)
5252

53-
with dpctl.device_context(filter_str):
53+
with dpctl.device_context(filter_str) as gpu_queue:
5454
func = dppy.kernel(data_parallel_sum)
55-
cached_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a, b, c)
55+
cached_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
56+
func._get_argtypes(a, b, c), gpu_queue
57+
)
5658

5759
for i in range(10):
58-
_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a, b, c)
60+
_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
61+
func._get_argtypes(a, b, c), gpu_queue
62+
)
5963
assert _kernel == cached_kernel
6064

6165

@@ -83,9 +87,14 @@ def data_parallel_sum(a, b, c):
8387
# created for that device
8488
dpctl.set_global_queue(filter_str)
8589
func = dppy.kernel(data_parallel_sum)
86-
cached_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a, b, c)
90+
default_queue = dpctl.get_current_queue()
91+
cached_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
92+
func._get_argtypes(a, b, c), default_queue
93+
)
8794
for i in range(0, 10):
8895
# Each iteration create a fresh queue that will share the same context
89-
with dpctl.device_context(filter_str):
90-
_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a, b, c)
96+
with dpctl.device_context(filter_str) as gpu_queue:
97+
_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
98+
func._get_argtypes(a, b, c), gpu_queue
99+
)
91100
assert _kernel == cached_kernel

numba_dppy/tests/test_debug.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def debug_option(request):
3232
return request.param
3333

3434

35-
def get_kernel_ir(fn, sig, debug=False):
36-
kernel = compiler.compile_kernel(fn.sycl_queue, fn.py_func, sig, None, debug=debug)
35+
def get_kernel_ir(sycl_queue, fn, sig, debug=False):
36+
kernel = compiler.compile_kernel(sycl_queue, fn.py_func, sig, None, debug=debug)
3737
return kernel.assembly
3838

3939

@@ -62,9 +62,9 @@ def test_debug_flag_generates_ir_with_debuginfo(offload_device, debug_option):
6262
def foo(x):
6363
return x
6464

65-
with dpctl.device_context(offload_device):
65+
with dpctl.device_context(offload_device) as sycl_queue:
6666
sig = (types.int32,)
67-
kernel_ir = get_kernel_ir(foo, sig, debug=debug_option)
67+
kernel_ir = get_kernel_ir(sycl_queue, foo, sig, debug=debug_option)
6868

6969
expect = debug_option
7070
got = make_check(kernel_ir)

0 commit comments

Comments
 (0)