Skip to content

Commit 2ecfc94

Browse files
committed
add loop pragmas
1 parent 1adbff5 commit 2ecfc94

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

base/Base.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ include("util.jl")
347347

348348
include("asyncmap.jl")
349349

350+
# various loop pragmas
351+
include("pragmas.jl")
352+
350353
# deprecated functions
351354
include("deprecated.jl")
352355

base/pragma.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
module Pragma
2+
3+
export @unroll
4+
5+
##
6+
# Uses the loopinfo expr node to attach LLVM loopinfo to loops
7+
# the full list of supported metadata nodes is available at
8+
# https://llvm.org/docs/LangRef.html#llvm-loop
9+
# TODO:
10+
# - Figure out how to deal with compile-time constants in `@unroll(N, expr)`
11+
# so constants that come from `Val{N}` but are not parse time constant.
12+
# - Difference between `unroll_enable` and `unroll_full`
13+
# - ? Expose `unroll_disable`
14+
# - ? Expose `jam_disable`
15+
##
16+
17+
module MD
18+
disable_nonforced() = (Symbol("llvm.loop.disable_nonforced"),)
19+
interleave(n) = (Symbol("llvm.loop.interleave.count"), convert(Int, n))
20+
vectorize_enable(flag) = (Symbol("llvm.loop.vectorize.enable"), convert(Bool, flag))
21+
vectorize_width(n) = (Symbol("llvm.loop.vectorize.width"), convert(Int, n))
22+
# ‘llvm.loop.vectorize.followup_vectorized’
23+
# ‘llvm.loop.vectorize.followup_epilogue’
24+
# ‘llvm.loop.vectorize.followup_all’
25+
unroll_count(n) = (Symbol("llvm.loop.unroll.count"), convert(Int, n))
26+
unroll_disable() = (Symbol("llvm.loop.unroll.disable"),)
27+
unroll_enable() = (Symbol("llvm.loop.unroll.enable"),)
28+
unroll_full() = (Symbol("llvm.loop.unroll.full"),)
29+
# ‘llvm.loop.unroll.followup’
30+
# ‘llvm.loop.unroll.followup_remainder’
31+
jam_count(n) = (Symbol("llvm.loop.unroll_and_jam.count"), convert(Int, n))
32+
jam_disable() = (Symbol("llvm.loop.unroll_and_jam.disable"),)
33+
jam_enable() = (Symbol("llvm.loop.unroll_and_jam.enable"),)
34+
# ‘llvm.loop.unroll_and_jam.followup_outer’
35+
# ‘llvm.loop.unroll_and_jam.followup_inner’
36+
# ‘llvm.loop.unroll_and_jam.followup_remainder_outer’
37+
# ‘llvm.loop.unroll_and_jam.followup_remainder_inner’
38+
# ‘llvm.loop.unroll_and_jam.followup_all’
39+
# ‘llvm.loop.licm_versioning.disable’
40+
# ‘llvm.loop.distribute.enable’
41+
# ‘llvm.loop.distribute.followup_coincident’
42+
# ‘llvm.loop.distribute.followup_sequential’
43+
# ‘llvm.loop.distribute.followup_fallback’
44+
# ‘llvm.loop.distribute.followup_all’
45+
end
46+
47+
function loopinfo(name, expr, nodes...)
48+
if expr.head != :for
49+
error("Syntax error: pragma $name needs a for loop")
50+
end
51+
push!(expr.args[2].args, Expr(:loopinfo, nodes...))
52+
return expr
53+
end
54+
55+
"""
56+
@unroll expr
57+
58+
Takes a for loop as `expr` and informs the LLVM unroller to fully unroll it, if
59+
it is safe to do so and the loop count is known.
60+
"""
61+
macro unroll(expr)
62+
expr = loopinfo("@unroll", expr, MD.unroll_full())
63+
return esc(expr)
64+
end
65+
66+
"""
67+
@unroll N expr
68+
69+
Takes a for loop as `expr` and informs the LLVM unroller to unroll it `N` times,
70+
if it is safe to do so.
71+
"""
72+
macro unroll(N, expr)
73+
if !(N isa Integer)
74+
error("Syntax error: `@unroll N expr` needs a constant integer N")
75+
end
76+
expr = loopinfo("@unroll", expr, MD.unroll_count(N))
77+
return esc(expr)
78+
end
79+
80+
macro jam(N, expr)
81+
if !(N isa Integer)
82+
error("Syntax error: `@jam N expr` needs a constant integer N")
83+
end
84+
expr = loopinfo("@jam", expr, MD.jam_count(N))
85+
return esc(expr)
86+
end
87+
88+
macro jam(expr)
89+
expr = loopinfo("@jam", expr, MD.jam_enable())
90+
return esc(expr)
91+
end
92+
93+
end #module

0 commit comments

Comments
 (0)