Skip to content

Commit e1da898

Browse files
author
Diptorup Deb
committed
WIP changes...
1 parent 1e6a80f commit e1da898

File tree

11 files changed

+118
-104
lines changed

11 files changed

+118
-104
lines changed

numba_dppy/compiler.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import dpctl.program as dpctl_prog
2929
import numpy as np
3030

31-
from . import spirv_generator
31+
from . import spirv_generator, target
3232

3333
from numba.core.compiler import DefaultPassBuilder, CompilerBase
3434
from numba_dppy.dppy_parfor_diagnostics import ExtendedParforDiagnostics
@@ -143,15 +143,15 @@ def compile_kernel(sycl_queue, pyfunc, args, access_types, debug=False):
143143
if not sycl_queue:
144144
# This will be get_current_queue
145145
sycl_queue = dpctl.get_current_queue()
146-
breakpoint()
147146
cres = compile_with_dppy(pyfunc, None, args, debug=debug)
148-
func = cres.library.get_function(cres.fndesc.llvm_func_name)
149-
kernel = cres.target_context.prepare_ocl_kernel(func, cres.signature.args)
150-
# The kernel objet should have a reference to the target context it is compiled for.
151-
# This is needed as we intend to shape the behavior of the kernel down the line
152-
# depending on the target context. For example, we want to link our kernel object
153-
# with implementation containing atomic operations only when atomic operations
154-
# are being used in the kernel.
147+
kernel = cres.library.get_function(cres.fndesc.llvm_func_name)
148+
breakpoint()
149+
target.set_dppy_kernel(kernel)
150+
# kernel = cres.target_context.prepare_ocl_kernel(func, cres.signature.args)
151+
# A reference to the target context is stored in the DPPYKernel to
152+
# reference the context later in code generation. For example, we link
153+
# the kernel object with a spir_func defining atomic operations only
154+
# when atomic operations are used in the kernel.
155155
oclkern = DPPYKernel(
156156
context=cres.target_context,
157157
sycl_queue=sycl_queue,
@@ -170,7 +170,6 @@ def compile_kernel_parfor(sycl_queue, func_ir, args, args_with_addrspaces, debug
170170
print(a, type(a))
171171
if isinstance(a, types.npytypes.Array):
172172
print("addrspace:", a.addrspace)
173-
174173
cres = compile_with_dppy(func_ir, None, args_with_addrspaces, debug=debug)
175174
func = cres.library.get_function(cres.fndesc.llvm_func_name)
176175

@@ -269,7 +268,7 @@ def __init__(self, cres):
269268
def _ensure_valid_work_item_grid(val, sycl_queue):
270269

271270
if not isinstance(val, (tuple, list, int)):
272-
error_message = "Cannot create work item dimension from " "provided argument"
271+
error_message = "Cannot create work item dimension from provided argument"
273272
raise ValueError(error_message)
274273

275274
if isinstance(val, int):
@@ -290,7 +289,7 @@ def _ensure_valid_work_item_grid(val, sycl_queue):
290289
def _ensure_valid_work_group_size(val, work_item_grid):
291290

292291
if not isinstance(val, (tuple, list, int)):
293-
error_message = "Cannot create work item dimension from " "provided argument"
292+
error_message = "Cannot create work item dimension from provided argument"
294293
raise ValueError(error_message)
295294

296295
if isinstance(val, int):

numba_dppy/dppy_array_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ class DPPYArrayModel(StructModel):
8484
def __init__(self, dmm, fe_type):
8585
ndim = fe_type.ndim
8686
members = [
87-
("meminfo", types.MemInfoPointer(fe_type.dtype)),
88-
("parent", types.pyobject),
87+
("meminfo", types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace)),
88+
("parent", types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace)),
8989
("nitems", types.intp),
9090
("itemsize", types.intp),
9191
("data", types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace)),

numba_dppy/dppy_lowerer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
import dpctl
5757
from numba_dppy.target import DPPYTargetContext
5858
from numba_dppy.dppy_array_type import DPPYArray
59-
from numba_dppy.utils.constants import address_space
59+
from numba_dppy.utils import address_space
6060

6161

6262
def _print_block(block):
@@ -358,7 +358,7 @@ def addrspace_from(params, def_addr):
358358
addrspaces.append(None)
359359
return addrspaces
360360

361-
addrspaces = addrspace_from(parfor_params, address_space.SPIR_GLOBAL)
361+
addrspaces = addrspace_from(parfor_params, address_space.GLOBAL)
362362

363363
if config.DEBUG_ARRAY_OPT >= 1:
364364
print("parfor_params = ", parfor_params, type(parfor_params))

numba_dppy/ocl/ocldecl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from numba import types
16-
from numba.core.typing.npydecl import register_number_classes, parse_dtype, parse_shape
16+
from numba.core.typing.npydecl import parse_dtype, parse_shape
1717
from numba.core.typing.templates import (
1818
AttributeTemplate,
1919
ConcreteTemplate,
@@ -22,8 +22,7 @@
2222
signature,
2323
Registry,
2424
)
25-
import numba_dppy, numba_dppy as dppy
26-
from numba_dppy import target
25+
import numba_dppy as dppy
2726
from numba_dppy.dppy_array_type import DPPYArray
2827
from numba_dppy.utils import address_space
2928

@@ -169,7 +168,7 @@ def typer(shape, dtype):
169168
dtype=nb_dtype,
170169
ndim=ndim,
171170
layout="C",
172-
addrspace=address_space.SPIR_LOCAL,
171+
addrspace=address_space.LOCAL,
173172
)
174173

