diff --git a/flang/include/flang/Support/Fortran-features.h b/flang/include/flang/Support/Fortran-features.h index 3f6d825e2b66c..ea0845b7d605f 100644 --- a/flang/include/flang/Support/Fortran-features.h +++ b/flang/include/flang/Support/Fortran-features.h @@ -55,7 +55,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines, SavedLocalInSpecExpr, PrintNamelist, AssumedRankPassedToNonAssumedRank, IgnoreIrrelevantAttributes, Unsigned, AmbiguousStructureConstructor, ContiguousOkForSeqAssociation, ForwardRefExplicitTypeDummy, - InaccessibleDeferredOverride) + InaccessibleDeferredOverride, CudaWarpMatchFunction) // Portability and suspicious usage warnings ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable, diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp index c024640af1220..8decfb0149829 100644 --- a/flang/lib/Semantics/check-cuda.cpp +++ b/flang/lib/Semantics/check-cuda.cpp @@ -17,6 +17,7 @@ #include "flang/Semantics/expression.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" +#include "llvm/ADT/StringSet.h" // Once labeled DO constructs have been canonicalized and their parse subtrees // transformed into parser::DoConstructs, scan the parser::Blocks of the program @@ -61,6 +62,11 @@ bool CanonicalizeCUDA(parser::Program &program) { using MaybeMsg = std::optional; +static const llvm::StringSet<> warpFunctions_ = {"match_all_syncjj", + "match_all_syncjx", "match_all_syncjf", "match_all_syncjd", + "match_any_syncjj", "match_any_syncjx", "match_any_syncjf", + "match_any_syncjd"}; + // Traverses an evaluate::Expr<> in search of unsupported operations // on the device. @@ -68,7 +74,7 @@ struct DeviceExprChecker : public evaluate::AnyTraverse { using Result = MaybeMsg; using Base = evaluate::AnyTraverse; - DeviceExprChecker() : Base(*this) {} + explicit DeviceExprChecker(SemanticsContext &c) : Base(*this), context_{c} {} using Base::operator(); Result operator()(const evaluate::ProcedureDesignator &x) const { if (const Symbol * sym{x.GetInterfaceSymbol()}) { @@ -78,10 +84,17 @@ struct DeviceExprChecker if (auto attrs{subp->cudaSubprogramAttrs()}) { if (*attrs == common::CUDASubprogramAttrs::HostDevice || *attrs == common::CUDASubprogramAttrs::Device) { + if (warpFunctions_.contains(sym->name().ToString()) && + !context_.languageFeatures().IsEnabled( + Fortran::common::LanguageFeature::CudaWarpMatchFunction)) { + return parser::MessageFormattedText( + "warp match function disabled"_err_en_US); + } return {}; } } } + const Symbol &ultimate{sym->GetUltimate()}; const Scope &scope{ultimate.owner()}; const Symbol *mod{scope.IsModule() ? scope.symbol() : nullptr}; @@ -94,9 +107,12 @@ struct DeviceExprChecker // TODO(CUDA): Check for unsupported intrinsics here return {}; } + return parser::MessageFormattedText( "'%s' may not be called in device code"_err_en_US, x.GetName()); } + + SemanticsContext &context_; }; struct FindHostArray @@ -133,9 +149,10 @@ struct FindHostArray } }; -template static MaybeMsg CheckUnwrappedExpr(const A &x) { +template +static MaybeMsg CheckUnwrappedExpr(SemanticsContext &context, const A &x) { if (const auto *expr{parser::Unwrap(x)}) { - return DeviceExprChecker{}(expr->typedExpr); + return DeviceExprChecker{context}(expr->typedExpr); } return {}; } @@ -144,104 +161,124 @@ template static void CheckUnwrappedExpr( SemanticsContext &context, SourceName at, const A &x) { if (const auto *expr{parser::Unwrap(x)}) { - if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) { + if (auto msg{DeviceExprChecker{context}(expr->typedExpr)}) { context.Say(at, std::move(*msg)); } } } template struct ActionStmtChecker { - template static MaybeMsg WhyNotOk(const A &x) { + template + static MaybeMsg WhyNotOk(SemanticsContext &context, const A &x) { if constexpr (ConstraintTrait) { - return WhyNotOk(x.thing); + return WhyNotOk(context, x.thing); } else if constexpr (WrapperTrait) { - return WhyNotOk(x.v); + return WhyNotOk(context, x.v); } else if constexpr (UnionTrait) { - return WhyNotOk(x.u); + return WhyNotOk(context, x.u); } else if constexpr (TupleTrait) { - return WhyNotOk(x.t); + return WhyNotOk(context, x.t); } else { return parser::MessageFormattedText{ "Statement may not appear in device code"_err_en_US}; } } template - static MaybeMsg WhyNotOk(const common::Indirection &x) { - return WhyNotOk(x.value()); + static MaybeMsg WhyNotOk( + SemanticsContext &context, const common::Indirection &x) { + return WhyNotOk(context, x.value()); } template - static MaybeMsg WhyNotOk(const std::variant &x) { - return common::visit([](const auto &x) { return WhyNotOk(x); }, x); + static MaybeMsg WhyNotOk( + SemanticsContext &context, const std::variant &x) { + return common::visit( + [&context](const auto &x) { return WhyNotOk(context, x); }, x); } template - static MaybeMsg WhyNotOk(const std::tuple &x) { + static MaybeMsg WhyNotOk( + SemanticsContext &context, const std::tuple &x) { if constexpr (J == sizeof...(As)) { return {}; - } else if (auto msg{WhyNotOk(std::get(x))}) { + } else if (auto msg{WhyNotOk(context, std::get(x))}) { return msg; } else { - return WhyNotOk<(J + 1)>(x); + return WhyNotOk<(J + 1)>(context, x); } } - template static MaybeMsg WhyNotOk(const std::list &x) { + template + static MaybeMsg WhyNotOk(SemanticsContext &context, const std::list &x) { for (const auto &y : x) { - if (MaybeMsg result{WhyNotOk(y)}) { + if (MaybeMsg result{WhyNotOk(context, y)}) { return result; } } return {}; } - template static MaybeMsg WhyNotOk(const std::optional &x) { + template + static MaybeMsg WhyNotOk( + SemanticsContext &context, const std::optional &x) { if (x) { - return WhyNotOk(*x); + return WhyNotOk(context, *x); } else { return {}; } } template - static MaybeMsg WhyNotOk(const parser::UnlabeledStatement &x) { - return WhyNotOk(x.statement); + static MaybeMsg WhyNotOk( + SemanticsContext &context, const parser::UnlabeledStatement &x) { + return WhyNotOk(context, x.statement); } template - static MaybeMsg WhyNotOk(const parser::Statement &x) { - return WhyNotOk(x.statement); + static MaybeMsg WhyNotOk( + SemanticsContext &context, const parser::Statement &x) { + return WhyNotOk(context, x.statement); } - static MaybeMsg WhyNotOk(const parser::AllocateStmt &) { + static MaybeMsg WhyNotOk( + SemanticsContext &context, const parser::AllocateStmt &) { return {}; // AllocateObjects are checked elsewhere } - static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) { + static MaybeMsg WhyNotOk( + SemanticsContext &context, const parser::AllocateCoarraySpec &) { return parser::MessageFormattedText( "A coarray may not be allocated on the device"_err_en_US); } - static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) { + static MaybeMsg WhyNotOk( + SemanticsContext &context, const parser::DeallocateStmt &) { return {}; // AllocateObjects are checked elsewhere } - static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) { - return DeviceExprChecker{}(x.typedAssignment); + static MaybeMsg WhyNotOk( + SemanticsContext &context, const parser::AssignmentStmt &x) { + return DeviceExprChecker{context}(x.typedAssignment); } - static MaybeMsg WhyNotOk(const parser::CallStmt &x) { - return DeviceExprChecker{}(x.typedCall); + static MaybeMsg WhyNotOk( + SemanticsContext &context, const parser::CallStmt &x) { + return DeviceExprChecker{context}(x.typedCall); + } + static MaybeMsg WhyNotOk( + SemanticsContext &context, const parser::ContinueStmt &) { + return {}; } - static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; } - static MaybeMsg WhyNotOk(const parser::IfStmt &x) { - if (auto result{ - CheckUnwrappedExpr(std::get(x.t))}) { + static MaybeMsg WhyNotOk(SemanticsContext &context, const parser::IfStmt &x) { + if (auto result{CheckUnwrappedExpr( + context, std::get(x.t))}) { return result; } - return WhyNotOk( + return WhyNotOk(context, std::get>(x.t) .statement); } - static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) { + static MaybeMsg WhyNotOk( + SemanticsContext &context, const parser::NullifyStmt &x) { for (const auto &y : x.v) { - if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) { + if (MaybeMsg result{DeviceExprChecker{context}(y.typedExpr)}) { return result; } } return {}; } - static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) { - return DeviceExprChecker{}(x.typedAssignment); + static MaybeMsg WhyNotOk( + SemanticsContext &context, const parser::PointerAssignmentStmt &x) { + return DeviceExprChecker{context}(x.typedAssignment); } }; @@ -435,12 +472,14 @@ template class DeviceContextChecker { ErrorIfHostSymbol(assign->lhs, source); ErrorIfHostSymbol(assign->rhs, source); } - if (auto msg{ActionStmtChecker::WhyNotOk(x)}) { + if (auto msg{ActionStmtChecker::WhyNotOk( + context_, x)}) { context_.Say(source, std::move(*msg)); } }, [&](const auto &x) { - if (auto msg{ActionStmtChecker::WhyNotOk(x)}) { + if (auto msg{ActionStmtChecker::WhyNotOk( + context_, x)}) { context_.Say(source, std::move(*msg)); } }, @@ -504,7 +543,7 @@ template class DeviceContextChecker { Check(DEREF(parser::Unwrap(x))); } void Check(const parser::Expr &expr) { - if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) { + if (MaybeMsg msg{DeviceExprChecker{context_}(expr.typedExpr)}) { context_.Say(expr.source, std::move(*msg)); } } diff --git a/flang/test/Semantics/cuf22.cuf b/flang/test/Semantics/cuf22.cuf new file mode 100644 index 0000000000000..36e0f0b2502df --- /dev/null +++ b/flang/test/Semantics/cuf22.cuf @@ -0,0 +1,8 @@ +! RUN: not bbc -fcuda -fcuda-disable-warp-function %s -o - 2>&1 | FileCheck %s + +attributes(device) subroutine testMatch() + integer :: a, ipred, mask, v32 + a = match_all_sync(mask, v32, ipred) +end subroutine + +! CHECK: warp match function disabled diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp index c544008a24d56..c80872108ac8f 100644 --- a/flang/tools/bbc/bbc.cpp +++ b/flang/tools/bbc/bbc.cpp @@ -223,6 +223,11 @@ static llvm::cl::opt enableCUDA("fcuda", llvm::cl::desc("enable CUDA Fortran"), llvm::cl::init(false)); +static llvm::cl::opt + disableCUDAWarpFunction("fcuda-disable-warp-function", + llvm::cl::desc("Disable CUDA Warp Function"), + llvm::cl::init(false)); + static llvm::cl::opt enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"), llvm::cl::init("")); @@ -600,6 +605,11 @@ int main(int argc, char **argv) { options.features.Enable(Fortran::common::LanguageFeature::CUDA); } + if (disableCUDAWarpFunction) { + options.features.Enable( + Fortran::common::LanguageFeature::CudaWarpMatchFunction, false); + } + if (enableGPUMode == "managed") { options.features.Enable(Fortran::common::LanguageFeature::CudaManaged); } else if (enableGPUMode == "unified") {