Skip to content

Commit efaf290

Browse files
committed
Calculate whether an inlining is safe
It's safe if all top level arguments are used at most once, meaning that there's no risk of duplication.
1 parent 251d77b commit efaf290

File tree

6 files changed

+206
-31
lines changed

6 files changed

+206
-31
lines changed

libs/prelude/Prelude/Interfaces.idr

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ public export
103103
($>) fa b = map (const b) fa
104104

105105
||| Run something for effects, throwing away the return value.
106+
%inline
106107
public export
107108
ignore : Functor f => f a -> f ()
108109
ignore = map (const ())

src/TTImp/Elab/Case.idr

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module TTImp.Elab.Case
22

3+
import Core.CaseTree
34
import Core.Context
45
import Core.Context.Log
56
import Core.Core
@@ -226,13 +227,6 @@ caseBlock {vars} rigc elabinfo fc nest env scr scrtm scrty caseRig alts expected
226227
setFlag fc (Resolved cidx) (SetTotal PartialOK)
227228
let caseRef : Term vars = Ref fc Func (Resolved cidx)
228229

229-
-- If there's no duplication of the scrutinee in the block,
230-
-- inline it.
231-
-- This will be the case either if the scrutinee is a variable, in
232-
-- which case the duplication won't hurt, or if (TODO) none of the
233-
-- case patterns in alts are just a variable
234-
maybe (pure ()) (const (setFlag fc casen Inline)) splitOn
235-
236230
let applyEnv = applyToFull fc caseRef env
237231
let appTm : Term vars
238232
= maybe (App fc applyEnv scrtm)
@@ -254,6 +248,17 @@ caseBlock {vars} rigc elabinfo fc nest env scr scrtm scrty caseRig alts expected
254248
let olddelayed = delayedElab ust
255249
put UST (record { delayedElab = [] } ust)
256250
processDecl [InCase] nest' [] (IDef fc casen alts')
251+
252+
-- If there's no duplication of the scrutinee in the block,
253+
-- flag it as inlinable.
254+
-- This will be the case either if the scrutinee is a variable, in
255+
-- which case the duplication won't hurt, or if there's no variable
256+
-- duplicated in the body (what ghc calls W-safe)
257+
-- We'll check that second condition later, after generating the
258+
-- runtime (erased) case trees
259+
let inlineOK = maybe False (const True) splitOn
260+
when inlineOK $ setFlag fc casen Inline
261+
257262
ust <- get UST
258263
put UST (record { delayedElab = olddelayed } ust)
259264

src/TTImp/Elab/Utils.idr

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module TTImp.Elab.Utils
22

