Skip to content

Fix eager compilation with signature for dppy.kernel #291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions numba_dppy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def compile_kernel(sycl_queue, pyfunc, args, access_types, debug=False):
print("compile_kernel", args)
debug = True
if not sycl_queue:
# This will be get_current_queue
sycl_queue = dpctl.get_current_queue()
# We expect the sycl_queue to be provided when this function is called
raise ValueError("SYCL queue is required for compiling a kernel")

cres = compile_with_dppy(pyfunc, None, args, debug=debug)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
Expand Down Expand Up @@ -561,32 +561,42 @@ def check_for_invalid_access_type(self, access_type):


class JitDPPYKernel(DPPYKernelBase):
def __init__(self, func, access_types):
def __init__(self, func, debug, access_types):

super(JitDPPYKernel, self).__init__()

self.py_func = func
self.definitions = {}
self.debug = debug
self.access_types = access_types

from .descriptor import dppy_target

self.typingctx = dppy_target.typing_context

def _get_argtypes(self, *args):
"""
Convenience function to get the type of each argument.
"""
return tuple([self.typingctx.resolve_argument_type(a) for a in args])

def __call__(self, *args, **kwargs):
assert not kwargs, "Keyword Arguments are not supported"
if self.sycl_queue is None:
try:
self.sycl_queue = dpctl.get_current_queue()
except:
_raise_no_device_found_error()
try:
current_queue = dpctl.get_current_queue()
except:
_raise_no_device_found_error()

kernel = self.specialize(*args)
argtypes = self._get_argtypes(*args)
kernel = self.specialize(argtypes, current_queue)
cfg = kernel.configure(self.sycl_queue, self.global_size, self.local_size)
cfg(*args)

def specialize(self, *args):
argtypes = tuple([self.typingctx.resolve_argument_type(a) for a in args])
def specialize(self, argtypes, queue):
# We specialize for argtypes and queue. These two are used as key for
# caching as well.
assert queue is not None

sycl_ctx = None
kernel = None
# we were previously using the _env_ptr of the device_env, the sycl_queue
Expand All @@ -598,11 +608,9 @@ def specialize(self, *args):
if result:
sycl_ctx, kernel = result

if sycl_ctx and sycl_ctx == self.sycl_queue.sycl_context:
if sycl_ctx and sycl_ctx == queue.sycl_context:
return kernel
else:
kernel = compile_kernel(
self.sycl_queue, self.py_func, argtypes, self.access_types
)
self.definitions[key_definitions] = (self.sycl_queue.sycl_context, kernel)
kernel = compile_kernel(queue, self.py_func, argtypes, self.access_types)
self.definitions[key_definitions] = (queue.sycl_context, kernel)
return kernel
13 changes: 11 additions & 2 deletions numba_dppy/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dpctl
from numba.core import sigutils, types

