Skip to content

Commit 73a0794

Browse files
Python script adopted from IntelPython/numba_dpex#147
This is an adaptation of pipelining technique shared by @mbecker in https://github.com/IntelPython/numbda_dpex/issues/147 This is built to work with async-ref-count-increment branch IntelPython/dpctl#1395 which implements asynchronous memcpy, asynchronous submit and asynchronous keep_arg_alve task submission.
1 parent 8855115 commit 73a0794

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed

gh_147_pipeline.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import numpy as np
2+
import sys
3+
import time
4+
5+
import dpctl
6+
import dpctl.program
7+
import dpctl.tensor as dpt
8+
import ctypes
9+
10+
spirv_file = "./increment_by_one.spv"
11+
with open(spirv_file, "rb") as fin:
12+
spirv = fin.read()
13+
program_cache = dict()
14+
15+
def increment_by_one(an_array, gws, lws):
16+
q = an_array.sycl_queue
17+
if q.sycl_context in program_cache:
18+
prog = program_cache[q.sycl_context]
19+
else:
20+
global spirv
21+
prog = dpctl.program.create_program_from_spirv(q, spirv)
22+
krn = prog.get_sycl_kernel("increment_by_one")
23+
24+
args = [an_array.usm_data, ctypes.c_uint32(an_array.size),]
25+
return q.submit_async(krn, args, [gws,], [lws,])
26+
27+
28+
def run_serial(a, gws, lws, n_itr):
29+
q = dpctl.SyclQueue(property=["in_order", "enable_profiling"])
30+
31+
timer_t = dpctl.SyclTimer()
32+
timer_c = dpctl.SyclTimer()
33+
34+
a_host = dpt.asarray(a, usm_type="host", sycl_queue=q)
35+
a_host_data = a_host.usm_data
36+
37+
t0 = time.time()
38+
for _ in range(n_itr):
39+
with timer_t(q):
40+
_a = dpt.empty(a_host.shape, usm_type="device", sycl_queue=q)
41+
_a_data = _a.usm_data
42+
e_copy = q.memcpy_async(_a.usm_data, a_host_data, a_host_data.nbytes)
43+
44+
with timer_c(q):
45+
e_compute = increment_by_one(_a, gws, lws)
46+
47+
q.wait()
48+
dt = time.time() - t0
49+
50+
return dt, timer_t.dt, timer_c.dt
51+
52+
53+
def run_pipeline(a, gws, lws, n_itr):
54+
q_a = dpctl.SyclQueue(property=["in_order", "enable_profiling"])
55+
q_b = dpctl.SyclQueue(property=["in_order", "enable_profiling"])
56+
57+
timer_t = dpctl.SyclTimer()
58+
timer_c = dpctl.SyclTimer()
59+
60+
a_host = dpt.asarray(a, usm_type="host", sycl_queue=q_a)
61+
a_host_data = a_host.usm_data
62+
63+
t0 = time.time()
64+
with timer_t(q_a):
65+
_a = dpt.empty(a_host.shape, usm_type="device", sycl_queue=q_a)
66+
_a_data = _a.usm_data
67+
e_copy_a = q_a.memcpy_async(_a_data, a_host_data, a_host_data.nbytes)
68+
69+
for i in range(n_itr-1):
70+
if i % 2 == 0:
71+
with timer_t(q_b):
72+
_b = dpt.empty(a_host.shape, usm_type="device", sycl_queue=q_b)
73+
_b_data = _b.usm_data
74+
e_copy_b = q_b.memcpy_async(_b_data, a_host_data, a_host_data.nbytes)
75+
76+
with timer_c(q_a):
77+
e_compute_a = increment_by_one(_a, gws, lws)
78+
79+
else:
80+
with timer_t(q_a):
81+
_a = dpt.empty(a_host.shape, usm_type="device", sycl_queue=q_a)
82+
_a_data = _a.usm_data
83+
e_copy_a = q_a.memcpy_async(_a_data, a_host_data, a_host_data.nbytes)
84+
85+
with timer_c(q_b):
86+
e_compute_b = increment_by_one(_b, gws, lws)
87+
88+
if n_itr % 2 == 0:
89+
with timer_c(q_b):
90+
e_compute_b = increment_by_one(_b, gws, lws)
91+
else:
92+
with timer_c(q_a):
93+
e_compute_a = increment_by_one(_a, gws, lws)
94+
95+
q_a.wait()
96+
q_b.wait()
97+
dt = time.time() - t0
98+
99+
return dt, timer_t.dt, timer_c.dt
100+
101+
102+
if len(sys.argv) > 1:
103+
n = int(sys.argv[1])
104+
else:
105+
n = 2_000_000
106+
107+
if len(sys.argv) > 2:
108+
n_itr = int(sys.argv[2])
109+
else:
110+
n_itr = 100
111+
112+
113+
print("timing %d elements for %d iterations" % (n, n_itr), flush=True)
114+
115+
print("using %f MB of memory" % (n * 4 /1024/1024), flush=True)
116+
117+
a = np.arange(n, dtype=np.float32)
118+
119+
lws = 32
120+
gws = ((a.size + (lws - 1)) // lws) * lws
121+
122+
dtp = run_pipeline(a, gws, lws, n_itr)
123+
print(f"pipeline time tot|pci|cmp|speedup: {dtp}", flush=True)
124+
125+
dts = run_serial(a, gws, lws, n_itr)
126+
print(f"serial time tot|pci|cmp|speedup: {dts}", flush=True)

0 commit comments

Comments
 (0)