Skip to content

Commit 3b3e3b4

Browse files
Test numba and numba-dppy API with GDB (#566)
* Dppy vs numba debug test * Example test for GDB using pexpect * Skip of pexpect not installed * Fix using pathlib * Add API selection Co-authored-by: akharche <[email protected]>
1 parent 16d73ca commit 3b3e3b4

File tree

3 files changed

+163
-0
lines changed

3 files changed

+163
-0
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies:
2626
- black==20.8b1
2727
- pytest-cov
2828
- pytest-xdist
29+
- pexpect
2930
variables:
3031
CHANNELS: -c defaults -c numba -c intel -c numba/label/dev -c dppy/label/dev --override-channels
3132
CHANNELS_DEV: -c dppy/label/dev -c defaults -c numba -c intel -c numba/label/dev --override-channels
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2020, 2021 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
17+
import dpctl
18+
import numba
19+
import numpy as np
20+
21+
import numba_dppy as dppy
22+
23+
24+
def func(param_a, param_b):
25+
param_c = param_a + 10 # Set breakpoint
26+
param_d = param_b * 0.5
27+
result = param_c + param_d
28+
return result
29+
30+
31+
dppy_func = dppy.func(debug=True)(func)
32+
numba_func = numba.njit(debug=True)(func)
33+
34+
35+
@dppy.kernel(debug=True)
36+
def dppy_kernel(a_in_kernel, b_in_kernel, c_in_kernel):
37+
i = dppy.get_global_id(0)
38+
c_in_kernel[i] = dppy_func(a_in_kernel[i], b_in_kernel[i])
39+
40+
41+
@numba.njit(debug=True)
42+
def numba_func_driver(a, b, c):
43+
for i in range(len(c)):
44+
c[i] = numba_func(a[i], b[i])
45+
46+
47+
def main():
48+
parser = argparse.ArgumentParser()
49+
parser.add_argument(
50+
"--api",
51+
required=False,
52+
default="numba",
53+
choices=["numba", "numba-dppy"],
54+
help="Start the version of functions using numba or numba-dppy API",
55+
)
56+
57+
args = parser.parse_args()
58+
59+
print("Using API:", args.api)
60+
61+
global_size = 10
62+
N = global_size
63+
64+
a = np.arange(N, dtype=np.float32)
65+
b = np.arange(N, dtype=np.float32)
66+
c = np.empty_like(a)
67+
68+
if args.api == "numba-dppy":
69+
device = dpctl.select_default_device()
70+
with dppy.offload_to_sycl_device(device):
71+
dppy_kernel[global_size, dppy.DEFAULT_LOCAL_SIZE](a, b, c)
72+
else:
73+
numba_func_driver(a, b, c)
74+
75+
print("Done...")
76+
77+
78+
if __name__ == "__main__":
79+
main()
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#! /usr/bin/env python
2+
# Copyright 2021 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
import pathlib
18+
import shutil
19+
import sys
20+
21+
import pytest
22+
23+
import numba_dppy
24+
25+
pexpect = pytest.importorskip("pexpect")
26+
27+
pytestmark = pytest.mark.skipif(
28+
not shutil.which("gdb-oneapi"),
29+
reason="Intel® Distribution for GDB* is not available",
30+
)
31+
32+
33+
# TODO: go to helper
34+
class gdb:
35+
def __init__(self):
36+
self.spawn()
37+
self.setup_gdb()
38+
39+
def __del__(self):
40+
self.teardown_gdb()
41+
42+
def spawn(self):
43+
env = os.environ.copy()
44+
env["NUMBA_OPT"] = "0"
45+
46+
self.child = pexpect.spawn("gdb-oneapi -q python", env=env, encoding="utf-8")
47+
# self.child.logfile = sys.stdout
48+
49+
def setup_gdb(self):
50+
self.child.expect("(gdb)", timeout=5)
51+
self.child.sendline("set breakpoint pending on")
52+
self.child.expect("(gdb)", timeout=5)
53+
self.child.sendline("set style enabled off") # disable colors symbols
54+
55+
def teardown_gdb(self):
56+
self.child.expect("(gdb)", timeout=5)
57+
self.child.sendline("quit")
58+
self.child.expect("Quit anyway?", timeout=5)
59+
self.child.sendline("y")
60+
61+
def breakpoint(self, breakpoint):
62+
self.child.expect("(gdb)", timeout=5)
63+
self.child.sendline("break " + breakpoint)
64+
65+
def run(self, script):
66+
self.child.expect("(gdb)", timeout=5)
67+
self.child.sendline("run " + self.script_path(script))
68+
69+
@staticmethod
70+
def script_path(script):
71+
package_path = pathlib.Path(numba_dppy.__file__).parent
72+
return str(package_path / "examples/debug" / script)
73+
74+
75+
@pytest.mark.parametrize("api", ["numba", "numba-dppy"])
76+
def test_breakpoint_row_number(api):
77+
app = gdb()
78+
79+
app.breakpoint("dppy_numba_basic.py:25")
80+
app.run("dppy_numba_basic.py --api={api}".format(api=api))
81+
82+
app.child.expect(r"Thread .* hit Breakpoint .* at dppy_numba_basic.py:25")
83+
app.child.expect(r"25\s+param_c = param_a \+ 10")

0 commit comments

Comments
 (0)