175174
return typer

numba_dppy/ocl/oclimpl.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from numba_dppy.codegen import SPIR_DATA_LAYOUT
3030
from numba_dppy.dppy_array_type import DPPYArray
3131
from numba_dppy.ocl.atomics import atomic_helper
32-
from numba_dppy.utils.constants import address_space
32+
from numba_dppy.utils import address_space
3333

3434
from . import stubs
3535

@@ -207,14 +207,14 @@ def insert_and_call_atomic_fn(
207207
else:
208208
raise TypeError("Atomic operation is not supported for type %s" % (dtype.name))
209209

210-
if addrspace == address_space.SPIR_LOCAL:
210+
if addrspace == address_space.LOCAL:
211211
name = name + "_local"
212212
else:
213213
name = name + "_global"
214214

215215
assert ll_p != None
216216
assert name != ""
217-
ll_p.addrspace = address_space.SPIR_GENERIC
217+
ll_p.addrspace = address_space.GENERIC
218218

219219
mod = builder.module
220220
if sig.return_type == types.void:
@@ -228,7 +228,7 @@ def insert_and_call_atomic_fn(
228228
fn = mod.get_or_insert_function(fnty, name)
229229
fn.calling_convention = target.CC_SPIR_FUNC
230230

231-
generic_ptr = context.addrspacecast(builder, ptr, address_space.SPIR_GENERIC)
231+
generic_ptr = context.addrspacecast(builder, ptr, address_space.GENERIC)
232232

233233
return builder.call(fn, [generic_ptr, val])
234234

@@ -272,10 +272,10 @@ def native_atomic_add(context, builder, sig, args):
272272

273273
ptr_type = context.get_value_type(dtype).as_pointer()
274274
if not hasattr(aryty, "addrspace"):
275-
ptr_type.addrspace = address_space.SPIR_GLOBAL
276-
ptr = context.addrspacecast(builder, ptr, address_space.SPIR_GLOBAL)
275+
ptr_type.addrspace = address_space.GLOBAL
276+
ptr = context.addrspacecast(builder, ptr, address_space.GLOBAL)
277277
else:
278-
ptr_type.addrspace = address_space.SPIR_LOCAL
278+
ptr_type.addrspace = address_space.LOCAL
279279
retty = context.get_value_type(sig.return_type)
280280
spirv_fn_arg_types = [
281281
ptr_type,
@@ -409,7 +409,7 @@ def atomic_add(context, builder, sig, args, name):
409409
lary = context.make_array(aryty)(context, builder, ary)
410410
ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices)
411411

412-
if isinstance(aryty, DPPYArray) and aryty.addrspace == address_space.SPIR_LOCAL:
412+
if isinstance(aryty, DPPYArray) and aryty.addrspace == address_space.LOCAL:
413413
return insert_and_call_atomic_fn(
414414
context,
415415
builder,
@@ -418,7 +418,7 @@ def atomic_add(context, builder, sig, args, name):
418418
dtype,
419419
ptr,
420420
val,
421-
address_space.SPIR_LOCAL,
421+
address_space.LOCAL,
422422
)
423423
else:
424424
return insert_and_call_atomic_fn(
@@ -429,7 +429,7 @@ def atomic_add(context, builder, sig, args, name):
429429
dtype,
430430
ptr,
431431
val,
432-
address_space.SPIR_GLOBAL,
432+
address_space.GLOBAL,
433433
)
434434
else:
435435
raise ImportError("Atomic support is not present, can not perform atomic_add")
@@ -445,7 +445,7 @@ def dppy_local_array_integer(context, builder, sig, args):
445445
shape=(length,),
446446
dtype=dtype,
447447
symbol_name="_dppy_lmem",
448-
addrspace=address_space.SPIR_LOCAL,
448+
addrspace=address_space.LOCAL,
449449
)
450450

451451

@@ -460,7 +460,7 @@ def dppy_local_array_tuple(context, builder, sig, args):
460460
shape=shape,
461461
dtype=dtype,
462462
symbol_name="_dppy_lmem",
463-
addrspace=address_space.SPIR_LOCAL,
463+
addrspace=address_space.LOCAL,
464464
)
465465

466466

@@ -473,7 +473,7 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
473473
lldtype = context.get_data_type(dtype)
474474
laryty = Type.array(lldtype, elemcount)
475475

476-
if addrspace == address_space.SPIR_LOCAL:
476+
if addrspace == address_space.LOCAL:
477477
lmod = builder.module
478478

479479
# Create global variable in the requested address-space
@@ -506,7 +506,7 @@ def _make_array(
506506
dtype,
507507
shape,
508508
layout="C",
509-
addrspace=address_space.SPIR_GENERIC,
509+
addrspace=address_space.GENERIC,
510510
):
511511
ndim = len(shape)
512512
# Create array object

numba_dppy/printimpl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
from numba.core import types, typing, cgutils
2020
from numba.core.imputils import Registry
2121

22-
from numba_dppy.utils.constants import address_space
22+
from numba_dppy.utils import address_space
2323

2424
registry = Registry()
2525
lower = registry.lower
2626

2727

2828
def declare_print(lmod):
29-
voidptrty = lc.Type.pointer(lc.Type.int(8), addrspace=address_space.SPIR_GENERIC)
29+
voidptrty = lc.Type.pointer(lc.Type.int(8), addrspace=address_space.GENERIC)
3030
printfty = lc.Type.function(lc.Type.int(), [voidptrty], var_arg=True)
3131
printf = lmod.get_or_insert_function(printfty, "printf")
3232
return printf

0 commit comments

Comments
 (0)