diff --git a/llvm/include/llvm/Transforms/Utils/Cloning.h b/llvm/include/llvm/Transforms/Utils/Cloning.h index 6b56230a6e1d4..c8c15edb55c56 100644 --- a/llvm/include/llvm/Transforms/Utils/Cloning.h +++ b/llvm/include/llvm/Transforms/Utils/Cloning.h @@ -23,6 +23,7 @@ #include "llvm/Analysis/InlineCost.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DebugLoc.h" +#include "llvm/IR/FMF.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Compiler.h" #include "llvm/Transforms/Utils/ValueMapper.h" @@ -87,6 +88,8 @@ struct ClonedCodeInfo { /// check whether the main VMap mapping involves simplification or not. DenseMap OrigVMap; + FastMathFlags FMFs; + ClonedCodeInfo() = default; bool isSimplified(const Value *From, const Value *To) const { diff --git a/llvm/lib/Transforms/Utils/CloneFunction.cpp b/llvm/lib/Transforms/Utils/CloneFunction.cpp index fccb73a36b182..3442fd040735f 100644 --- a/llvm/lib/Transforms/Utils/CloneFunction.cpp +++ b/llvm/lib/Transforms/Utils/CloneFunction.cpp @@ -700,6 +700,51 @@ void PruningFunctionCloner::CloneBlock( } } +/// Propagate fast-math flags flags from OldFunc's new arguments to their users +/// if applicable. +static void propagateFastMathFlags(const Function *OldFunc, + ValueToValueMapTy &VMap, + const FastMathFlags &FMFs) { + if (!FMFs.any()) + return; + + // Visit all instructions reachable from the arguments of OldFunc. This + // ensures we only visit instructions in the original function. The arguments + // have FMFs as fast-math flags. Set them for all applicable instructions in + // the new function (retrieved via VMap). + + DenseSet Visited; + SmallVector Worklist; + for (const Argument &Arg : OldFunc->args()) { + Visited.insert(&Arg); + for (const User *U : Arg.users()) { + if (Visited.insert(U).second) + Worklist.push_back(cast(U)); + } + } + + while (!Worklist.empty()) { + const Instruction *CurrentOld = Worklist.pop_back_val(); + Instruction *Current = dyn_cast(VMap.lookup(CurrentOld)); + if (!Current || !isa(Current)) + continue; + + // TODO: Assumes all FP ops propagate the flags from args to the result, if + // all operands have the same flags. + if (!all_of(CurrentOld->operands(), + [&Visited](Value *V) { return Visited.contains(V); })) + continue; + + Current->setFastMathFlags(Current->getFastMathFlags() | FMFs); + + // Add all users of this instruction to the worklist + for (const User *U : CurrentOld->users()) { + if (Visited.insert(U).second) + Worklist.push_back(cast(U)); + } + } +} + /// This works like CloneAndPruneFunctionInto, except that it does not clone the /// entire function. Instead it starts at an instruction provided by the caller /// and copies (and prunes) only the code reachable from that instruction. @@ -996,6 +1041,8 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, I != E; ++I) if (ReturnInst *RI = dyn_cast(I->getTerminator())) Returns.push_back(RI); + + propagateFastMathFlags(OldFunc, VMap, CodeInfo->FMFs); } /// This works exactly like CloneFunctionInto, diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index 7df5e9958182c..a3af09d124e1f 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -2697,6 +2697,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // have no dead or constant instructions leftover after inlining occurs // (which can happen, e.g., because an argument was constant), but we'll be // happy with whatever the cloner can do. + InlinedFunctionInfo.FMFs = + isa(&CB) ? CB.getFastMathFlags() : FastMathFlags(); CloneAndPruneFunctionInto(Caller, CalledFunc, VMap, /*ModuleLevelChanges=*/false, Returns, ".i", &InlinedFunctionInfo); diff --git a/llvm/test/Transforms/Inline/propagate-fast-math-flags-to-inlined-instructions.ll b/llvm/test/Transforms/Inline/propagate-fast-math-flags-to-inlined-instructions.ll new file mode 100644 index 0000000000000..1d5becea14014 --- /dev/null +++ b/llvm/test/Transforms/Inline/propagate-fast-math-flags-to-inlined-instructions.ll @@ -0,0 +1,93 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals none --version 5 +; RUN: opt -p inline -S %s | FileCheck %s + +@g = external global float + +define float @add(float %a, float %b) { +; CHECK-LABEL: define float @add( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = fadd float [[A]], [[B]] +; CHECK-NEXT: ret float [[ADD]] +; + %add = fadd float %a, %b + ret float %add +} + +define float @caller1(float %a, float %b) { +; CHECK-LABEL: define float @caller1( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]]) { +; CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc float [[A]], [[B]] +; CHECK-NEXT: ret float [[ADD_I]] +; + %r = call reassoc float @add(float %a, float %b) + ret float %r +} + +define float @add_with_unrelated_fp_math(float %a, float %b) { +; CHECK-LABEL: define float @add_with_unrelated_fp_math( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]]) { +; CHECK-NEXT: [[L:%.*]] = load float, ptr @g, align 4 +; CHECK-NEXT: [[RES:%.*]] = fmul float [[L]], [[A]] +; CHECK-NEXT: store float [[RES]], ptr @g, align 4 +; CHECK-NEXT: [[ADD:%.*]] = fadd float [[A]], [[B]] +; CHECK-NEXT: ret float [[ADD]] +; + %l = load float, ptr @g + %res = fmul float %l, %a + store float %res, ptr @g + %add = fadd float %a, %b + ret float %add +} + +; Make sure the call-site fast-math flags are not added to instructions where +; not all operands have the new fast-math flags. +define float @caller2(float %a, float %b) { +; CHECK-LABEL: define float @caller2( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]]) { +; CHECK-NEXT: [[L_I:%.*]] = load float, ptr @g, align 4 +; CHECK-NEXT: [[RES_I:%.*]] = fmul float [[L_I]], [[A]] +; CHECK-NEXT: store float [[RES_I]], ptr @g, align 4 +; CHECK-NEXT: [[ADD_I:%.*]] = fadd nnan float [[A]], [[B]] +; CHECK-NEXT: ret float [[ADD_I]] +; + %r = call nnan float @add_with_unrelated_fp_math(float %a, float %b) + ret float %r +} + +define float @add_with_nnan(float %a, float %b) { +; CHECK-LABEL: define float @add_with_nnan( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = fadd nnan float [[A]], [[B]] +; CHECK-NEXT: ret float [[ADD]] +; + %add = fadd nnan float %a, %b + ret float %add +} + +; Make sure the fast-math flags on the original instruction are kept and the +; call-site flags are added. +define float @caller3(float %a, float %b) { +; CHECK-LABEL: define float @caller3( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]]) { +; CHECK-NEXT: [[ADD_I:%.*]] = fadd nnan ninf float [[A]], [[B]] +; CHECK-NEXT: ret float [[ADD_I]] +; + %r = call ninf float @add_with_nnan(float %a, float %b) + ret float %r +} + +; Make sure the fast-math flags don't get accidentally propagated to +; instructions in the caller, reachable via the passed arguments. +define float @caller4(float %a, float %b) { +; CHECK-LABEL: define float @caller4( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]]) { +; CHECK-NEXT: [[ADD_I:%.*]] = fadd ninf float [[A]], [[B]] +; CHECK-NEXT: [[DIV:%.*]] = fdiv float [[A]], [[B]] +; CHECK-NEXT: [[ADD:%.*]] = fadd float [[ADD_I]], [[DIV]] +; CHECK-NEXT: ret float [[ADD]] +; + %r = call ninf float @add(float %a, float %b) + %div = fdiv float %a, %b + %add = fadd float %r, %div + ret float %add +}