diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h index c3a02fca5ade8..6ba43f6688c25 100644 --- a/flang/include/flang/Parser/parse-tree.h +++ b/flang/include/flang/Parser/parse-tree.h @@ -4553,8 +4553,8 @@ WRAPPER_CLASS(OmpReductionInitializerClause, Expr); struct OpenMPDeclareReductionConstruct { TUPLE_CLASS_BOILERPLATE(OpenMPDeclareReductionConstruct); CharBlock source; - std::tuple, - OmpReductionCombiner, std::optional> + std::tuple, + std::optional> t; }; diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index 2b6c77c08cc58..b39b8737b70c0 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -170,8 +170,8 @@ TYPE_PARSER(sourced( // TYPE_PARSER(construct(nonemptyList(Parser{}))) TYPE_PARSER( // - construct(Parser{}) || - construct(Parser{})) + construct(Parser{}) || + construct(Parser{})) TYPE_PARSER(construct( // Parser{}, @@ -1148,9 +1148,7 @@ TYPE_PARSER(construct( // 2.16 Declare Reduction Construct TYPE_PARSER(sourced(construct( verbatim("DECLARE REDUCTION"_tok), - "(" >> Parser{} / ":", - nonemptyList(Parser{}) / ":", - Parser{} / ")", + "(" >> indirect(Parser{}) / ")", maybe(Parser{})))) // declare-target with list diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index cd91fbe4ea5eb..3d00979d7b7a6 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -2690,11 +2690,10 @@ class UnparseVisitor { BeginOpenMP(); Word("!$OMP DECLARE REDUCTION "); Put("("); - Walk(std::get(x.t)), Put(" : "); - Walk(std::get>(x.t), ","), Put(" : "); - Walk(std::get(x.t)); + Walk(std::get>(x.t)); Put(")"); Walk(std::get>(x.t)); + Put("\n"); EndOpenMP(); } diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index fd2893998205c..9d7b60cdecbd0 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -3180,6 +3180,10 @@ bool OmpStructureChecker::CheckReductionOperator( const SourceName &realName{name->symbol->GetUltimate().name()}; valid = llvm::is_contained({"max", "min", "iand", "ior", "ieor"}, realName); + if (!valid) { + auto *misc{name->symbol->detailsIf()}; + valid = misc && misc->kind() == MiscDetails::Kind::ConstructName; + } } if (!valid) { context_.Say(source, diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 91a1b3061e1f9..94b653c152c5b 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -446,6 +446,9 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor { bool Pre(const parser::OpenMPDeclareMapperConstruct &); void Post(const parser::OpenMPDeclareMapperConstruct &) { PopContext(); } + bool Pre(const parser::OpenMPDeclareReductionConstruct &); + void Post(const parser::OpenMPDeclareReductionConstruct &) { PopContext(); } + bool Pre(const parser::OpenMPThreadprivate &); void Post(const parser::OpenMPThreadprivate &) { PopContext(); } @@ -1976,6 +1979,12 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPDeclareMapperConstruct &x) { return true; } +bool OmpAttributeVisitor::Pre( + const parser::OpenMPDeclareReductionConstruct &x) { + PushContext(x.source, llvm::omp::Directive::OMPD_declare_reduction); + return true; +} + bool OmpAttributeVisitor::Pre(const parser::OpenMPThreadprivate &x) { PushContext(x.source, llvm::omp::Directive::OMPD_threadprivate); const auto &list{std::get(x.t)}; diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index e64abe6b50e78..ff793658f1e06 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -1482,6 +1482,15 @@ class OmpVisitor : public virtual DeclarationVisitor { return false; } + bool Pre(const parser::OpenMPDeclareReductionConstruct &x) { + AddOmpSourceRange(x.source); + parser::OmpClauseList emptyList{std::list{}}; + ProcessReductionSpecifier( + std::get>(x.t).value(), + emptyList); + Walk(std::get>(x.t)); + return false; + } bool Pre(const parser::OmpMapClause &); void Post(const parser::OmpBeginLoopDirective &) { @@ -1732,11 +1741,19 @@ void OmpVisitor::ProcessMapperSpecifier(const parser::OmpMapperSpecifier &spec, void OmpVisitor::ProcessReductionSpecifier( const parser::OmpReductionSpecifier &spec, const parser::OmpClauseList &clauses) { + BeginDeclTypeSpec(); + const auto &id{std::get(spec.t)}; + if (auto procDes{std::get_if(&id.u)}) { + if (auto *name{std::get_if(&procDes->u)}) { + name->symbol = + &MakeSymbol(*name, MiscDetails{MiscDetails::Kind::ConstructName}); + } + } + EndDeclTypeSpec(); // Creating a new scope in case the combiner expression (or clauses) use // reerved identifiers, like "omp_in". This is a temporary solution until // we deal with these in a more thorough way. PushScope(Scope::Kind::OtherConstruct, nullptr); - Walk(std::get(spec.t)); Walk(std::get(spec.t)); Walk(std::get>(spec.t)); Walk(clauses); diff --git a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90 b/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90 index 7a7d28db8d6f5..db50c9ac8ee9d 100644 --- a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90 +++ b/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90 @@ -1,10 +1,10 @@ ! This test checks lowering of OpenMP declare reduction Directive. -// RUN: not flang -fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s +! RUN: not flang -fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s subroutine declare_red() integer :: my_var - // CHECK: not yet implemented: OpenMPDeclareReductionConstruct + !CHECK: not yet implemented: OpenMPDeclareReductionConstruct !$omp declare reduction (my_red : integer : omp_out = omp_in) initializer (omp_priv = 0) my_var = 0 end subroutine declare_red diff --git a/flang/test/Parser/OpenMP/declare-reduction-unparse.f90 b/flang/test/Parser/OpenMP/declare-reduction-unparse.f90 new file mode 100644 index 0000000000000..a2a3ef9f630ab --- /dev/null +++ b/flang/test/Parser/OpenMP/declare-reduction-unparse.f90 @@ -0,0 +1,21 @@ +! RUN: %flang_fc1 -fdebug-unparse -fopenmp %s | FileCheck --ignore-case %s +! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp %s | FileCheck --check-prefix="PARSE-TREE" %s +!CHECK-LABEL: program main +program main + integer :: my_var + !CHECK: !$OMP DECLARE REDUCTION (my_add_red:INTEGER: omp_out=omp_out+omp_in + !CHECK-NEXT: ) INITIALIZER(OMP_PRIV = 0_4) + + !$omp declare reduction (my_add_red : integer : omp_out = omp_out + omp_in) initializer (omp_priv=0) + my_var = 0 + !$omp parallel reduction (my_add_red : my_var) num_threads(4) + my_var = omp_get_thread_num() + 1 + !$omp end parallel + print *, "sum of thread numbers is ", my_var +end program main + +!PARSE-TREE: OpenMPDeclareReductionConstruct +!PARSE-TREE: OmpReductionIdentifier -> ProcedureDesignator -> Name = 'my_add_red' +!PARSE-TREE: DeclarationTypeSpec -> IntrinsicTypeSpec -> IntegerTypeSpec +!PARSE-TREE: OmpReductionCombiner -> AssignmentStmt = 'omp_out=omp_out+omp_in' +!PARSE-TREE: OmpReductionInitializerClause -> Expr = '0_4' diff --git a/flang/test/Semantics/OpenMP/declarative-directive01.f90 b/flang/test/Semantics/OpenMP/declarative-directive01.f90 index 17dc50b70e542..e8bf605565fad 100644 --- a/flang/test/Semantics/OpenMP/declarative-directive01.f90 +++ b/flang/test/Semantics/OpenMP/declarative-directive01.f90 @@ -2,9 +2,6 @@ ! Check OpenMP declarative directives -!TODO: all internal errors -! enable declare-reduction example after name resolution - ! 2.4 requires subroutine requires_1(a) @@ -88,15 +85,14 @@ end module m2 ! 2.16 declare-reduction -! subroutine declare_red_1() -! use omp_lib -! integer :: my_var -! !$omp declare reduction (my_add_red : integer : omp_out = omp_out + omp_in) initializer (omp_priv=0) -! my_var = 0 -! !$omp parallel reduction (my_add_red : my_var) num_threads(4) -! my_var = omp_get_thread_num() + 1 -! !$omp end parallel -! print *, "sum of thread numbers is ", my_var -! end subroutine declare_red_1 +subroutine declare_red_1() + integer :: my_var + !$omp declare reduction (my_add_red : integer : omp_out = omp_out + omp_in) initializer (omp_priv=0) + my_var = 0 + !$omp parallel reduction (my_add_red : my_var) num_threads(4) + my_var = 1 + !$omp end parallel + print *, "sum of thread numbers is ", my_var +end subroutine declare_red_1 end diff --git a/flang/test/Semantics/OpenMP/declare-reduction.f90 b/flang/test/Semantics/OpenMP/declare-reduction.f90 new file mode 100644 index 0000000000000..8fee79dfc0b7b --- /dev/null +++ b/flang/test/Semantics/OpenMP/declare-reduction.f90 @@ -0,0 +1,11 @@ +! RUN: %flang_fc1 -fdebug-dump-symbols -fopenmp -fopenmp-version=50 %s | FileCheck %s + +program main +!CHECK-LABEL: MainProgram scope: main + + !$omp declare reduction (my_add_red : integer : omp_out = omp_out + omp_in) initializer (omp_priv=0) + +!CHECK: my_add_red: Misc ConstructName + +end program main +