from .compiler import (
compile_kernel,
JitDPPYKernel,
Expand Down Expand Up @@ -40,20 +42,27 @@ def kernel(signature=None, access_types=None, debug=False):
def autojit(debug=False, access_types=None):
def _kernel_autojit(pyfunc):
ordered_arg_access_types = get_ordered_arg_access_types(pyfunc, access_types)
return JitDPPYKernel(pyfunc, ordered_arg_access_types)
return JitDPPYKernel(pyfunc, debug, ordered_arg_access_types)

return _kernel_autojit


def _kernel_jit(signature, debug, access_types):
argtypes, restype = sigutils.normalize_signature(signature)

if restype is not None and restype != types.void:
msg = "DPPY kernel must have void return type but got {restype}"
raise TypeError(msg.format(restype=restype))

def _wrapped(pyfunc):
current_queue = dpctl.get_current_queue()
ordered_arg_access_types = get_ordered_arg_access_types(pyfunc, access_types)
return compile_kernel(None, pyfunc, argtypes, ordered_arg_access_types, debug)
# We create an instance of JitDPPYKernel to make sure at call time
# we are going through the caching mechanism.
dppy_kernel = JitDPPYKernel(pyfunc, debug, ordered_arg_access_types)
# This will make sure we are compiling eagerly.
dppy_kernel.specialize(argtypes, current_queue)
return dppy_kernel

return _wrapped

Expand Down
12 changes: 8 additions & 4 deletions numba_dppy/tests/kernel_tests/test_atomic_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,10 @@ def f(a):
LLVM_SPIRV_ROOT_old_val = dppy.config.LLVM_SPIRV_ROOT
dppy.config.LLVM_SPIRV_ROOT = LLVM_SPIRV_ROOT

with dpctl.device_context(filter_str):
kern = kernel[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a)
with dpctl.device_context(filter_str) as sycl_queue:
kern = kernel[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
kernel._get_argtypes(a), sycl_queue
)
if filter_str != "opencl:cpu:0":
assert "__spirv_AtomicFAddEXT" in kern.assembly
else:
Expand All @@ -212,6 +214,8 @@ def f(a):

# To bypass caching
kernel = dppy.kernel(f)
with dpctl.device_context(filter_str):
kern = kernel[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a)
with dpctl.device_context(filter_str) as sycl_queue:
kern = kernel[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
kernel._get_argtypes(a), sycl_queue
)
assert "__spirv_AtomicFAddEXT" not in kern.assembly
37 changes: 30 additions & 7 deletions numba_dppy/tests/kernel_tests/test_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import platform

import numpy as np
import numba_dppy as dppy
import pytest
import dpctl

import numba_dppy as dppy
from numba_dppy.tests._helper import skip_test


Expand All @@ -31,11 +34,26 @@ def filter_str(request):
return request.param


def skip_if_win():
return platform.system == "Windows"


def test_proper_lowering(filter_str):
if skip_test(filter_str):
pytest.skip()
# @dppy.kernel("void(float32[::1])")
@dppy.kernel

# We perform eager compilation at the site of
# @dppy.kernel. This takes the default dpctl
# queue which is level_zero backed. Level_zero
# is not yet supported on Windows platform and
# hence we skip these tests if the platform is
# Windows regardless of which backend filter_str
# specifies.
if skip_if_win():
pytest.skip()

# This will trigger eager compilation
@dppy.kernel("void(float32[::1])")
def twice(A):
i = dppy.get_global_id(0)
d = A[i]
Expand All @@ -56,8 +74,11 @@ def twice(A):
def test_no_arg_barrier_support(filter_str):
if skip_test(filter_str):
pytest.skip()
# @dppy.kernel("void(float32[::1])")
@dppy.kernel

if skip_if_win():
pytest.skip()

@dppy.kernel("void(float32[::1])")
def twice(A):
i = dppy.get_global_id(0)
d = A[i]
Expand All @@ -81,8 +102,10 @@ def test_local_memory(filter_str):
pytest.skip()
blocksize = 10

# @dppy.kernel("void(float32[::1])")
@dppy.kernel
if skip_if_win():
pytest.skip()

@dppy.kernel("void(float32[::1])")
def reverse_array(A):
lm = dppy.local.array(shape=10, dtype=np.float32)
i = dppy.get_global_id(0)
Expand Down
21 changes: 15 additions & 6 deletions numba_dppy/tests/kernel_tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,16 @@ def data_parallel_sum(a, b, c):
b = np.array(np.random.random(N), dtype=np.float32)
c = np.ones_like(a)

with dpctl.device_context(filter_str):
with dpctl.device_context(filter_str) as gpu_queue:
func = dppy.kernel(data_parallel_sum)
cached_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a, b, c)
cached_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
func._get_argtypes(a, b, c), gpu_queue
)

for i in range(10):
_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a, b, c)
_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
func._get_argtypes(a, b, c), gpu_queue
)
assert _kernel == cached_kernel


Expand Down Expand Up @@ -83,9 +87,14 @@ def data_parallel_sum(a, b, c):
# created for that device
dpctl.set_global_queue(filter_str)
func = dppy.kernel(data_parallel_sum)
cached_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a, b, c)
default_queue = dpctl.get_current_queue()
cached_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
func._get_argtypes(a, b, c), default_queue
)
for i in range(0, 10):
# Each iteration create a fresh queue that will share the same context
with dpctl.device_context(filter_str):
_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a, b, c)
with dpctl.device_context(filter_str) as gpu_queue:
_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
func._get_argtypes(a, b, c), gpu_queue
)
assert _kernel == cached_kernel
8 changes: 4 additions & 4 deletions numba_dppy/tests/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def debug_option(request):
return request.param


def get_kernel_ir(fn, sig, debug=False):
kernel = compiler.compile_kernel(fn.sycl_queue, fn.py_func, sig, None, debug=debug)
def get_kernel_ir(sycl_queue, fn, sig, debug=False):
kernel = compiler.compile_kernel(sycl_queue, fn.py_func, sig, None, debug=debug)
return kernel.assembly


Expand Down Expand Up @@ -62,9 +62,9 @@ def test_debug_flag_generates_ir_with_debuginfo(offload_device, debug_option):
def foo(x):
return x

with dpctl.device_context(offload_device):
with dpctl.device_context(offload_device) as sycl_queue:
sig = (types.int32,)
kernel_ir = get_kernel_ir(foo, sig, debug=debug_option)
kernel_ir = get_kernel_ir(sycl_queue, foo, sig, debug=debug_option)

expect = debug_option
got = make_check(kernel_ir)
Expand Down