Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cupyx/jit/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,9 +664,9 @@ def _transpile_stmt(
if isinstance(stmt, ast.Pass):
return [';']
if isinstance(stmt, ast.Break):
raise NotImplementedError('Not implemented.')
return ['break;']
if isinstance(stmt, ast.Continue):
raise NotImplementedError('Not implemented.')
return ['continue;']
assert False


Expand Down
71 changes: 71 additions & 0 deletions tests/cupyx_tests/jit_tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,77 @@ def f(x, m):
y[:mask] += 1
assert bool((x == y).all())

def test_loop_continue(self):
@jit.rawkernel()
def f(x, y, z):
tid = jit.grid(1)

for i in range(10):
# adds 0-9, except for 5.
# Sum is 40
if i == 5:
continue
x[tid] += i

i2 = 0
while i2 < 9:
# adds 1-9 in a while loop, except for 6,
# should equal 39
i2 += 1
if i2 == 6:
continue
y[tid] += i2

for i in range(11):
# adds 0-10, but skips if the sum is greater than 3*i,
# skips 8 and 9, but not 10 (28 < 3*10), sum is 38
if z[tid] > 3*i:
continue
z[tid] += i

x = cupy.zeros(32, dtype=int)
y = cupy.zeros(32, dtype=int)
z = cupy.zeros(32, dtype=int)
f[1, 32](x, y, z)
assert bool((x == 40).all())
assert bool((y == 39).all())
assert bool((z == 38).all())

def test_loop_break(self):
@jit.rawkernel()
def f(x, y, z):
tid = jit.grid(1)

for i in range(10):
# adds 0-4,
# break at 5. Sum is 10
if i == 5:
break
x[tid] += i

i2 = 0
while i2 < 9:
# adds 1-5 in a while loop,
# breaks at 6, should equal 15
i2 += 1
if i2 == 6:
break
y[tid] += i2
for i in range(11):
# adds 0-10, but stops once the sum is greater than 3*i,
# breaks at 8 (28 > 3*8), sum is 28
if z[tid] > 3*i:
break
z[tid] += i

x = cupy.zeros(32, dtype=int)
y = cupy.zeros(32, dtype=int)
z = cupy.zeros(32, dtype=int)
f[1, 32](x, y, z)
assert bool((x == 10).all())
assert bool((y == 15).all())
assert bool((z == 28).all())

def test_shared_memory_static(self):
@jit.rawkernel()
def f(x, y):
Expand Down