3+
import Core.CaseTree
34
import Core.Context
45
import Core.Core
56
import Core.Env
@@ -116,3 +117,169 @@ bindReq {vs = n :: _} fc (b :: env) (KeepCons p) ns tm
116117
(Bind fc _ (Pi (binderLoc b) (multiplicity b) Explicit (binderType b')) tm)
117118
bindReq fc (b :: env) (DropCons p) ns tm
118119
= bindReq fc env p ns tm
120+
121+
-- This machinery is to calculate whether any top level argument is used
122+
-- more than once in a case block, in which case inlining wouldn't be safe
123+
-- since it might duplicate work.
124+
125+
data ArgUsed = Used1 -- been used
126+
| Used0 -- not used
127+
| LocalVar -- don't care if it's used
128+
129+
data Usage : List Name -> Type where
130+
Nil : Usage []
131+
(::) : ArgUsed -> Usage xs -> Usage (x :: xs)
132+
133+
initUsed : (xs : List Name) -> Usage xs
134+
initUsed [] = []
135+
initUsed (x :: xs) = Used0 :: initUsed xs
136+
137+
initUsedCase : (xs : List Name) -> Usage xs
138+
initUsedCase [] = []
139+
initUsedCase [x] = [Used0]
140+
initUsedCase (x :: xs) = LocalVar :: initUsedCase xs
141+
142+
setUsedVar : {idx : _} ->
143+
(0 _ : IsVar n idx xs) -> Usage xs -> Usage xs
144+
setUsedVar First (Used0 :: us) = Used1 :: us
145+
setUsedVar (Later p) (x :: us) = x :: setUsedVar p us
146+
setUsedVar First us = us
147+
148+
isUsed : {idx : _} ->
149+
(0 _ : IsVar n idx xs) -> Usage xs -> Bool
150+
isUsed First (Used1 :: us) = True
151+
isUsed First (_ :: us) = False
152+
isUsed (Later p) (_ :: us) = isUsed p us
153+
154+
data Used : Type where
155+
156+
setUsed : {idx : _} ->
157+
{auto u : Ref Used (Usage vars)} ->
158+
(0 _ : IsVar n idx vars) -> Core ()
159+
setUsed p
160+
= do used <- get Used
161+
put Used (setUsedVar p used)
162+
163+
extendUsed : ArgUsed -> (new : List Name) -> Usage vars -> Usage (new ++ vars)
164+
extendUsed a [] x = x
165+
extendUsed a (y :: xs) x = a :: extendUsed a xs x
166+
167+
dropUsed : (new : List Name) -> Usage (new ++ vars) -> Usage vars
168+
dropUsed [] x = x
169+
dropUsed (x :: xs) (u :: us) = dropUsed xs us
170+
171+
inExtended : ArgUsed -> (new : List Name) ->
172+
{auto u : Ref Used (Usage vars)} ->
173+
(Ref Used (Usage (new ++ vars)) -> Core a) ->
174+
Core a
175+
inExtended a new sc
176+
= do used <- get Used
177+
u' <- newRef Used (extendUsed a new used)
178+
res <- sc u'
179+
put Used (dropUsed new !(get Used @{u'}))
180+
pure res
181+
182+
termInlineSafe : {vars : _} ->
183+
{auto u : Ref Used (Usage vars)} ->
184+
Term vars -> Core Bool
185+
termInlineSafe (Local fc isLet idx p)
186+
= if isUsed p !(get Used)
187+
then pure False
188+
else do setUsed p
189+
pure True
190+
termInlineSafe (Meta fc x y xs)
191+
= allInlineSafe xs
192+
where
193+
allInlineSafe : List (Term vars) -> Core Bool
194+
allInlineSafe [] = pure True
195+
allInlineSafe (x :: xs)
196+
= do xok <- termInlineSafe x
197+
if xok
198+
then allInlineSafe xs
199+
else pure False
200+
termInlineSafe (Bind fc x b scope)
201+
= do bok <- binderInlineSafe b
202+
if bok
203+
then inExtended LocalVar [x] (\u' => termInlineSafe scope)
204+
else pure False
205+
where
206+
binderInlineSafe : Binder (Term vars) -> Core Bool
207+
binderInlineSafe (Let _ _ val _) = termInlineSafe val
208+
binderInlineSafe _ = pure True
209+
termInlineSafe (App fc fn arg)
210+
= do fok <- termInlineSafe fn
211+
if fok
212+
then termInlineSafe arg
213+
else pure False
214+
termInlineSafe (As fc x as pat) = termInlineSafe pat
215+
termInlineSafe (TDelayed fc x ty) = termInlineSafe ty
216+
termInlineSafe (TDelay fc x ty arg) = termInlineSafe arg
217+
termInlineSafe (TForce fc x val) = termInlineSafe val
218+
termInlineSafe _ = pure True
219+
220+
mutual
221+
caseInlineSafe : {vars : _} ->
222+
{auto u : Ref Used (Usage vars)} ->
223+
CaseTree vars -> Core Bool
224+
caseInlineSafe (Case idx p scTy xs)
225+
= if isUsed p !(get Used)
226+
then pure False
227+
else do setUsed p
228+
altsSafe xs
229+
where
230+
altsSafe : List (CaseAlt vars) -> Core Bool
231+
altsSafe [] = pure True
232+
altsSafe (a :: as)
233+
= do u <- get Used
234+
aok <- caseAltInlineSafe a
235+
if aok
236+
then do -- We can reset the usage information, because we're
237+
-- only going to use one alternative at a time
238+
put Used u
239+
altsSafe as
240+
else pure False
241+
caseInlineSafe (STerm x tm) = termInlineSafe tm
242+
caseInlineSafe (Unmatched msg) = pure True
243+
caseInlineSafe Impossible = pure True
244+
245+
caseAltInlineSafe : {vars : _} ->
246+
{auto u : Ref Used (Usage vars)} ->
247+
CaseAlt vars -> Core Bool
248+
caseAltInlineSafe (ConCase x tag args sc)
249+
= inExtended Used0 args (\u' => caseInlineSafe sc)
250+
caseAltInlineSafe (DelayCase ty arg sc)
251+
= inExtended Used0 [ty, arg] (\u' => caseInlineSafe sc)
252+
caseAltInlineSafe (ConstCase x sc) = caseInlineSafe sc
253+
caseAltInlineSafe (DefaultCase sc) = caseInlineSafe sc
254+
255+
-- An inlining is safe if no variable is used more than once in the tree,
256+
-- which means that there's no risk of an input being evaluated more than
257+
-- once after the definition is expanded.
258+
export
259+
inlineSafe : {vars : _} ->
260+
CaseTree vars -> Core Bool
261+
inlineSafe t
262+
= do u <- newRef Used (initUsed vars)
263+
caseInlineSafe t
264+
265+
export
266+
canInlineDef : {auto c : Ref Ctxt Defs} ->
267+
Name -> Core Bool
268+
canInlineDef n
269+
= do defs <- get Ctxt
270+
Just (PMDef _ _ _ rtree _) <- lookupDefExact n (gamma defs)
271+
| _ => pure False
272+
inlineSafe rtree
273+
274+
-- This is a special case because the only argument we actually care about
275+
-- is the last one, since the others are just variables passed through from
276+
-- the environment, and duplicating a variable doesn't cost anything.
277+
export
278+
canInlineCaseBlock : {auto c : Ref Ctxt Defs} ->
279+
Name -> Core Bool
280+
canInlineCaseBlock n
281+
= do defs <- get Ctxt
282+
Just (PMDef _ vars _ rtree _) <- lookupDefExact n (gamma defs)
283+
| _ => pure False
284+
u <- newRef Used (initUsedCase vars)
285+
caseInlineSafe rtree

src/TTImp/ProcessDef.idr

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,22 @@ mkRunTime fc n
750750
ignore $ addDef n $
751751
record { definition = PMDef r rargs tree_ct tree_rt pats
752752
} gdef
753+
-- If it's a case block, and not already set as inlinable,
754+
-- check if it's safe to inline
755+
when (caseName !(toFullNames n) && noInline (flags gdef)) $
756+
do inl <- canInlineCaseBlock n
757+
when inl $ setFlag fc n Inline
753758
where
759+
noInline : List DefFlag -> Bool
760+
noInline (Inline :: _) = False
761+
noInline (x :: xs) = noInline xs
762+
noInline _ = True
763+
764+
caseName : Name -> Bool
765+
caseName (CaseBlock _ _) = True
766+
caseName (NS _ n) = caseName n
767+
caseName _ = False
768+
754769
mkCrash : {vars : _} -> String -> Term vars
755770
mkCrash msg
756771
= apply fc (Ref fc Func (NS builtinNS (UN "idris_crash")))

tests/codegen/builtin001/expected

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@ prim__mul_Integer = [{arg:N}, {arg:N}]: (*Integer [!{arg:N}, !{arg:N}])
55
Main.plus = [{arg:N}, {arg:N}]: (%case !{arg:N} [(%constcase 0 !{arg:N})] Just (%let {e:N} (-Integer [!{arg:N}, 1]) (+Integer [1, (Main.plus [!{e:N}, !{arg:N}])])))
66
Main.main = [{ext:N}]: (Main.plus [(+Integer [1, 0]), (+Integer [1, (+Integer [1, 0])])])
77
Builtin.believe_me = [{ext:N}]: (believe_me [___, ___, !{ext:N}])
8-
Prelude.Types.case block in prim__integerToNat = [{arg:N}, {arg:N}]: (%case !{arg:N} [(%constcase 1 (Builtin.believe_me [!{arg:N}])), (%constcase 0 0)] Nothing)
9-
Prelude.Types.prim__integerToNat = [{arg:N}]: (Prelude.Types.case block in prim__integerToNat [!{arg:N}, (%case (<=Integer [0, !{arg:N}]) [(%constcase 0 0)] Just 1)])
10-
PrimIO.case block in unsafePerformIO = [{arg:N}, {arg:N}]: (PrimIO.unsafeDestroyWorld [___, !{arg:N}])
11-
PrimIO.unsafePerformIO = [{arg:N}]: (PrimIO.unsafeCreateWorld [(%lam w (PrimIO.case block in unsafePerformIO [!{arg:N}, (!{arg:N} [!w])]))])
8+
Prelude.Types.prim__integerToNat = [{arg:N}]: (%case (%case (<=Integer [0, !{arg:N}]) [(%constcase 0 0)] Just 1) [(%constcase 1 (Builtin.believe_me [!{arg:N}])), (%constcase 0 0)] Nothing)
9+
PrimIO.unsafePerformIO = [{arg:N}]: (PrimIO.unsafeCreateWorld [(%lam w (PrimIO.unsafeDestroyWorld [___, (!{arg:N} [!w])]))])
1210
PrimIO.unsafeDestroyWorld = [{arg:N}, {arg:N}]: !{arg:N}
1311
PrimIO.unsafeCreateWorld = [{arg:N}]: (!{arg:N} [%MkWorld])
1412

0 commit comments

Comments
 (0)