Skip to content

Commit 609fb0e

Browse files
dhruvachakronlieb
authored andcommitted
[clang] [XteamReduction] Match analysis with CodeGen.
The analysis phase determines whether CodeGen will be able to generate a XteamReduction version of the kernel. But currently, the analysis does not fully match the CodeGen, so there are scenarios where the analysis will pass but CodeGen will fail. This patch fixes that problem. Change-Id: Ic4c6315db8511dc2acb5d7eb7bff45c7d1cf100a
1 parent 03220e4 commit 609fb0e

File tree

4 files changed

+155
-75
lines changed

4 files changed

+155
-75
lines changed

clang/lib/CodeGen/CGStmt.cpp

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -496,27 +496,13 @@ bool CodeGenFunction::EmitXteamRedStmt(const Stmt *S) {
496496
if (!isa<BinaryOperator>(S))
497497
return false;
498498

499-
auto getRedVarDecl =
500-
[](const Expr *E,
501-
const CodeGenModule::XteamRedVarMap &RVM) -> const VarDecl * {
502-
if (!isa<DeclRefExpr>(E))
503-
return nullptr;
504-
const ValueDecl *ValDecl = cast<DeclRefExpr>(E)->getDecl();
505-
if (!isa<VarDecl>(ValDecl))
506-
return nullptr;
507-
const VarDecl *VD = cast<VarDecl>(ValDecl);
508-
if (RVM.find(VD) == RVM.end())
509-
return nullptr;
510-
return VD;
511-
};
512-
513499
const BinaryOperator *RedBO = cast<BinaryOperator>(S);
514500
const CodeGenModule::XteamRedVarMap &RedVarMap =
515501
CGM.getXteamRedVarMap(CGM.getCurrentXteamRedStmt());
516502

517503
// Is a reduction variable the lhs?
518504
const VarDecl *RedVarDecl =
519-
getRedVarDecl(RedBO->getLHS()->IgnoreImpCasts(), RedVarMap);
505+
CGM.getXteamRedVarDecl(RedBO->getLHS()->IgnoreImpCasts(), RedVarMap);
520506
if (RedVarDecl == nullptr)
521507
return false;
522508

@@ -525,16 +511,6 @@ bool CodeGenFunction::EmitXteamRedStmt(const Stmt *S) {
525511
(RedBO->getOpcode() == BO_AddAssign || RedBO->getOpcode() == BO_Assign) &&
526512
"Unexpected operator during Xteam CodeGen");
527513

528-
auto isRedVarExpr = [](const Expr *E, const VarDecl *RedVarDecl) {
529-
if (!isa<DeclRefExpr>(E))
530-
return false;
531-
const ValueDecl *ValDecl = cast<DeclRefExpr>(E)->getDecl();
532-
if (!isa<VarDecl>(ValDecl))
533-
return false;
534-
const VarDecl *VD = cast<VarDecl>(ValDecl);
535-
return VD == RedVarDecl;
536-
};
537-
538514
// Extract the rhs for the reduction.
539515
const Expr *RedRHSExpr = nullptr;
540516
auto OpcRedBO = RedBO->getOpcode();
@@ -549,9 +525,10 @@ bool CodeGenFunction::EmitXteamRedStmt(const Stmt *S) {
549525
assert(OpcL2BO == BO_Add && "Unexpected operator");
550526
// If the redvar is lhs, use the rhs in the generated reduction statement
551527
// and vice-versa.
552-
if (isRedVarExpr(L2BO->getLHS()->IgnoreImpCasts(), RedVarDecl))
528+
if (CGM.isXteamRedVarExpr(L2BO->getLHS()->IgnoreImpCasts(), RedVarDecl))
553529
RedRHSExpr = L2BO->getRHS();
554-
else if (isRedVarExpr(L2BO->getRHS()->IgnoreImpCasts(), RedVarDecl))
530+
else if (CGM.isXteamRedVarExpr(L2BO->getRHS()->IgnoreImpCasts(),
531+
RedVarDecl))
555532
RedRHSExpr = L2BO->getLHS();
556533
else
557534
llvm_unreachable("Unhandled add expression during xteam reduction");

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 102 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7876,8 +7876,9 @@ class NoLoopStepChecker final : public ConstStmtVisitor<NoLoopStepChecker> {
78767876
/// red_var
78777877
class XteamRedExprChecker final : public ConstStmtVisitor<XteamRedExprChecker> {
78787878
public:
7879-
XteamRedExprChecker(const CodeGenModule::XteamRedVarMap &RVM)
7880-
: RedMap(RVM), IsSupported(true) {}
7879+
XteamRedExprChecker(CodeGenModule &CGM,
7880+
const CodeGenModule::XteamRedVarMap &RVM)
7881+
: CGM(CGM), RedMap(RVM), IsSupported(true) {}
78817882
XteamRedExprChecker() = delete;
78827883

78837884
bool isSupported() const { return IsSupported; }
@@ -7886,18 +7887,6 @@ class XteamRedExprChecker final : public ConstStmtVisitor<XteamRedExprChecker> {
78867887
if (!S)
78877888
return;
78887889

7889-
auto isExprXteamRedVar = [this](const Expr *E) {
7890-
if (!isa<DeclRefExpr>(E))
7891-
return false;
7892-
auto *Decl = cast<DeclRefExpr>(E)->getDecl();
7893-
if (!isa<VarDecl>(Decl))
7894-
return false;
7895-
auto *VD = cast<VarDecl>(Decl);
7896-
if (RedMap.find(VD) != RedMap.end())
7897-
return true;
7898-
return false;
7899-
};
7900-
79017890
if (isa<BinaryOperator>(S)) {
79027891
const BinaryOperator *BinOpExpr = cast<BinaryOperator>(S);
79037892
// Even though we filtered out everything except the sum reduction
@@ -7908,51 +7897,77 @@ class XteamRedExprChecker final : public ConstStmtVisitor<XteamRedExprChecker> {
79087897
// We punt on anything more complex.
79097898

79107899
const Expr *LHS = BinOpExpr->getLHS()->IgnoreImpCasts();
7911-
if (isExprXteamRedVar(LHS)) {
7912-
auto BinOpExprOp = BinOpExpr->getOpcode();
7913-
if (BinOpExprOp != BO_Assign && BinOpExprOp != BO_AddAssign &&
7914-
BinOpExprOp != BO_Add) {
7915-
IsSupported = false;
7916-
return;
7917-
}
7918-
// We only need to further examine the assignment case.
7919-
// If += or +, Codegen will extract the rhs.
7920-
if (BinOpExpr->getOpcode() == BO_Assign) {
7900+
auto BinOpExprOp = BinOpExpr->getOpcode();
7901+
// Get the reduction variable, if any, from the LHS.
7902+
const VarDecl *RedVarDecl = CGM.getXteamRedVarDecl(LHS, RedMap);
7903+
if (RedVarDecl != nullptr) {
7904+
if (BinOpExprOp == BO_Assign || BinOpExprOp == BO_AddAssign) {
79217905
const Expr *RHS = BinOpExpr->getRHS()->IgnoreImpCasts();
7922-
if (!isa<BinaryOperator>(RHS)) {
7923-
IsSupported = false;
7924-
return;
7925-
}
7926-
const BinaryOperator *BinOpRHS = cast<BinaryOperator>(RHS);
7927-
if (BinOpRHS->getOpcode() != BO_Add) {
7928-
IsSupported = false;
7929-
return;
7930-
}
7931-
const Expr *LHSBinOpRHS = BinOpRHS->getLHS()->IgnoreImpCasts();
7932-
const Expr *RHSBinOpRHS = BinOpRHS->getRHS()->IgnoreImpCasts();
7933-
if (!isExprXteamRedVar(LHSBinOpRHS) &&
7934-
!isExprXteamRedVar(RHSBinOpRHS)) {
7935-
IsSupported = false;
7936-
return;
7906+
// If operator +=, reject if RHS accesses any reduction variable.
7907+
if (BinOpExprOp == BO_AddAssign) {
7908+
ValidateChildren(RHS);
7909+
if (!IsSupported)
7910+
return;
7911+
} else { // BinOpExprOp == BO_Assign
7912+
if (isa<BinaryOperator>(RHS)) {
7913+
const BinaryOperator *BinOpRHS = cast<BinaryOperator>(RHS);
7914+
if (BinOpRHS->getOpcode() == BO_Add) {
7915+
const Expr *LHSBinOpRHS = BinOpRHS->getLHS()->IgnoreImpCasts();
7916+
const Expr *RHSBinOpRHS = BinOpRHS->getRHS()->IgnoreImpCasts();
7917+
// If LHS is the reduction variable, the RHS must not access any
7918+
// reduction variable. Similarly, vice-versa for RHS.
7919+
if (CGM.isXteamRedVarExpr(LHSBinOpRHS, RedVarDecl))
7920+
ValidateChildren(RHSBinOpRHS);
7921+
else if (CGM.isXteamRedVarExpr(RHSBinOpRHS, RedVarDecl))
7922+
ValidateChildren(LHSBinOpRHS);
7923+
else // Neither LHS nor RHS is the reduction variable.
7924+
IsSupported = false;
7925+
if (!IsSupported)
7926+
return;
7927+
} else { // Not an add binary operator.
7928+
IsSupported = false;
7929+
return;
7930+
}
7931+
} else { // RHS is not a binary operator for assignment.
7932+
IsSupported = false;
7933+
return;
7934+
}
79377935
}
7936+
} else { // Binary operator is neither +=, nor =.
7937+
IsSupported = false;
7938+
return;
79387939
}
7940+
} else { // LHS of binary operator does not access any reduction variable.
7941+
// Ensure that RHS does not access any reduction variable either.
7942+
ValidateChildren(S);
7943+
if (!IsSupported)
7944+
return;
79397945
}
7940-
} else if (isa<UnaryOperator>(S)) {
7941-
const Expr *UnaryOpExpr =
7942-
cast<UnaryOperator>(S)->getSubExpr()->IgnoreImpCasts();
7943-
// Xteam reduction does not handle unary operators currently.
7944-
if (isExprXteamRedVar(UnaryOpExpr)) {
7946+
} else if (isa<DeclRefExpr>(S)) {
7947+
// Not a binary operator, so not supported at this point. So ensure no
7948+
// reduction variable is accessed.
7949+
if (CGM.hasXteamRedVar(cast<DeclRefExpr>(S), RedMap)) {
79457950
IsSupported = false;
79467951
return;
79477952
}
7953+
} else {
7954+
// Recursively check the children.
7955+
ValidateChildren(S);
7956+
if (!IsSupported)
7957+
return;
79487958
}
7949-
7950-
for (const Stmt *Child : S->children())
7951-
if (Child)
7959+
}
7960+
void ValidateChildren(const Stmt *S) {
7961+
for (auto Child : S->children())
7962+
if (Child) {
79527963
Visit(Child);
7964+
if (!IsSupported)
7965+
return;
7966+
}
79537967
}
79547968

79557969
private:
7970+
CodeGenModule &CGM;
79567971
/// Map of reduction variables for this directive.
79577972
const CodeGenModule::XteamRedVarMap &RedMap;
79587973
/// Set to false if codegen does not support the reduction expression.
@@ -8438,7 +8453,7 @@ CodeGenModule::getXteamRedForStmtStatus(const OMPExecutableDirective &D,
84388453
// the directive
84398454
const ForStmt *FStmt = getSingleForStmt(OMPStmt);
84408455
assert(FStmt != nullptr && "Unexpected missing For Stmt");
8441-
XteamRedExprChecker Chk(RVM);
8456+
XteamRedExprChecker Chk(*this, RVM);
84428457
Chk.Visit(FStmt);
84438458
if (!Chk.isSupported())
84448459
return std::make_pair(NxUnsupportedRedExpr, HasNestedGenericCall);
@@ -8620,6 +8635,45 @@ CodeGenModule::collectXteamRedVars(const OptKernelNestDirectives &NestDirs) {
86208635
return std::make_pair(NxSuccess, std::make_pair(VarMap, VarVec));
86218636
}
86228637

8638+
bool CodeGenModule::hasXteamRedVar(const Expr *E,
8639+
const XteamRedVarMap &RedMap) const {
8640+
assert(E && "Unexpected null expression");
8641+
if (!isa<DeclRefExpr>(E))
8642+
return false;
8643+
auto *Decl = cast<DeclRefExpr>(E)->getDecl();
8644+
if (!isa<VarDecl>(Decl))
8645+
return false;
8646+
auto *VD = cast<VarDecl>(Decl);
8647+
if (RedMap.find(VD) != RedMap.end())
8648+
return true;
8649+
return false;
8650+
}
8651+
8652+
const VarDecl *
8653+
CodeGenModule::getXteamRedVarDecl(const Expr *E,
8654+
const XteamRedVarMap &RedMap) const {
8655+
if (!isa<DeclRefExpr>(E))
8656+
return nullptr;
8657+
const ValueDecl *ValDecl = cast<DeclRefExpr>(E)->getDecl();
8658+
if (!isa<VarDecl>(ValDecl))
8659+
return nullptr;
8660+
const VarDecl *VD = cast<VarDecl>(ValDecl);
8661+
if (RedMap.find(VD) == RedMap.end())
8662+
return nullptr;
8663+
return VD;
8664+
}
8665+
8666+
bool CodeGenModule::isXteamRedVarExpr(const Expr *E,
8667+
const VarDecl *RedVarDecl) const {
8668+
if (!isa<DeclRefExpr>(E))
8669+
return false;
8670+
const ValueDecl *ValDecl = cast<DeclRefExpr>(E)->getDecl();
8671+
if (!isa<VarDecl>(ValDecl))
8672+
return false;
8673+
const VarDecl *VD = cast<VarDecl>(ValDecl);
8674+
return VD == RedVarDecl;
8675+
}
8676+
86238677
const OMPExecutableDirective *
86248678
getNestedDirective(const OMPExecutableDirective &D) {
86258679
const Stmt *AssocStmt = D.getAssociatedStmt();

clang/lib/CodeGen/CodeGenModule.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,6 +1915,19 @@ class CodeGenModule : public CodeGenTypeCache {
19151915
void setCurrentXteamRedStmt(const Stmt *S) { CurrentXteamRedStmt = S; }
19161916
const Stmt *getCurrentXteamRedStmt() { return CurrentXteamRedStmt; }
19171917

1918+
/// Return true if the provided expression accesses a variable in the provided
1919+
/// map, otherwise return false.
1920+
bool hasXteamRedVar(const Expr *E, const XteamRedVarMap &RedMap) const;
1921+
1922+
/// If present in the provided map, return the reduction variable accessed by
1923+
/// the provided expression, otherwise return nullptr.
1924+
const VarDecl *getXteamRedVarDecl(const Expr *E,
1925+
const XteamRedVarMap &RedMap) const;
1926+
1927+
/// Return true if the provided expression accesses the provided variable,
1928+
/// otherwise return false.
1929+
bool isXteamRedVarExpr(const Expr *E, const VarDecl *VD) const;
1930+
19181931
/// Move some lazily-emitted states to the NewBuilder. This is especially
19191932
/// essential for the incremental parsing environment like Clang Interpreter,
19201933
/// because we'll lose all important information after each repl.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// clang-format off
2+
// This test verifies correctness of Xteam Reduction for sum reduction using increment.
3+
//
4+
// RUN: %libomptarget-compile-generic -fopenmp-target-fast
5+
// RUN: env LIBOMPTARGET_KERNEL_TRACE=1 %libomptarget-run-generic 2>&1 | %fcheck-generic
6+
7+
// UNSUPPORTED: nvptx64-nvidia-cuda
8+
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
9+
// UNSUPPORTED: aarch64-unknown-linux-gnu
10+
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
11+
// UNSUPPORTED: x86_64-pc-linux-gnu
12+
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
13+
14+
// clang-format on
15+
16+
#include <omp.h>
17+
#include <stdio.h>
18+
19+
int main() {
20+
int N = 10;
21+
int sum = 0;
22+
23+
#pragma omp target teams distribute parallel for reduction(+ : sum)
24+
for (int j = 0; j < N; j = j + 1)
25+
sum++;
26+
27+
printf("sum = %d\n", sum);
28+
int rc = sum != 10;
29+
30+
if (!rc)
31+
printf("Success\n");
32+
33+
return rc;
34+
}
35+
36+
/// CHECK: DEVID:[[S:[ ]*]][[DEVID:[0-9]+]] SGN:2

0 commit comments

Comments
 (0)