Skip to content

Commit 71838b9

Browse files
author
KristofferC
committed
Revert "Remove llvm-muladd pass and move it's functionality to to llvm-simdloop (#55802)"
This reverts commit 69ed5fd.
1 parent 69ed5fd commit 71838b9

File tree

10 files changed

+201
-96
lines changed

10 files changed

+201
-96
lines changed

doc/src/devdocs/llvm-passes.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ This pass is used to verify Julia's invariants about LLVM IR. This includes thin
114114

115115
These passes are used to perform transformations on LLVM IR that LLVM will not perform itself, e.g. fast math flag propagation, escape analysis, and optimizations on Julia-specific internal functions. They use knowledge about Julia's semantics to perform these optimizations.
116116

117+
### CombineMulAdd
118+
119+
* Filename: `llvm-muladd.cpp`
120+
* Class Name: `CombineMulAddPass`
121+
* Opt Name: `function(CombineMulAdd)`
122+
123+
This pass serves to optimize the particular combination of a regular `fmul` with a fast `fadd` into a contract `fmul` with a fast `fadd`. This is later optimized by the backend to a [fused multiply-add](https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation#Fused_multiply%E2%80%93add) instruction, which can provide significantly faster operations at the cost of more [unpredictable semantics](https://simonbyrne.github.io/notes/fastmath/).
124+
125+
!!! note
126+
127+
This optimization only occurs when the `fmul` has a single use, which is the fast `fadd`.
128+
117129
### AllocOpt
118130

119131
* Filename: `llvm-alloc-opt.cpp`

doc/src/devdocs/llvm.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ The code for lowering Julia AST to LLVM IR or interpreting it directly is in dir
3030
| `llvm-julia-licm.cpp` | Custom LLVM pass to hoist/sink Julia-specific intrinsics |
3131
| `llvm-late-gc-lowering.cpp` | Custom LLVM pass to root GC-tracked values |
3232
| `llvm-lower-handlers.cpp` | Custom LLVM pass to lower try-catch blocks |
33+
| `llvm-muladd.cpp` | Custom LLVM pass for fast-match FMA |
3334
| `llvm-multiversioning.cpp` | Custom LLVM pass to generate sysimg code on multiple architectures |
3435
| `llvm-propagate-addrspaces.cpp` | Custom LLVM pass to canonicalize addrspaces |
3536
| `llvm-ptls.cpp` | Custom LLVM pass to lower TLS operations |

src/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ RT_LLVMLINK :=
5252
CG_LLVMLINK :=
5353

5454
ifeq ($(JULIACODEGEN),LLVM)
55-
CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop \
55+
CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop llvm-muladd \
5656
llvm-final-gc-lowering llvm-pass-helpers llvm-late-gc-lowering llvm-ptls \
5757
llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces \
5858
llvm-multiversioning llvm-alloc-opt llvm-alloc-helpers cgmemmgr llvm-remove-addrspaces \

src/llvm-muladd.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
#include "llvm-version.h"
4+
#include "passes.h"
5+
6+
#include <llvm-c/Core.h>
7+
#include <llvm-c/Types.h>
8+
9+
#include <llvm/ADT/Statistic.h>
10+
#include <llvm/Analysis/OptimizationRemarkEmitter.h>
11+
#include <llvm/IR/Value.h>
12+
#include <llvm/IR/PassManager.h>
13+
#include <llvm/IR/Function.h>
14+
#include <llvm/IR/Instructions.h>
15+
#include <llvm/IR/IntrinsicInst.h>
16+
#include <llvm/IR/Module.h>
17+
#include <llvm/IR/Operator.h>
18+
#include <llvm/IR/IRBuilder.h>
19+
#include <llvm/IR/Verifier.h>
20+
#include <llvm/Pass.h>
21+
#include <llvm/Support/Debug.h>
22+
23+
#include "julia.h"
24+
#include "julia_assert.h"
25+
26+
#define DEBUG_TYPE "combine-muladd"
27+
#undef DEBUG
28+
29+
using namespace llvm;
30+
STATISTIC(TotalContracted, "Total number of multiplies marked for FMA");
31+
32+
#ifndef __clang_gcanalyzer__
33+
#define REMARK(remark) ORE.emit(remark)
34+
#else
35+
#define REMARK(remark) (void) 0;
36+
#endif
37+
38+
/**
39+
* Combine
40+
* ```
41+
* %v0 = fmul ... %a, %b
42+
* %v = fadd contract ... %v0, %c
43+
* ```
44+
* to
45+
* `%v = call contract @llvm.fmuladd.<...>(... %a, ... %b, ... %c)`
46+
* when `%v0` has no other use
47+
*/
48+
49+
// Return true if we changed the mulOp
50+
static bool checkCombine(Value *maybeMul, OptimizationRemarkEmitter &ORE) JL_NOTSAFEPOINT
51+
{
52+
auto mulOp = dyn_cast<Instruction>(maybeMul);
53+
if (!mulOp || mulOp->getOpcode() != Instruction::FMul)
54+
return false;
55+
if (!mulOp->hasOneUse()) {
56+
LLVM_DEBUG(dbgs() << "mulOp has multiple uses: " << *maybeMul << "\n");
57+
REMARK([&](){
58+
return OptimizationRemarkMissed(DEBUG_TYPE, "Multiuse FMul", mulOp)
59+
<< "fmul had multiple uses " << ore::NV("fmul", mulOp);
60+
});
61+
return false;
62+
}
63+
// On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us.
64+
auto fmf = mulOp->getFastMathFlags();
65+
if (!fmf.allowContract()) {
66+
LLVM_DEBUG(dbgs() << "Marking mulOp for FMA: " << *maybeMul << "\n");
67+
REMARK([&](){
68+
return OptimizationRemark(DEBUG_TYPE, "Marked for FMA", mulOp)
69+
<< "marked for fma " << ore::NV("fmul", mulOp);
70+
});
71+
++TotalContracted;
72+
fmf.setAllowContract(true);
73+
mulOp->copyFastMathFlags(fmf);
74+
return true;
75+
}
76+
return false;
77+
}
78+
79+
static bool combineMulAdd(Function &F) JL_NOTSAFEPOINT
80+
{
81+
OptimizationRemarkEmitter ORE(&F);
82+
bool modified = false;
83+
for (auto &BB: F) {
84+
for (auto it = BB.begin(); it != BB.end();) {
85+
auto &I = *it;
86+
it++;
87+
switch (I.getOpcode()) {
88+
case Instruction::FAdd: {
89+
if (!I.hasAllowContract())
90+
continue;
91+
modified |= checkCombine(I.getOperand(0), ORE) || checkCombine(I.getOperand(1), ORE);
92+
break;
93+
}
94+
case Instruction::FSub: {
95+
if (!I.hasAllowContract())
96+
continue;
97+
modified |= checkCombine(I.getOperand(0), ORE) || checkCombine(I.getOperand(1), ORE);
98+
break;
99+
}
100+
default:
101+
break;
102+
}
103+
}
104+
}
105+
#ifdef JL_VERIFY_PASSES
106+
assert(!verifyLLVMIR(F));
107+
#endif
108+
return modified;
109+
}
110+
111+
PreservedAnalyses CombineMulAddPass::run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT
112+
{
113+
if (combineMulAdd(F)) {
114+
return PreservedAnalyses::allInSet<CFGAnalyses>();
115+
}
116+
return PreservedAnalyses::all();
117+
}

src/llvm-simdloop.cpp

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ STATISTIC(ReductionChainLength, "Total sum of instructions folded from reduction
4141
STATISTIC(MaxChainLength, "Max length of reduction chain");
4242
STATISTIC(AddChains, "Addition reduction chains");
4343
STATISTIC(MulChains, "Multiply reduction chains");
44-
STATISTIC(TotalContracted, "Total number of multiplies marked for FMA");
4544

4645
#ifndef __clang_gcanalyzer__
4746
#define REMARK(remark) ORE.emit(remark)
@@ -50,49 +49,6 @@ STATISTIC(TotalContracted, "Total number of multiplies marked for FMA");
5049
#endif
5150
namespace {
5251

53-
/**
54-
* Combine
55-
* ```
56-
* %v0 = fmul ... %a, %b
57-
* %v = fadd contract ... %v0, %c
58-
* ```
59-
* to
60-
* %v0 = fmul contract ... %a, %b
61-
* %v = fadd contract ... %v0, %c
62-
* when `%v0` has no other use
63-
*/
64-
65-
static bool checkCombine(Value *maybeMul, Loop &L, OptimizationRemarkEmitter &ORE) JL_NOTSAFEPOINT
66-
{
67-
auto mulOp = dyn_cast<Instruction>(maybeMul);
68-
if (!mulOp || mulOp->getOpcode() != Instruction::FMul)
69-
return false;
70-
if (!L.contains(mulOp))
71-
return false;
72-
if (!mulOp->hasOneUse()) {
73-
LLVM_DEBUG(dbgs() << "mulOp has multiple uses: " << *maybeMul << "\n");
74-
REMARK([&](){
75-
return OptimizationRemarkMissed(DEBUG_TYPE, "Multiuse FMul", mulOp)
76-
<< "fmul had multiple uses " << ore::NV("fmul", mulOp);
77-
});
78-
return false;
79-
}
80-
// On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us.
81-
auto fmf = mulOp->getFastMathFlags();
82-
if (!fmf.allowContract()) {
83-
LLVM_DEBUG(dbgs() << "Marking mulOp for FMA: " << *maybeMul << "\n");
84-
REMARK([&](){
85-
return OptimizationRemark(DEBUG_TYPE, "Marked for FMA", mulOp)
86-
<< "marked for fma " << ore::NV("fmul", mulOp);
87-
});
88-
++TotalContracted;
89-
fmf.setAllowContract(true);
90-
mulOp->copyFastMathFlags(fmf);
91-
return true;
92-
}
93-
return false;
94-
}
95-
9652
static unsigned getReduceOpcode(Instruction *J, Instruction *operand) JL_NOTSAFEPOINT
9753
{
9854
switch (J->getOpcode()) {
@@ -194,28 +150,6 @@ static void enableUnsafeAlgebraIfReduction(PHINode *Phi, Loop &L, OptimizationRe
194150
});
195151
(*K)->setHasAllowReassoc(true);
196152
(*K)->setHasAllowContract(true);
197-
switch ((*K)->getOpcode()) {
198-
case Instruction::FAdd: {
199-
if (!(*K)->hasAllowContract())
200-
continue;
201-
// (*K)->getOperand(0)->print(dbgs());
202-
// (*K)->getOperand(1)->print(dbgs());
203-
checkCombine((*K)->getOperand(0), L, ORE);
204-
checkCombine((*K)->getOperand(1), L, ORE);
205-
break;
206-
}
207-
case Instruction::FSub: {
208-
if (!(*K)->hasAllowContract())
209-
continue;
210-
// (*K)->getOperand(0)->print(dbgs());
211-
// (*K)->getOperand(1)->print(dbgs());
212-
checkCombine((*K)->getOperand(0), L, ORE);
213-
checkCombine((*K)->getOperand(1), L, ORE);
214-
break;
215-
}
216-
default:
217-
break;
218-
}
219153
if (SE)
220154
SE->forgetValue(*K);
221155
++length;

src/passes.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@ struct DemoteFloat16Pass : PassInfoMixin<DemoteFloat16Pass> {
1515
static bool isRequired() { return true; }
1616
};
1717

18-
struct LateLowerGCPass : PassInfoMixin<LateLowerGCPass> {
18+
struct CombineMulAddPass : PassInfoMixin<CombineMulAddPass> {
1919
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT;
20-
static bool isRequired() { return true; }
2120
};
2221

23-
struct CombineMulAddPass : PassInfoMixin<CombineMulAddPass> {
24-
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT {
25-
// no-op
26-
return PreservedAnalyses::all();
27-
}
22+
struct LateLowerGCPass : PassInfoMixin<LateLowerGCPass> {
23+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT;
24+
static bool isRequired() { return true; }
2825
};
2926

3027
struct AllocOptPass : PassInfoMixin<AllocOptPass> {

src/pipeline.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ static void buildCleanupPipeline(ModulePassManager &MPM, PassBuilder *PB, Optimi
577577
if (options.cleanup) {
578578
if (O.getSpeedupLevel() >= 2) {
579579
FunctionPassManager FPM;
580+
JULIA_PASS(FPM.addPass(CombineMulAddPass()));
580581
FPM.addPass(DivRemPairsPass());
581582
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
582583
}

test/llvmpasses/julia-simdloop.ll

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,26 +61,6 @@ loopdone:
6161
ret double %nextv
6262
}
6363

64-
; CHECK-LABEL: @simd_test_sub4(
65-
define double @simd_test_sub4(double *%a) {
66-
top:
67-
br label %loop
68-
loop:
69-
%i = phi i64 [0, %top], [%nexti, %loop]
70-
%v = phi double [0.000000e+00, %top], [%nextv, %loop]
71-
%aptr = getelementptr double, double *%a, i64 %i
72-
%aval = load double, double *%aptr
73-
%nextv2 = fmul double %aval, %aval
74-
; CHECK: fmul contract double %aval, %aval
75-
%nextv = fsub double %v, %nextv2
76-
; CHECK: fsub reassoc contract double %v, %nextv2
77-
%nexti = add i64 %i, 1
78-
%done = icmp sgt i64 %nexti, 500
79-
br i1 %done, label %loopdone, label %loop, !llvm.loop !0
80-
loopdone:
81-
ret double %nextv
82-
}
83-
8464
; Tests if we correctly pass through other metadata
8565
; CHECK-LABEL: @disabled(
8666
define i32 @disabled(i32* noalias nocapture %a, i32* noalias nocapture readonly %b, i32 %N) {
@@ -104,7 +84,6 @@ for.end: ; preds = %for.body
10484
ret i32 %1
10585
}
10686

107-
10887
!0 = distinct !{!0, !"julia.simdloop"}
10988
!1 = distinct !{!1, !"julia.simdloop", !"julia.ivdep"}
11089
!2 = distinct !{!2, !"julia.simdloop", !"julia.ivdep", !3}

test/llvmpasses/muladd.ll

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
; This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
; RUN: opt -enable-new-pm=1 --opaque-pointers=0 --load-pass-plugin=libjulia-codegen%shlibext -passes='CombineMulAdd' -S %s | FileCheck %s
4+
5+
; RUN: opt -enable-new-pm=1 --opaque-pointers=1 --load-pass-plugin=libjulia-codegen%shlibext -passes='CombineMulAdd' -S %s | FileCheck %s
6+
7+
8+
; CHECK-LABEL: @fast_muladd1
9+
define double @fast_muladd1(double %a, double %b, double %c) {
10+
top:
11+
; CHECK: {{contract|fmuladd}}
12+
%v1 = fmul double %a, %b
13+
%v2 = fadd fast double %v1, %c
14+
; CHECK: ret double
15+
ret double %v2
16+
}
17+
18+
; CHECK-LABEL: @fast_mulsub1
19+
define double @fast_mulsub1(double %a, double %b, double %c) {
20+
top:
21+
; CHECK: {{contract|fmuladd}}
22+
%v1 = fmul double %a, %b
23+
%v2 = fsub fast double %v1, %c
24+
; CHECK: ret double
25+
ret double %v2
26+
}
27+
28+
; CHECK-LABEL: @fast_mulsub_vec1
29+
define <2 x double> @fast_mulsub_vec1(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
30+
top:
31+
; CHECK: {{contract|fmuladd}}
32+
%v1 = fmul <2 x double> %a, %b
33+
%v2 = fsub fast <2 x double> %c, %v1
34+
; CHECK: ret <2 x double>
35+
ret <2 x double> %v2
36+
}
37+
38+
; COM: Should not mark fmul as contract when multiple uses of fmul exist
39+
; CHECK-LABEL: @slow_muladd1
40+
define double @slow_muladd1(double %a, double %b, double %c) {
41+
top:
42+
; CHECK: %v1 = fmul double %a, %b
43+
%v1 = fmul double %a, %b
44+
; CHECK: %v2 = fadd fast double %v1, %c
45+
%v2 = fadd fast double %v1, %c
46+
; CHECK: %v3 = fadd fast double %v1, %b
47+
%v3 = fadd fast double %v1, %b
48+
; CHECK: %v4 = fadd fast double %v3, %v2
49+
%v4 = fadd fast double %v3, %v2
50+
; CHECK: ret double %v4
51+
ret double %v4
52+
}
53+
54+
; COM: Should not mark fadd->fadd fast as contract
55+
; CHECK-LABEL: @slow_addadd1
56+
define double @slow_addadd1(double %a, double %b, double %c) {
57+
top:
58+
; CHECK: %v1 = fadd double %a, %b
59+
%v1 = fadd double %a, %b
60+
; CHECK: %v2 = fadd fast double %v1, %c
61+
%v2 = fadd fast double %v1, %c
62+
; CHECK: ret double %v2
63+
ret double %v2
64+
}

test/llvmpasses/parsing.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
; COM: NewPM-only test, tests for ability to parse Julia passes
22

3-
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes='module(CPUFeatures,RemoveNI,JuliaMultiVersioning,RemoveJuliaAddrspaces,LowerPTLSPass,function(DemoteFloat16,LateLowerGCFrame,FinalLowerGC,AllocOpt,PropagateJuliaAddrspaces,LowerExcHandlers,GCInvariantVerifier,loop(LowerSIMDLoop,JuliaLICM),GCInvariantVerifier<strong>,GCInvariantVerifier<no-strong>),LowerPTLSPass<imaging>,LowerPTLSPass<no-imaging>,JuliaMultiVersioning<external>,JuliaMultiVersioning<no-external>)' -S %s -o /dev/null
3+
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes='module(CPUFeatures,RemoveNI,JuliaMultiVersioning,RemoveJuliaAddrspaces,LowerPTLSPass,function(DemoteFloat16,CombineMulAdd,LateLowerGCFrame,FinalLowerGC,AllocOpt,PropagateJuliaAddrspaces,LowerExcHandlers,GCInvariantVerifier,loop(LowerSIMDLoop,JuliaLICM),GCInvariantVerifier<strong>,GCInvariantVerifier<no-strong>),LowerPTLSPass<imaging>,LowerPTLSPass<no-imaging>,JuliaMultiVersioning<external>,JuliaMultiVersioning<no-external>)' -S %s -o /dev/null
44
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes="julia<level=3;llvm_only>" -S %s -o /dev/null
55
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes="julia<level=3;no_llvm_only>" -S %s -o /dev/null
66
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes="julia<level=3;no_enable_vector_pipeline>" -S %s -o /dev/null

0 commit comments

Comments
 (0)