Skip to content

Fix eager compilation with signature for dppy.kernel #291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 6, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions numba_dppy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def compile_kernel(sycl_queue, pyfunc, args, access_types, debug=False):
print("compile_kernel", args)
debug = True
if not sycl_queue:
# This will be get_current_queue
sycl_queue = dpctl.get_current_queue()
# We expect the sycl_queue to be provided when this function is called
raise ValueError("SYCL queue is required for compiling a kernel")

cres = compile_with_dppy(pyfunc, None, args, debug=debug)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
Expand Down Expand Up @@ -545,32 +545,39 @@ def check_for_invalid_access_type(self, access_type):


class JitDPPYKernel(DPPYKernelBase):
def __init__(self, func, access_types):
def __init__(self, func, debug, access_types):

super(JitDPPYKernel, self).__init__()

self.py_func = func
self.definitions = {}
self.debug = debug
self.access_types = access_types

from .descriptor import dppy_target

self.typingctx = dppy_target.typing_context

def get_argtypes(self, *args):
# Convenience function to get the type of each argument.
return tuple([self.typingctx.resolve_argument_type(a) for a in args])

def __call__(self, *args, **kwargs):
assert not kwargs, "Keyword Arguments are not supported"
if self.sycl_queue is None:
try:
self.sycl_queue = dpctl.get_current_queue()
except:
_raise_no_device_found_error()
try:
current_queue = dpctl.get_current_queue()
except:
_raise_no_device_found_error()

kernel = self.specialize(*args)
argtypes = self.get_argtypes(*args)
kernel = self.specialize(argtypes, current_queue)
cfg = kernel.configure(self.sycl_queue, self.global_size, self.local_size)
cfg(*args)

def specialize(self, *args):
argtypes = tuple([self.typingctx.resolve_argument_type(a) for a in args])
def specialize(self, argtypes, queue):
# We specialize for argtypes and queue. These two are used as key for
# caching as well.
assert queue is not None
q = None
kernel = None
# we were previously using the _env_ptr of the device_env, the sycl_queue
Expand All @@ -582,11 +589,9 @@ def specialize(self, *args):
if result:
q, kernel = result

if q and self.sycl_queue.equals(q):
if q and queue.equals(q):
return kernel
else:
kernel = compile_kernel(
self.sycl_queue, self.py_func, argtypes, self.access_types
)
self.definitions[key_definitions] = (self.sycl_queue, kernel)
kernel = compile_kernel(queue, self.py_func, argtypes, self.access_types)
self.definitions[key_definitions] = (queue, kernel)
return kernel
13 changes: 10 additions & 3 deletions numba_dppy/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function, absolute_import, division
from numba.core import sigutils, types
from .compiler import (
compile_kernel,
Expand All @@ -21,6 +20,7 @@
compile_dppy_func,
get_ordered_arg_access_types,
)
import dpctl


def kernel(signature=None, access_types=None, debug=False):
Expand All @@ -41,20 +41,27 @@ def kernel(signature=None, access_types=None, debug=False):
def autojit(debug=False, access_types=None):
def _kernel_autojit(pyfunc):
ordered_arg_access_types = get_ordered_arg_access_types(pyfunc, access_types)
return JitDPPYKernel(pyfunc, ordered_arg_access_types)
return JitDPPYKernel(pyfunc, debug, ordered_arg_access_types)

return _kernel_autojit


def _kernel_jit(signature, debug, access_types):
argtypes, restype = sigutils.normalize_signature(signature)

if restype is not None and restype != types.void:
msg = "DPPY kernel must have void return type but got {restype}"
raise TypeError(msg.format(restype=restype))

def _wrapped(pyfunc):
current_queue = dpctl.get_current_queue()
ordered_arg_access_types = get_ordered_arg_access_types(pyfunc, access_types)
return compile_kernel(None, pyfunc, argtypes, ordered_arg_access_types, debug)
# We create an instance of JitDPPYKernel to make sure at call time
# we are going through the caching mechanism.
dppy_kernel = JitDPPYKernel(pyfunc, debug, ordered_arg_access_types)
# This will make sure we are compiling eagerly.
dppy_kernel.specialize(argtypes, current_queue)
return dppy_kernel

return _wrapped

Expand Down
68 changes: 41 additions & 27 deletions numba_dppy/tests/kernel_tests/test_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,18 @@ def filter_str(request):
def test_proper_lowering(filter_str):
if skip_test(filter_str):
pytest.skip()
# @dppy.kernel("void(float32[::1])")
@dppy.kernel
def twice(A):
i = dppy.get_global_id(0)
d = A[i]
dppy.barrier(dppy.CLK_LOCAL_MEM_FENCE) # local mem fence
A[i] = d * 2

try:
# This will trigger eager compilation
@dppy.kernel("void(float32[::1])")
def twice(A):
i = dppy.get_global_id(0)
d = A[i]
dppy.barrier(dppy.CLK_LOCAL_MEM_FENCE) # local mem fence
A[i] = d * 2

except:
pytest.skip()

N = 256
arr = np.random.random(N).astype(np.float32)
Expand All @@ -56,14 +61,19 @@ def twice(A):
def test_no_arg_barrier_support(filter_str):
if skip_test(filter_str):
pytest.skip()
# @dppy.kernel("void(float32[::1])")
@dppy.kernel
def twice(A):
i = dppy.get_global_id(0)
d = A[i]
# no argument defaults to global mem fence
dppy.barrier()
A[i] = d * 2

try:

@dppy.kernel("void(float32[::1])")
def twice(A):
i = dppy.get_global_id(0)
d = A[i]
# no argument defaults to global mem fence
dppy.barrier()
A[i] = d * 2

except:
pytest.skip()

N = 256
arr = np.random.random(N).astype(np.float32)
Expand All @@ -81,18 +91,22 @@ def test_local_memory(filter_str):
pytest.skip()
blocksize = 10

# @dppy.kernel("void(float32[::1])")
@dppy.kernel
def reverse_array(A):
lm = dppy.local.array(shape=10, dtype=np.float32)
i = dppy.get_global_id(0)

# preload
lm[i] = A[i]
# barrier local or global will both work as we only have one work group
dppy.barrier(dppy.CLK_LOCAL_MEM_FENCE) # local mem fence
# write
A[i] += lm[blocksize - 1 - i]
try:

@dppy.kernel("void(float32[::1])")
def reverse_array(A):
lm = dppy.local.array(shape=10, dtype=np.float32)
i = dppy.get_global_id(0)

# preload
lm[i] = A[i]
# barrier local or global will both work as we only have one work group
dppy.barrier(dppy.CLK_LOCAL_MEM_FENCE) # local mem fence
# write
A[i] += lm[blocksize - 1 - i]

except:
pytest.skip()

arr = np.arange(blocksize).astype(np.float32)
orig = arr.copy()
Expand Down
6 changes: 4 additions & 2 deletions numba_dppy/tests/kernel_tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def test_caching_kernel(filter_str):

with dpctl.device_context(filter_str) as gpu_queue:
func = dppy.kernel(data_parallel_sum)
caching_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(a, b, c)
caching_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
func.get_argtypes(a, b, c), gpu_queue
)

for i in range(10):
cached_kernel = func[global_size, dppy.DEFAULT_LOCAL_SIZE].specialize(
a, b, c
func.get_argtypes(a, b, c), gpu_queue
)
assert caching_kernel == cached_kernel