Skip to content

[flang][cuda] Add option to disable warp function in semantic #143640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Support/Fortran-features.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
SavedLocalInSpecExpr, PrintNamelist, AssumedRankPassedToNonAssumedRank,
IgnoreIrrelevantAttributes, Unsigned, AmbiguousStructureConstructor,
ContiguousOkForSeqAssociation, ForwardRefExplicitTypeDummy,
InaccessibleDeferredOverride)
InaccessibleDeferredOverride, CudaWarpMatchFunction)

// Portability and suspicious usage warnings
ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,
Expand Down
125 changes: 82 additions & 43 deletions flang/lib/Semantics/check-cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "flang/Semantics/expression.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "llvm/ADT/StringSet.h"

// Once labeled DO constructs have been canonicalized and their parse subtrees
// transformed into parser::DoConstructs, scan the parser::Blocks of the program
Expand Down Expand Up @@ -61,14 +62,19 @@ bool CanonicalizeCUDA(parser::Program &program) {

using MaybeMsg = std::optional<parser::MessageFormattedText>;

static const llvm::StringSet<> warpFunctions_ = {"match_all_syncjj",
"match_all_syncjx", "match_all_syncjf", "match_all_syncjd",
"match_any_syncjj", "match_any_syncjx", "match_any_syncjf",
"match_any_syncjd"};

// Traverses an evaluate::Expr<> in search of unsupported operations
// on the device.

struct DeviceExprChecker
: public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
using Result = MaybeMsg;
using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
DeviceExprChecker() : Base(*this) {}
explicit DeviceExprChecker(SemanticsContext &c) : Base(*this), context_{c} {}
using Base::operator();
Result operator()(const evaluate::ProcedureDesignator &x) const {
if (const Symbol * sym{x.GetInterfaceSymbol()}) {
Expand All @@ -78,10 +84,17 @@ struct DeviceExprChecker
if (auto attrs{subp->cudaSubprogramAttrs()}) {
if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
*attrs == common::CUDASubprogramAttrs::Device) {
if (warpFunctions_.contains(sym->name().ToString()) &&
!context_.languageFeatures().IsEnabled(
Fortran::common::LanguageFeature::CudaWarpMatchFunction)) {
return parser::MessageFormattedText(
"warp match function disabled"_err_en_US);
}
return {};
}
}
}

const Symbol &ultimate{sym->GetUltimate()};
const Scope &scope{ultimate.owner()};
const Symbol *mod{scope.IsModule() ? scope.symbol() : nullptr};
Expand All @@ -94,9 +107,12 @@ struct DeviceExprChecker
// TODO(CUDA): Check for unsupported intrinsics here
return {};
}

return parser::MessageFormattedText(
"'%s' may not be called in device code"_err_en_US, x.GetName());
}

SemanticsContext &context_;
};

struct FindHostArray
Expand Down Expand Up @@ -133,9 +149,10 @@ struct FindHostArray
}
};

template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
template <typename A>
static MaybeMsg CheckUnwrappedExpr(SemanticsContext &context, const A &x) {
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
return DeviceExprChecker{}(expr->typedExpr);
return DeviceExprChecker{context}(expr->typedExpr);
}
return {};
}
Expand All @@ -144,104 +161,124 @@ template <typename A>
static void CheckUnwrappedExpr(
SemanticsContext &context, SourceName at, const A &x) {
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {
if (auto msg{DeviceExprChecker{context}(expr->typedExpr)}) {
context.Say(at, std::move(*msg));
}
}
}

template <bool CUF_KERNEL> struct ActionStmtChecker {
template <typename A> static MaybeMsg WhyNotOk(const A &x) {
template <typename A>
static MaybeMsg WhyNotOk(SemanticsContext &context, const A &x) {
if constexpr (ConstraintTrait<A>) {
return WhyNotOk(x.thing);
return WhyNotOk(context, x.thing);
} else if constexpr (WrapperTrait<A>) {
return WhyNotOk(x.v);
return WhyNotOk(context, x.v);
} else if constexpr (UnionTrait<A>) {
return WhyNotOk(x.u);
return WhyNotOk(context, x.u);
} else if constexpr (TupleTrait<A>) {
return WhyNotOk(x.t);
return WhyNotOk(context, x.t);
} else {
return parser::MessageFormattedText{
"Statement may not appear in device code"_err_en_US};
}
}
template <typename A>
static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {
return WhyNotOk(x.value());
static MaybeMsg WhyNotOk(
SemanticsContext &context, const common::Indirection<A> &x) {
return WhyNotOk(context, x.value());
}
template <typename... As>
static MaybeMsg WhyNotOk(const std::variant<As...> &x) {
return common::visit([](const auto &x) { return WhyNotOk(x); }, x);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const std::variant<As...> &x) {
return common::visit(
[&context](const auto &x) { return WhyNotOk(context, x); }, x);
}
template <std::size_t J = 0, typename... As>
static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {
static MaybeMsg WhyNotOk(
SemanticsContext &context, const std::tuple<As...> &x) {
if constexpr (J == sizeof...(As)) {
return {};
} else if (auto msg{WhyNotOk(std::get<J>(x))}) {
} else if (auto msg{WhyNotOk(context, std::get<J>(x))}) {
return msg;
} else {
return WhyNotOk<(J + 1)>(x);
return WhyNotOk<(J + 1)>(context, x);
}
}
template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {
template <typename A>
static MaybeMsg WhyNotOk(SemanticsContext &context, const std::list<A> &x) {
for (const auto &y : x) {
if (MaybeMsg result{WhyNotOk(y)}) {
if (MaybeMsg result{WhyNotOk(context, y)}) {
return result;
}
}
return {};
}
template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {
template <typename A>
static MaybeMsg WhyNotOk(
SemanticsContext &context, const std::optional<A> &x) {
if (x) {
return WhyNotOk(*x);
return WhyNotOk(context, *x);
} else {
return {};
}
}
template <typename A>
static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {
return WhyNotOk(x.statement);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::UnlabeledStatement<A> &x) {
return WhyNotOk(context, x.statement);
}
template <typename A>
static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {
return WhyNotOk(x.statement);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::Statement<A> &x) {
return WhyNotOk(context, x.statement);
}
static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::AllocateStmt &) {
return {}; // AllocateObjects are checked elsewhere
}
static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::AllocateCoarraySpec &) {
return parser::MessageFormattedText(
"A coarray may not be allocated on the device"_err_en_US);
}
static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::DeallocateStmt &) {
return {}; // AllocateObjects are checked elsewhere
}
static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {
return DeviceExprChecker{}(x.typedAssignment);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::AssignmentStmt &x) {
return DeviceExprChecker{context}(x.typedAssignment);
}
static MaybeMsg WhyNotOk(const parser::CallStmt &x) {
return DeviceExprChecker{}(x.typedCall);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::CallStmt &x) {
return DeviceExprChecker{context}(x.typedCall);
}
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::ContinueStmt &) {
return {};
}
static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }
static MaybeMsg WhyNotOk(const parser::IfStmt &x) {
if (auto result{
CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {
static MaybeMsg WhyNotOk(SemanticsContext &context, const parser::IfStmt &x) {
if (auto result{CheckUnwrappedExpr(
context, std::get<parser::ScalarLogicalExpr>(x.t))}) {
return result;
}
return WhyNotOk(
return WhyNotOk(context,
std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
.statement);
}
static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::NullifyStmt &x) {
for (const auto &y : x.v) {
if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {
if (MaybeMsg result{DeviceExprChecker{context}(y.typedExpr)}) {
return result;
}
}
return {};
}
static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {
return DeviceExprChecker{}(x.typedAssignment);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::PointerAssignmentStmt &x) {
return DeviceExprChecker{context}(x.typedAssignment);
}
};

Expand Down Expand Up @@ -435,12 +472,14 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
ErrorIfHostSymbol(assign->lhs, source);
ErrorIfHostSymbol(assign->rhs, source);
}
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
context_, x)}) {
context_.Say(source, std::move(*msg));
}
},
[&](const auto &x) {
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
context_, x)}) {
context_.Say(source, std::move(*msg));
}
},
Expand Down Expand Up @@ -504,7 +543,7 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
Check(DEREF(parser::Unwrap<parser::Expr>(x)));
}
void Check(const parser::Expr &expr) {
if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {
if (MaybeMsg msg{DeviceExprChecker{context_}(expr.typedExpr)}) {
context_.Say(expr.source, std::move(*msg));
}
}
Expand Down
8 changes: 8 additions & 0 deletions flang/test/Semantics/cuf22.cuf
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
! RUN: not bbc -fcuda -fcuda-disable-warp-function %s -o - 2>&1 | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does not bbc do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bbc command will fail since I want to check the error. Adding the not makes the test pass because the return value of the command is checked.


attributes(device) subroutine testMatch()
integer :: a, ipred, mask, v32
a = match_all_sync(mask, v32, ipred)
end subroutine

! CHECK: warp match function disabled
10 changes: 10 additions & 0 deletions flang/tools/bbc/bbc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ static llvm::cl::opt<bool> enableCUDA("fcuda",
llvm::cl::desc("enable CUDA Fortran"),
llvm::cl::init(false));

static llvm::cl::opt<bool>
disableCUDAWarpFunction("fcuda-disable-warp-function",
llvm::cl::desc("Disable CUDA Warp Function"),
llvm::cl::init(false));

static llvm::cl::opt<std::string>
enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"),
llvm::cl::init(""));
Expand Down Expand Up @@ -600,6 +605,11 @@ int main(int argc, char **argv) {
options.features.Enable(Fortran::common::LanguageFeature::CUDA);
}

if (disableCUDAWarpFunction) {
options.features.Enable(
Fortran::common::LanguageFeature::CudaWarpMatchFunction, false);
}

if (enableGPUMode == "managed") {
options.features.Enable(Fortran::common::LanguageFeature::CudaManaged);
} else if (enableGPUMode == "unified") {
Expand Down