From 5020e498440b0016adef7e99806aa55c4837b441 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 5 Jun 2024 13:08:22 -0500 Subject: [PATCH 01/33] Add getters for multi dim loop variables in LoopLikeOpInterface --- .../mlir/Dialect/Affine/IR/AffineOps.td | 4 +- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 37 ++-------- .../mlir/Interfaces/LoopLikeInterface.td | 65 +++++++++++------ mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 20 +++--- mlir/lib/Dialect/SCF/IR/SCF.cpp | 70 +++++++------------ .../Dialect/SCF/LoopLikeSCFOpsTest.cpp | 8 +++ 6 files changed, 97 insertions(+), 107 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index 3640055ea8da8..bb2c29b5733b8 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for", [AttrSizedOperandSegments, AutomaticAllocationScope, ImplicitAffineTerminator, ConditionallySpeculatable, RecursiveMemoryEffects, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 0b063aa772bab..3b28ca8b21d0f 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ def ForOp : SCF_Op<"for", [AutomaticAllocationScope, DeclareOpInterfaceMethods, AllTypesMatch<["lowerBound", "upperBound", "step"]>, @@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [ AttrSizedOperandSegments, AutomaticAllocationScope, DeclareOpInterfaceMethods, RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, @@ -510,24 +510,6 @@ def ForallOp : SCF_Op<"forall", [ ]; let extraClassDeclaration = [{ - // Get lower bounds as OpFoldResult. - SmallVector getMixedLowerBound() { - Builder b(getOperation()->getContext()); - return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b); - } - - // Get upper bounds as OpFoldResult. - SmallVector getMixedUpperBound() { - Builder b(getOperation()->getContext()); - return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b); - } - - // Get steps as OpFoldResult. - SmallVector getMixedStep() { - Builder b(getOperation()->getContext()); - return getMixedValues(getStaticStep(), getDynamicStep(), b); - } - /// Get lower bounds as values. SmallVector getLowerBound(OpBuilder &b) { return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound()); @@ -584,10 +566,6 @@ def ForallOp : SCF_Op<"forall", [ getNumDynamicControlOperands() + getRank()); } - ::mlir::ValueRange getInductionVars() { - return getBody()->getArguments().take_front(getRank()); - } - ::mlir::Value getInductionVar(int64_t idx) { return getInductionVars()[idx]; } @@ -765,8 +743,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, RecursiveMemoryEffects, DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::ReduceOp">, @@ -846,9 +824,6 @@ def ParallelOp : SCF_Op<"parallel", ]; let extraClassDeclaration = [{ - ValueRange getInductionVars() { - return getBody()->getArguments(); - } unsigned getNumLoops() { return getStep().size(); } unsigned getNumReductions() { return getInitVals().size(); } }]; diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td index f0dc6e60eba58..813779c852027 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -93,51 +93,47 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { }] >, InterfaceMethod<[{ - If there is a single induction variable return it, otherwise return - std::nullopt. + Return all induction variables. }], - /*retTy=*/"::std::optional<::mlir::Value>", - /*methodName=*/"getSingleInductionVar", + /*retTy=*/"::mlir::ValueRange", + /*methodName=*/"getInductionVars", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return std::nullopt; + return {}; }] >, InterfaceMethod<[{ - Return the single lower bound value or attribute if it exists, otherwise - return std::nullopt. + Return all lower bounds. }], - /*retTy=*/"::std::optional<::mlir::OpFoldResult>", - /*methodName=*/"getSingleLowerBound", + /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>", + /*methodName=*/"getMixedLowerBound", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return std::nullopt; + return {}; }] >, InterfaceMethod<[{ - Return the single step value or attribute if it exists, otherwise - return std::nullopt. + Return all steps. }], - /*retTy=*/"::std::optional<::mlir::OpFoldResult>", - /*methodName=*/"getSingleStep", + /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>", + /*methodName=*/"getMixedStep", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return std::nullopt; + return {}; }] >, InterfaceMethod<[{ - Return the single upper bound value or attribute if it exists, otherwise - return std::nullopt. + Return all upper bounds. }], - /*retTy=*/"::std::optional<::mlir::OpFoldResult>", - /*methodName=*/"getSingleUpperBound", + /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>", + /*methodName=*/"getMixedUpperBound", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return std::nullopt; + return {}; }] >, InterfaceMethod<[{ @@ -235,6 +231,35 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { }]; let extraSharedClassDeclaration = [{ + /// If there is a single induction variable return it, otherwise return + /// std::nullopt. + ::std::optional<::mlir::Value> getSingleInductionVar() { + if (this->getInductionVars().size() == 1) + return this->getInductionVars()[0]; + return std::nullopt; + } + /// Return the single lower bound value or attribute if it exists, otherwise + /// return std::nullopt. + ::std::optional<::mlir::OpFoldResult> getSingleLowerBound() { + if (this->getMixedLowerBound().size() == 1) + return this->getMixedLowerBound()[0]; + return std::nullopt; + } + /// Return the single step value or attribute if it exists, otherwise + /// return std::nullopt. + ::std::optional<::mlir::OpFoldResult> getSingleStep() { + if (this->getMixedStep().size() == 1) + return this->getMixedStep()[0]; + return std::nullopt; + } + /// Return the single upper bound value or attribute if it exists, otherwise + /// return std::nullopt. + ::std::optional<::mlir::OpFoldResult> getSingleUpperBound() { + if (this->getMixedUpperBound().size() == 1) + return this->getMixedUpperBound()[0]; + return std::nullopt; + } + /// Append the specified additional "init" operands: replace this loop with /// a new loop that has the additional init operands. The loop body of this /// loop is moved over to the new loop. diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 2e31487bd55a0..746a9c919560c 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2454,27 +2454,25 @@ bool AffineForOp::matchingBoundOperandList() { SmallVector AffineForOp::getLoopRegions() { return {&getRegion()}; } -std::optional AffineForOp::getSingleInductionVar() { - return getInductionVar(); -} +ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; } -std::optional AffineForOp::getSingleLowerBound() { +SmallVector AffineForOp::getMixedLowerBound() { if (!hasConstantLowerBound()) - return std::nullopt; + return {}; OpBuilder b(getContext()); - return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound())); + return {OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))}; } -std::optional AffineForOp::getSingleStep() { +SmallVector AffineForOp::getMixedStep() { OpBuilder b(getContext()); - return OpFoldResult(b.getI64IntegerAttr(getStepAsInt())); + return {OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))}; } -std::optional AffineForOp::getSingleUpperBound() { +SmallVector AffineForOp::getMixedUpperBound() { if (!hasConstantUpperBound()) - return std::nullopt; + return {}; OpBuilder b(getContext()); - return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound())); + return {OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))}; } FailureOr AffineForOp::replaceWithAdditionalYields( diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 107fd0690f193..e275ff1849c10 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -378,20 +378,18 @@ LogicalResult ForOp::verifyRegions() { return success(); } -std::optional ForOp::getSingleInductionVar() { - return getInductionVar(); -} +ValueRange ForOp::getInductionVars() { return {getInductionVar()}; } -std::optional ForOp::getSingleLowerBound() { - return OpFoldResult(getLowerBound()); +SmallVector ForOp::getMixedLowerBound() { + return {OpFoldResult(getLowerBound())}; } -std::optional ForOp::getSingleStep() { - return OpFoldResult(getStep()); +SmallVector ForOp::getMixedStep() { + return {OpFoldResult(getStep())}; } -std::optional ForOp::getSingleUpperBound() { - return OpFoldResult(getUpperBound()); +SmallVector ForOp::getMixedUpperBound() { + return {OpFoldResult(getUpperBound())}; } std::optional ForOp::getLoopResults() { return getResults(); } @@ -1428,28 +1426,26 @@ SmallVector ForallOp::getCombiningOps(BlockArgument bbArg) { return storeOps; } -std::optional ForallOp::getSingleInductionVar() { - if (getRank() != 1) - return std::nullopt; - return getInductionVar(0); +ValueRange ForallOp::getInductionVars() { + return getBody()->getArguments().take_front(getRank()); } -std::optional ForallOp::getSingleLowerBound() { - if (getRank() != 1) - return std::nullopt; - return getMixedLowerBound()[0]; +// Get lower bounds as OpFoldResult. +SmallVector ForallOp::getMixedLowerBound() { + Builder b(getOperation()->getContext()); + return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b); } -std::optional ForallOp::getSingleUpperBound() { - if (getRank() != 1) - return std::nullopt; - return getMixedUpperBound()[0]; +// Get upper bounds as OpFoldResult. +SmallVector ForallOp::getMixedUpperBound() { + Builder b(getOperation()->getContext()); + return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b); } -std::optional ForallOp::getSingleStep() { - if (getRank() != 1) - return std::nullopt; - return getMixedStep()[0]; +// Get steps as OpFoldResult. +SmallVector ForallOp::getMixedStep() { + Builder b(getOperation()->getContext()); + return getMixedValues(getStaticStep(), getDynamicStep(), b); } ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) { @@ -3008,29 +3004,17 @@ void ParallelOp::print(OpAsmPrinter &p) { SmallVector ParallelOp::getLoopRegions() { return {&getRegion()}; } -std::optional ParallelOp::getSingleInductionVar() { - if (getNumLoops() != 1) - return std::nullopt; - return getBody()->getArgument(0); -} +ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); } -std::optional ParallelOp::getSingleLowerBound() { - if (getNumLoops() != 1) - return std::nullopt; - return getLowerBound()[0]; +SmallVector ParallelOp::getMixedLowerBound() { + return getLowerBound(); } -std::optional ParallelOp::getSingleUpperBound() { - if (getNumLoops() != 1) - return std::nullopt; - return getUpperBound()[0]; +SmallVector ParallelOp::getMixedUpperBound() { + return getUpperBound(); } -std::optional ParallelOp::getSingleStep() { - if (getNumLoops() != 1) - return std::nullopt; - return getStep()[0]; -} +SmallVector ParallelOp::getMixedStep() { return getStep(); } ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) { auto ivArg = llvm::dyn_cast(val); diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp index 6bc0fd6113b9b..d8cdb213070da 100644 --- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp +++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp @@ -36,6 +36,10 @@ class SCFLoopLikeTest : public ::testing::Test { std::optional maybeIndVar = loopLikeOp.getSingleInductionVar(); EXPECT_TRUE(maybeIndVar.has_value()); + EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u); + EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u); + EXPECT_EQ(loopLikeOp.getMixedStep().size(), 1u); + EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u); } void checkMultidimensional(LoopLikeOpInterface loopLikeOp) { @@ -48,6 +52,10 @@ class SCFLoopLikeTest : public ::testing::Test { std::optional maybeIndVar = loopLikeOp.getSingleInductionVar(); EXPECT_FALSE(maybeIndVar.has_value()); + EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u); + EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u); + EXPECT_EQ(loopLikeOp.getMixedStep().size(), 2u); + EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u); } MLIRContext context; From 50852d570440e0041c8b2b38925c4af05fac0636 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 4 Jun 2024 14:44:53 -0500 Subject: [PATCH 02/33] Refactor LoopFuseSiblingOp and support parallel fusion --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 16 ++ .../SCF/TransformOps/SCFTransformOps.cpp | 53 +++-- .../SCF/Transforms/ParallelLoopFusion.cpp | 204 +---------------- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 208 ++++++++++++++++++ .../SCF/transform-loop-fuse-sibling.mlir | 53 +++++ 5 files changed, 304 insertions(+), 230 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index bc09cc7f7fa5e..2944d8ffac022 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -156,6 +156,12 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef sizes); void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, scf::ForOp root); +/// Prepends operations of firstPloop's body into secondPloop's body. +/// Updates secondPloop with new loop. +void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop, + OpBuilder builder, + llvm::function_ref mayAlias); + /// Given two scf.forall loops, `target` and `source`, fuses `target` into /// `source`. Assumes that the given loops are siblings and are independent of /// each other. @@ -177,6 +183,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter); +/// Given two scf.parallel loops, `target` and `source`, fuses `target` into +/// `source`. Assumes that the given loops are siblings and are independent of +/// each other. +/// +/// This function does not perform any legality checks and simply fuses the +/// loops. The caller is responsible for ensuring that the loops are legal to +/// fuse. +scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target, + scf::ParallelOp source, + RewriterBase &rewriter); } // namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 69f83d8bd70da..1c53e89d69040 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -442,39 +442,32 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target, return DiagnosedSilenceableFailure::success(); } -/// Check if `target` scf.forall can be fused into `source` scf.forall. +/// Check if `target` scf loop can be fused into `source` scf loop. +/// Applies for scf.for, scf.forall, and scf.parallel. /// /// This simply checks if both loops have the same bounds, steps and mapping. /// No attempt is made at checking that the side effects of `target` and /// `source` are independent of each other. -static bool isForallWithIdenticalConfiguration(Operation *target, - Operation *source) { - auto targetOp = dyn_cast(target); - auto sourceOp = dyn_cast(source); - if (!targetOp || !sourceOp) - return false; - - return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && - targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && - targetOp.getMixedStep() == sourceOp.getMixedStep() && - targetOp.getMapping() == sourceOp.getMapping(); -} - -/// Check if `target` scf.for can be fused into `source` scf.for. -/// -/// This simply checks if both loops have the same bounds and steps. No attempt -/// is made at checking that the side effects of `target` and `source` are -/// independent of each other. -static bool isForWithIdenticalConfiguration(Operation *target, - Operation *source) { - auto targetOp = dyn_cast(target); - auto sourceOp = dyn_cast(source); +template +static bool isLoopWithIdenticalConfiguration(Operation *target, + Operation *source) { + static_assert(llvm::is_one_of::value, + "applies to only `forall`, `for` and `parallel`"); + auto targetOp = dyn_cast(target); + auto sourceOp = dyn_cast(source); if (!targetOp || !sourceOp) return false; - return targetOp.getLowerBound() == sourceOp.getLowerBound() && - targetOp.getUpperBound() == sourceOp.getUpperBound() && - targetOp.getStep() == sourceOp.getStep(); + if constexpr (std::is_same_v) + return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && + targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && + targetOp.getMixedStep() == sourceOp.getMixedStep() && + targetOp.getMapping() == sourceOp.getMapping(); + else + return targetOp.getLowerBound() == sourceOp.getLowerBound() && + targetOp.getUpperBound() == sourceOp.getUpperBound() && + targetOp.getStep() == sourceOp.getStep(); } DiagnosedSilenceableFailure @@ -502,12 +495,16 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, Operation *fusedLoop; /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. - if (isForWithIdenticalConfiguration(target, source)) { + if (isLoopWithIdenticalConfiguration(target, source)) { fusedLoop = fuseIndependentSiblingForLoops( cast(target), cast(source), rewriter); - } else if (isForallWithIdenticalConfiguration(target, source)) { + } else if (isLoopWithIdenticalConfiguration(target, source)) { fusedLoop = fuseIndependentSiblingForallLoops( cast(target), cast(source), rewriter); + } else if (isLoopWithIdenticalConfiguration(target, + source)) { + fusedLoop = fuseIndependentSiblingParallelLoops( + cast(target), cast(source), rewriter); } else return emitSilenceableFailure(target->getLoc()) << "operations cannot be fused"; diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index 5934d85373b03..abac91cfaf7d9 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" @@ -30,207 +31,6 @@ namespace mlir { using namespace mlir; using namespace mlir::scf; -/// Verify there are no nested ParallelOps. -static bool hasNestedParallelOp(ParallelOp ploop) { - auto walkResult = - ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); - return walkResult.wasInterrupted(); -} - -/// Verify equal iteration spaces. -static bool equalIterationSpaces(ParallelOp firstPloop, - ParallelOp secondPloop) { - if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) - return false; - - auto matchOperands = [&](const OperandRange &lhs, - const OperandRange &rhs) -> bool { - // TODO: Extend this to support aliases and equal constants. - return std::equal(lhs.begin(), lhs.end(), rhs.begin()); - }; - return matchOperands(firstPloop.getLowerBound(), - secondPloop.getLowerBound()) && - matchOperands(firstPloop.getUpperBound(), - secondPloop.getUpperBound()) && - matchOperands(firstPloop.getStep(), secondPloop.getStep()); -} - -/// Checks if the parallel loops have mixed access to the same buffers. Returns -/// `true` if the first parallel loop writes to the same indices that the second -/// loop reads. -static bool haveNoReadsAfterWriteExceptSameIndex( - ParallelOp firstPloop, ParallelOp secondPloop, - const IRMapping &firstToSecondPloopIndices, - llvm::function_ref mayAlias) { - DenseMap> bufferStores; - SmallVector bufferStoresVec; - firstPloop.getBody()->walk([&](memref::StoreOp store) { - bufferStores[store.getMemRef()].push_back(store.getIndices()); - bufferStoresVec.emplace_back(store.getMemRef()); - }); - auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { - Value loadMem = load.getMemRef(); - // Stop if the memref is defined in secondPloop body. Careful alias analysis - // is needed. - auto *memrefDef = loadMem.getDefiningOp(); - if (memrefDef && memrefDef->getBlock() == load->getBlock()) - return WalkResult::interrupt(); - - for (Value store : bufferStoresVec) - if (store != loadMem && mayAlias(store, loadMem)) - return WalkResult::interrupt(); - - auto write = bufferStores.find(loadMem); - if (write == bufferStores.end()) - return WalkResult::advance(); - - // Check that at last one store was retrieved - if (!write->second.size()) - return WalkResult::interrupt(); - - auto storeIndices = write->second.front(); - - // Multiple writes to the same memref are allowed only on the same indices - for (const auto &othStoreIndices : write->second) { - if (othStoreIndices != storeIndices) - return WalkResult::interrupt(); - } - - // Check that the load indices of secondPloop coincide with store indices of - // firstPloop for the same memrefs. - auto loadIndices = load.getIndices(); - if (storeIndices.size() != loadIndices.size()) - return WalkResult::interrupt(); - for (int i = 0, e = storeIndices.size(); i < e; ++i) { - if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != - loadIndices[i]) { - auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); - auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); - if (storeIndexDefOp && loadIndexDefOp) { - if (!isMemoryEffectFree(storeIndexDefOp)) - return WalkResult::interrupt(); - if (!isMemoryEffectFree(loadIndexDefOp)) - return WalkResult::interrupt(); - if (!OperationEquivalence::isEquivalentTo( - storeIndexDefOp, loadIndexDefOp, - [&](Value storeIndex, Value loadIndex) { - if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) != - firstToSecondPloopIndices.lookupOrDefault(loadIndex)) - return failure(); - else - return success(); - }, - /*markEquivalent=*/nullptr, - OperationEquivalence::Flags::IgnoreLocations)) { - return WalkResult::interrupt(); - } - } else - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); - return !walkResult.wasInterrupted(); -} - -/// Analyzes dependencies in the most primitive way by checking simple read and -/// write patterns. -static LogicalResult -verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, - const IRMapping &firstToSecondPloopIndices, - llvm::function_ref mayAlias) { - if (!haveNoReadsAfterWriteExceptSameIndex( - firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) - return failure(); - - IRMapping secondToFirstPloopIndices; - secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), - firstPloop.getBody()->getArguments()); - return success(haveNoReadsAfterWriteExceptSameIndex( - secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); -} - -static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, - const IRMapping &firstToSecondPloopIndices, - llvm::function_ref mayAlias) { - return !hasNestedParallelOp(firstPloop) && - !hasNestedParallelOp(secondPloop) && - equalIterationSpaces(firstPloop, secondPloop) && - succeeded(verifyDependencies(firstPloop, secondPloop, - firstToSecondPloopIndices, mayAlias)); -} - -/// Prepends operations of firstPloop's body into secondPloop's body. -/// Updates secondPloop with new loop. -static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, - OpBuilder builder, - llvm::function_ref mayAlias) { - Block *block1 = firstPloop.getBody(); - Block *block2 = secondPloop.getBody(); - IRMapping firstToSecondPloopIndices; - firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments()); - - if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, - mayAlias)) - return; - - DominanceInfo dom; - // We are fusing first loop into second, make sure there are no users of the - // first loop results between loops. - for (Operation *user : firstPloop->getUsers()) - if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) - return; - - ValueRange inits1 = firstPloop.getInitVals(); - ValueRange inits2 = secondPloop.getInitVals(); - - SmallVector newInitVars(inits1.begin(), inits1.end()); - newInitVars.append(inits2.begin(), inits2.end()); - - IRRewriter b(builder); - b.setInsertionPoint(secondPloop); - auto newSecondPloop = b.create( - secondPloop.getLoc(), secondPloop.getLowerBound(), - secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); - - Block *newBlock = newSecondPloop.getBody(); - auto term1 = cast(block1->getTerminator()); - auto term2 = cast(block2->getTerminator()); - - b.inlineBlockBefore(block2, newBlock, newBlock->begin(), - newBlock->getArguments()); - b.inlineBlockBefore(block1, newBlock, newBlock->begin(), - newBlock->getArguments()); - - ValueRange results = newSecondPloop.getResults(); - if (!results.empty()) { - b.setInsertionPointToEnd(newBlock); - - ValueRange reduceArgs1 = term1.getOperands(); - ValueRange reduceArgs2 = term2.getOperands(); - SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); - newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); - - auto newReduceOp = b.create(term2.getLoc(), newReduceArgs); - - for (auto &&[i, reg] : llvm::enumerate(llvm::concat( - term1.getReductions(), term2.getReductions()))) { - Block &oldRedBlock = reg.front(); - Block &newRedBlock = newReduceOp.getReductions()[i].front(); - b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), - newRedBlock.getArguments()); - } - - firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); - secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); - } - term1->erase(); - term2->erase(); - firstPloop.erase(); - secondPloop.erase(); - secondPloop = newSecondPloop; -} - void mlir::scf::naivelyFuseParallelOps( Region ®ion, llvm::function_ref mayAlias) { OpBuilder b(region); @@ -259,7 +59,7 @@ void mlir::scf::naivelyFuseParallelOps( } for (MutableArrayRef ploops : ploopChains) { for (int i = 0, e = ploops.size(); i + 1 < e; ++i) - fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); + mlir::fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); } } } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 6658cca03eba7..d85339f32dbe3 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" @@ -1070,6 +1071,206 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, return tileLoops; } +/// Checks if the parallel loops have mixed access to the same buffers. Returns +/// `true` if the first parallel loop writes to the same indices that the second +/// loop reads. +static bool haveNoReadsAfterWriteExceptSameIndex( + scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, + const IRMapping &firstToSecondPloopIndices, + llvm::function_ref mayAlias) { + DenseMap> bufferStores; + SmallVector bufferStoresVec; + firstPloop.getBody()->walk([&](memref::StoreOp store) { + bufferStores[store.getMemRef()].push_back(store.getIndices()); + bufferStoresVec.emplace_back(store.getMemRef()); + }); + auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { + Value loadMem = load.getMemRef(); + // Stop if the memref is defined in secondPloop body. Careful alias analysis + // is needed. + auto *memrefDef = loadMem.getDefiningOp(); + if (memrefDef && memrefDef->getBlock() == load->getBlock()) + return WalkResult::interrupt(); + + for (Value store : bufferStoresVec) + if (store != loadMem && mayAlias(store, loadMem)) + return WalkResult::interrupt(); + + auto write = bufferStores.find(loadMem); + if (write == bufferStores.end()) + return WalkResult::advance(); + + // Check that at last one store was retrieved + if (!write->second.size()) + return WalkResult::interrupt(); + + auto storeIndices = write->second.front(); + + // Multiple writes to the same memref are allowed only on the same indices + for (const auto &othStoreIndices : write->second) { + if (othStoreIndices != storeIndices) + return WalkResult::interrupt(); + } + + // Check that the load indices of secondPloop coincide with store indices of + // firstPloop for the same memrefs. + auto loadIndices = load.getIndices(); + if (storeIndices.size() != loadIndices.size()) + return WalkResult::interrupt(); + for (int i = 0, e = storeIndices.size(); i < e; ++i) { + if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != + loadIndices[i]) { + auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); + auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); + if (storeIndexDefOp && loadIndexDefOp) { + if (!isMemoryEffectFree(storeIndexDefOp)) + return WalkResult::interrupt(); + if (!isMemoryEffectFree(loadIndexDefOp)) + return WalkResult::interrupt(); + if (!OperationEquivalence::isEquivalentTo( + storeIndexDefOp, loadIndexDefOp, + [&](Value storeIndex, Value loadIndex) { + if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) != + firstToSecondPloopIndices.lookupOrDefault(loadIndex)) + return failure(); + else + return success(); + }, + /*markEquivalent=*/nullptr, + OperationEquivalence::Flags::IgnoreLocations)) { + return WalkResult::interrupt(); + } + } else + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + return !walkResult.wasInterrupted(); +} + +/// Analyzes dependencies in the most primitive way by checking simple read and +/// write patterns. +static LogicalResult +verifyDependencies(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, + const IRMapping &firstToSecondPloopIndices, + llvm::function_ref mayAlias) { + if (!haveNoReadsAfterWriteExceptSameIndex( + firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) + return failure(); + + IRMapping secondToFirstPloopIndices; + secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), + firstPloop.getBody()->getArguments()); + return success(haveNoReadsAfterWriteExceptSameIndex( + secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); +} + +/// Verify equal iteration spaces. +static bool equalIterationSpaces(scf::ParallelOp firstPloop, + scf::ParallelOp secondPloop) { + if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) + return false; + + auto matchOperands = [&](const OperandRange &lhs, + const OperandRange &rhs) -> bool { + // TODO: Extend this to support aliases and equal constants. + return std::equal(lhs.begin(), lhs.end(), rhs.begin()); + }; + return matchOperands(firstPloop.getLowerBound(), + secondPloop.getLowerBound()) && + matchOperands(firstPloop.getUpperBound(), + secondPloop.getUpperBound()) && + matchOperands(firstPloop.getStep(), secondPloop.getStep()); +} + +/// Verify there are no nested ParallelOps. +static bool hasNestedParallelOp(scf::ParallelOp ploop) { + auto walkResult = ploop.getBody()->walk( + [](scf::ParallelOp) { return WalkResult::interrupt(); }); + return walkResult.wasInterrupted(); +} + +static bool isFusionLegal(scf::ParallelOp firstPloop, + scf::ParallelOp secondPloop, + const IRMapping &firstToSecondPloopIndices, + llvm::function_ref mayAlias) { + return !hasNestedParallelOp(firstPloop) && + !hasNestedParallelOp(secondPloop) && + equalIterationSpaces(firstPloop, secondPloop) && + succeeded(verifyDependencies(firstPloop, secondPloop, + firstToSecondPloopIndices, mayAlias)); +} + +void mlir::fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop, + OpBuilder builder, + llvm::function_ref mayAlias) { + Block *block1 = firstPloop.getBody(); + Block *block2 = secondPloop.getBody(); + IRMapping firstToSecondPloopIndices; + firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments()); + + if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, + mayAlias)) + return; + + DominanceInfo dom; + // We are fusing first loop into second, make sure there are no users of the + // first loop results between loops. + for (Operation *user : firstPloop->getUsers()) + if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) + return; + + ValueRange inits1 = firstPloop.getInitVals(); + ValueRange inits2 = secondPloop.getInitVals(); + + SmallVector newInitVars(inits1.begin(), inits1.end()); + newInitVars.append(inits2.begin(), inits2.end()); + + IRRewriter b(builder); + b.setInsertionPoint(secondPloop); + auto newSecondPloop = b.create( + secondPloop.getLoc(), secondPloop.getLowerBound(), + secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); + + Block *newBlock = newSecondPloop.getBody(); + auto term1 = cast(block1->getTerminator()); + auto term2 = cast(block2->getTerminator()); + + b.inlineBlockBefore(block2, newBlock, newBlock->begin(), + newBlock->getArguments()); + b.inlineBlockBefore(block1, newBlock, newBlock->begin(), + newBlock->getArguments()); + + ValueRange results = newSecondPloop.getResults(); + if (!results.empty()) { + b.setInsertionPointToEnd(newBlock); + + ValueRange reduceArgs1 = term1.getOperands(); + ValueRange reduceArgs2 = term2.getOperands(); + SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); + newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); + + auto newReduceOp = b.create(term2.getLoc(), newReduceArgs); + + for (auto &&[i, reg] : llvm::enumerate(llvm::concat( + term1.getReductions(), term2.getReductions()))) { + Block &oldRedBlock = reg.front(); + Block &newRedBlock = newReduceOp.getReductions()[i].front(); + b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), + newRedBlock.getArguments()); + } + + firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); + secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); + } + term1->erase(); + term2->erase(); + firstPloop.erase(); + secondPloop.erase(); + secondPloop = newSecondPloop; +} + scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter) { @@ -1171,3 +1372,10 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, return fusedLoop; } + +scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops( + scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) { + auto mayAlias = [&](Value val1, Value val2) -> bool { return false; }; + mlir::fuseIfLegal(target, source, rewriter, mayAlias); + return source; +} diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index 0f51b1cdbe0cf..46c6be36c3271 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -47,6 +47,59 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func @fuse_two_parallel +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { +func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 +// CHECK: [[SUM:%.*]] = memref.alloc() + %sum = memref.alloc() : memref<2x2xf32> +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] +// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK-NOT: scf.parallel +// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] +// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] +// CHECK: scf.reduce +// CHECK: } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + scf.reduce + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> + scf.reduce + } +// CHECK: memref.dealloc [[SUM]] + memref.dealloc %sum : memref<2x2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + // CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) { // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index From b73238a9472b0682f250e37848ad504d21a57059 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 5 Jun 2024 10:47:00 -0500 Subject: [PATCH 03/33] add checkFusionStructuralLegality --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 7 ++++++ mlir/lib/Dialect/SCF/Utils/Utils.cpp | 26 +++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index 2944d8ffac022..834857f177cdf 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -156,6 +156,13 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef sizes); void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, scf::ForOp root); +//===----------------------------------------------------------------------===// +// Fusion related helpers +//===----------------------------------------------------------------------===// + +template +bool checkFusionStructuralLegality(Operation *target, Operation *source); + /// Prepends operations of firstPloop's body into secondPloop's body. /// Updates secondPloop with new loop. void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop, diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index d85339f32dbe3..c490983335470 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1184,6 +1184,10 @@ static bool equalIterationSpaces(scf::ParallelOp firstPloop, matchOperands(firstPloop.getStep(), secondPloop.getStep()); } +//===----------------------------------------------------------------------===// +// Fusion related helpers +//===----------------------------------------------------------------------===// + /// Verify there are no nested ParallelOps. static bool hasNestedParallelOp(scf::ParallelOp ploop) { auto walkResult = ploop.getBody()->walk( @@ -1191,6 +1195,28 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) { return walkResult.wasInterrupted(); } +template +static bool checkFusionStructuralLegality(Operation *target, + Operation *source) { + static_assert(llvm::is_one_of::value, + "applies to only `forall`, `for` and `parallel`"); + auto targetOp = dyn_cast(target); + auto sourceOp = dyn_cast(source); + if (!targetOp || !sourceOp) + return false; + + if constexpr (std::is_same_v) + return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && + targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && + targetOp.getMixedStep() == sourceOp.getMixedStep() && + targetOp.getMapping() == sourceOp.getMapping(); + else + return targetOp.getLowerBound() == sourceOp.getLowerBound() && + targetOp.getUpperBound() == sourceOp.getUpperBound() && + targetOp.getStep() == sourceOp.getStep(); +} + static bool isFusionLegal(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, From f5bbd131bb7713ae47a58587b2d9acf82dc3b12f Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 5 Jun 2024 14:19:56 -0500 Subject: [PATCH 04/33] replace isLoopWithIdenticalConfiguration with checkFusionStructuralLegality --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 4 +- .../SCF/TransformOps/SCFTransformOps.cpp | 50 ++++++------------- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 32 +++++------- 3 files changed, 29 insertions(+), 57 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index 834857f177cdf..ab9d154aa480d 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -160,8 +160,8 @@ void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, // Fusion related helpers //===----------------------------------------------------------------------===// -template -bool checkFusionStructuralLegality(Operation *target, Operation *source); +bool checkFusionStructuralLegality(LoopLikeOpInterface &target, + LoopLikeOpInterface &source); /// Prepends operations of firstPloop's body into secondPloop's body. /// Updates secondPloop with new loop. diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 1c53e89d69040..9f541b94af474 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -442,34 +442,6 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target, return DiagnosedSilenceableFailure::success(); } -/// Check if `target` scf loop can be fused into `source` scf loop. -/// Applies for scf.for, scf.forall, and scf.parallel. -/// -/// This simply checks if both loops have the same bounds, steps and mapping. -/// No attempt is made at checking that the side effects of `target` and -/// `source` are independent of each other. -template -static bool isLoopWithIdenticalConfiguration(Operation *target, - Operation *source) { - static_assert(llvm::is_one_of::value, - "applies to only `forall`, `for` and `parallel`"); - auto targetOp = dyn_cast(target); - auto sourceOp = dyn_cast(source); - if (!targetOp || !sourceOp) - return false; - - if constexpr (std::is_same_v) - return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && - targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && - targetOp.getMixedStep() == sourceOp.getMixedStep() && - targetOp.getMapping() == sourceOp.getMapping(); - else - return targetOp.getLowerBound() == sourceOp.getLowerBound() && - targetOp.getUpperBound() == sourceOp.getUpperBound() && - targetOp.getStep() == sourceOp.getStep(); -} - DiagnosedSilenceableFailure transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -485,29 +457,37 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, << "source handle (got " << llvm::range_size(sourceOps) << ")"; } - Operation *target = *targetOps.begin(); - Operation *source = *sourceOps.begin(); + LoopLikeOpInterface target = + dyn_cast(*targetOps.begin()); + LoopLikeOpInterface source = + dyn_cast(*sourceOps.begin()); + if (!target || !source) + return emitSilenceableFailure(target->getLoc()) + << "target or source is not a loop op"; // Check if the target and source are siblings. DiagnosedSilenceableFailure diag = isOpSibling(target, source); if (!diag.succeeded()) return diag; + if (!mlir::checkFusionStructuralLegality(target, source)) + return emitSilenceableFailure(target->getLoc()) + << "operations cannot be fused"; + Operation *fusedLoop; /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. - if (isLoopWithIdenticalConfiguration(target, source)) { + if (isa(target) && isa(source)) { fusedLoop = fuseIndependentSiblingForLoops( cast(target), cast(source), rewriter); - } else if (isLoopWithIdenticalConfiguration(target, source)) { + } else if (isa(target) && isa(source)) { fusedLoop = fuseIndependentSiblingForallLoops( cast(target), cast(source), rewriter); - } else if (isLoopWithIdenticalConfiguration(target, - source)) { + } else if (isa(target) && isa(source)) { fusedLoop = fuseIndependentSiblingParallelLoops( cast(target), cast(source), rewriter); } else return emitSilenceableFailure(target->getLoc()) - << "operations cannot be fused"; + << "unsupported loop type for fusion"; assert(fusedLoop && "failed to fuse operations"); diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index c490983335470..ce20730459c2a 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1195,26 +1195,18 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) { return walkResult.wasInterrupted(); } -template -static bool checkFusionStructuralLegality(Operation *target, - Operation *source) { - static_assert(llvm::is_one_of::value, - "applies to only `forall`, `for` and `parallel`"); - auto targetOp = dyn_cast(target); - auto sourceOp = dyn_cast(source); - if (!targetOp || !sourceOp) - return false; - - if constexpr (std::is_same_v) - return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && - targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && - targetOp.getMixedStep() == sourceOp.getMixedStep() && - targetOp.getMapping() == sourceOp.getMapping(); - else - return targetOp.getLowerBound() == sourceOp.getLowerBound() && - targetOp.getUpperBound() == sourceOp.getUpperBound() && - targetOp.getStep() == sourceOp.getStep(); +bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target, + LoopLikeOpInterface &source) { + auto iterSpaceEq = + target.getMixedLowerBound() == source.getMixedLowerBound() && + target.getMixedUpperBound() == source.getMixedUpperBound() && + target.getMixedStep() == source.getMixedStep(); + auto forAllTarget = dyn_cast(*target); + auto forAllSource = dyn_cast(*source); + if (forAllTarget && forAllSource) + return iterSpaceEq && + forAllTarget.getMapping() == forAllSource.getMapping(); + return iterSpaceEq; } static bool isFusionLegal(scf::ParallelOp firstPloop, From 7d995815064cb25e47ec8e400de3692fbe5fdfba Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 5 Jun 2024 14:37:57 -0500 Subject: [PATCH 05/33] address review comment --- .../mlir/Interfaces/LoopLikeInterface.td | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td index 813779c852027..5cf3eba0bd9ed 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -234,29 +234,33 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { /// If there is a single induction variable return it, otherwise return /// std::nullopt. ::std::optional<::mlir::Value> getSingleInductionVar() { - if (this->getInductionVars().size() == 1) - return this->getInductionVars()[0]; + auto inductionVars = this->getInductionVars(); + if (inductionVars.size() == 1) + return inductionVars[0]; return std::nullopt; } /// Return the single lower bound value or attribute if it exists, otherwise /// return std::nullopt. ::std::optional<::mlir::OpFoldResult> getSingleLowerBound() { - if (this->getMixedLowerBound().size() == 1) - return this->getMixedLowerBound()[0]; + auto lowerBounds = this->getMixedLowerBound(); + if (lowerBounds.size() == 1) + return lowerBounds[0]; return std::nullopt; } /// Return the single step value or attribute if it exists, otherwise /// return std::nullopt. ::std::optional<::mlir::OpFoldResult> getSingleStep() { - if (this->getMixedStep().size() == 1) - return this->getMixedStep()[0]; + auto steps = this->getMixedStep(); + if (steps.size() == 1) + return steps[0]; return std::nullopt; } /// Return the single upper bound value or attribute if it exists, otherwise /// return std::nullopt. ::std::optional<::mlir::OpFoldResult> getSingleUpperBound() { - if (this->getMixedUpperBound().size() == 1) - return this->getMixedUpperBound()[0]; + auto upperBounds = this->getMixedUpperBound(); + if (upperBounds.size() == 1) + return upperBounds[0]; return std::nullopt; } From a5fa3b3c4903c344847ee544cd9812b6f0c70571 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 5 Jun 2024 20:56:02 -0500 Subject: [PATCH 06/33] Make return types optional and change names --- .../mlir/Dialect/Affine/IR/AffineOps.td | 10 +-- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 29 +++++++-- .../mlir/Interfaces/LoopLikeInterface.td | 42 ++++++------- .../AffineToStandard/AffineToStandard.cpp | 4 +- .../SCFToControlFlow/SCFToControlFlow.cpp | 9 +-- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 27 ++++---- mlir/lib/Dialect/Affine/Utils/Utils.cpp | 2 +- mlir/lib/Dialect/SCF/IR/SCF.cpp | 26 ++++---- .../Dialect/SCF/Transforms/ForallToFor.cpp | 9 +-- .../Dialect/SCF/LoopLikeSCFOpsTest.cpp | 62 +++++++++++++------ 10 files changed, 131 insertions(+), 89 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index bb2c29b5733b8..4c032e66f7a83 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for", [AttrSizedOperandSegments, AutomaticAllocationScope, ImplicitAffineTerminator, ConditionallySpeculatable, RecursiveMemoryEffects, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { @@ -671,7 +671,7 @@ def AffineParallelOp : Affine_Op<"parallel", I32ElementsAttr:$lowerBoundsGroups, AffineMapAttr:$upperBoundsMap, I32ElementsAttr:$upperBoundsGroups, - I64SmallVectorArrayAttr:$steps, + I64SmallVectorArrayAttr:$step, Variadic:$mapOperands); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -682,7 +682,7 @@ def AffineParallelOp : Affine_Op<"parallel", OpBuilder<(ins "TypeRange":$resultTypes, "ArrayRef":$reductions, "ArrayRef":$lbMaps, "ValueRange":$lbArgs, "ArrayRef":$ubMaps, "ValueRange":$ubArgs, - "ArrayRef":$steps)> + "ArrayRef":$step)> ]; let extraClassDeclaration = [{ @@ -727,7 +727,7 @@ def AffineParallelOp : Affine_Op<"parallel", static StringRef getUpperBoundsGroupsAttrStrName() { return "upperBoundsGroups"; } - static StringRef getStepsAttrStrName() { return "steps"; } + static StringRef getStepsAttrStrName() { return "step"; } /// Returns `true` if the loop bounds have min/max expressions. bool hasMinMaxBounds() { diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 3b28ca8b21d0f..66b478f141b32 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ def ForOp : SCF_Op<"for", [AutomaticAllocationScope, DeclareOpInterfaceMethods, AllTypesMatch<["lowerBound", "upperBound", "step"]>, @@ -302,7 +302,7 @@ def ForallOp : SCF_Op<"forall", [ AutomaticAllocationScope, DeclareOpInterfaceMethods, RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, @@ -510,6 +510,27 @@ def ForallOp : SCF_Op<"forall", [ ]; let extraClassDeclaration = [{ + // Get lower bounds as OpFoldResult. + SmallVector getMixedLowerBound() { + auto maybeLowerBounds = getLowerBounds(); + assert(maybeLowerBounds.has_value() && "expected values"); + return *maybeLowerBounds; + } + + // Get upper bounds as OpFoldResult. + SmallVector getMixedUpperBound() { + auto maybeUpperBounds = getUpperBounds(); + assert(maybeUpperBounds.has_value() && "expected values"); + return *maybeUpperBounds; + } + + // Get steps as OpFoldResult. + SmallVector getMixedStep() { + auto maybeSteps = getSteps(); + assert(maybeSteps.has_value() && "expected values"); + return *maybeSteps; + } + /// Get lower bounds as values. SmallVector getLowerBound(OpBuilder &b) { return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound()); @@ -744,7 +765,7 @@ def ParallelOp : SCF_Op<"parallel", [AutomaticAllocationScope, AttrSizedOperandSegments, DeclareOpInterfaceMethods, + "getLowerBounds", "getUpperBounds", "getSteps"]>, RecursiveMemoryEffects, DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::ReduceOp">, diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td index 5cf3eba0bd9ed..cc79d026c8d4e 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -106,34 +106,34 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { InterfaceMethod<[{ Return all lower bounds. }], - /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>", - /*methodName=*/"getMixedLowerBound", + /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>", + /*methodName=*/"getLowerBounds", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return {}; + return std::nullopt; }] >, InterfaceMethod<[{ Return all steps. }], - /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>", - /*methodName=*/"getMixedStep", + /*retTy=*/"std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>", + /*methodName=*/"getSteps", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return {}; + return std::nullopt; }] >, InterfaceMethod<[{ Return all upper bounds. }], - /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>", - /*methodName=*/"getMixedUpperBound", + /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>", + /*methodName=*/"getUpperBounds", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return {}; + return std::nullopt; }] >, InterfaceMethod<[{ @@ -242,26 +242,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { /// Return the single lower bound value or attribute if it exists, otherwise /// return std::nullopt. ::std::optional<::mlir::OpFoldResult> getSingleLowerBound() { - auto lowerBounds = this->getMixedLowerBound(); - if (lowerBounds.size() == 1) - return lowerBounds[0]; - return std::nullopt; + auto lowerBounds = this->getLowerBounds(); + if (lowerBounds.has_value() && (*lowerBounds).size() == 1) + return (*lowerBounds)[0]; + return std::nullopt; } /// Return the single step value or attribute if it exists, otherwise /// return std::nullopt. ::std::optional<::mlir::OpFoldResult> getSingleStep() { - auto steps = this->getMixedStep(); - if (steps.size() == 1) - return steps[0]; - return std::nullopt; + auto steps = this->getSteps(); + if (steps.has_value() && (*steps).size() == 1) + return (*steps)[0]; + return std::nullopt; } /// Return the single upper bound value or attribute if it exists, otherwise /// return std::nullopt. ::std::optional<::mlir::OpFoldResult> getSingleUpperBound() { - auto upperBounds = this->getMixedUpperBound(); - if (upperBounds.size() == 1) - return upperBounds[0]; - return std::nullopt; + auto upperBounds = this->getUpperBounds(); + if (upperBounds.has_value() && (*upperBounds).size() == 1) + return (*upperBounds)[0]; + return std::nullopt; } /// Append the specified additional "init" operands: replace this loop with diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 10ccd5c97783b..20487b32e3fe0 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -196,8 +196,8 @@ class AffineParallelLowering : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds"); upperBoundTuple.push_back(upper); } - steps.reserve(op.getSteps().size()); - for (int64_t step : op.getSteps()) + steps.reserve(op.getStep().size()); + for (int64_t step : op.getStep()) steps.push_back(rewriter.create(loc, step)); // Get the terminator op. diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 9eb8a289d7d65..48e1d88c1c75e 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -695,12 +695,9 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp, "only fully bufferized scf.forall ops can be lowered to scf.parallel"); // Convert mixed bounds and steps to SSA values. - SmallVector lbs = getValueOrCreateConstantIndexOp( - rewriter, loc, forallOp.getMixedLowerBound()); - SmallVector ubs = getValueOrCreateConstantIndexOp( - rewriter, loc, forallOp.getMixedUpperBound()); - SmallVector steps = - getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep()); + SmallVector lbs = forallOp.getLowerBound(rewriter); + SmallVector ubs = forallOp.getUpperBound(rewriter); + SmallVector steps = forallOp.getStep(rewriter); // Create empty scf.parallel op. auto parallelOp = rewriter.create(loc, lbs, ubs, steps); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 746a9c919560c..d3f034a0660ba 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2456,23 +2456,26 @@ SmallVector AffineForOp::getLoopRegions() { return {&getRegion()}; } ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; } -SmallVector AffineForOp::getMixedLowerBound() { +std::optional> AffineForOp::getLowerBounds() { if (!hasConstantLowerBound()) - return {}; + return std::nullopt; OpBuilder b(getContext()); - return {OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))}; + return SmallVector{ + OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))}; } -SmallVector AffineForOp::getMixedStep() { +std::optional> AffineForOp::getSteps() { OpBuilder b(getContext()); - return {OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))}; + return SmallVector{ + OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))}; } -SmallVector AffineForOp::getMixedUpperBound() { +std::optional> AffineForOp::getUpperBounds() { if (!hasConstantUpperBound()) return {}; OpBuilder b(getContext()); - return {OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))}; + return SmallVector{ + OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))}; } FailureOr AffineForOp::replaceWithAdditionalYields( @@ -3753,7 +3756,7 @@ SmallVector AffineParallelOp::getLoopRegions() { return {&getRegion()}; } -unsigned AffineParallelOp::getNumDims() { return getSteps().size(); } +unsigned AffineParallelOp::getNumDims() { return getStep().size(); } AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() { return getOperands().take_front(getLowerBoundsMap().getNumInputs()); @@ -3838,7 +3841,7 @@ void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) { } void AffineParallelOp::setSteps(ArrayRef newSteps) { - setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); + setStepAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); } // check whether resultType match op or not in affine.parallel @@ -3888,14 +3891,14 @@ LogicalResult AffineParallelOp::verify() { auto numDims = getNumDims(); if (getLowerBoundsGroups().getNumElements() != numDims || getUpperBoundsGroups().getNumElements() != numDims || - getSteps().size() != numDims || getBody()->getNumArguments() != numDims) { + getStep().size() != numDims || getBody()->getNumArguments() != numDims) { return emitOpError() << "the number of region arguments (" << getBody()->getNumArguments() << ") and the number of map groups for lower (" << getLowerBoundsGroups().getNumElements() << ") and upper bound (" << getUpperBoundsGroups().getNumElements() - << "), and the number of steps (" << getSteps().size() + << "), and the number of steps (" << getStep().size() << ") must all match"; } @@ -4013,7 +4016,7 @@ void AffineParallelOp::print(OpAsmPrinter &p) { printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(), getUpperBoundsOperands(), "min"); p << ')'; - SmallVector steps = getSteps(); + SmallVector steps = getStep(); bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; }); if (!elideSteps) { p << " step ("; diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index f46381403bc52..a652ee4a488d1 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -494,7 +494,7 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) { return; AffineMap lbMap = op.getLowerBoundsMap(); - SmallVector steps = op.getSteps(); + SmallVector steps = op.getStep(); // No need to do any work if the parallel op is already normalized. bool isAlreadyNormalized = llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) { diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index e275ff1849c10..281d73afee4a8 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -380,16 +380,16 @@ LogicalResult ForOp::verifyRegions() { ValueRange ForOp::getInductionVars() { return {getInductionVar()}; } -SmallVector ForOp::getMixedLowerBound() { - return {OpFoldResult(getLowerBound())}; +std::optional> ForOp::getLowerBounds() { + return SmallVector{OpFoldResult(getLowerBound())}; } -SmallVector ForOp::getMixedStep() { - return {OpFoldResult(getStep())}; +std::optional> ForOp::getSteps() { + return SmallVector{OpFoldResult(getStep())}; } -SmallVector ForOp::getMixedUpperBound() { - return {OpFoldResult(getUpperBound())}; +std::optional> ForOp::getUpperBounds() { + return SmallVector{OpFoldResult(getUpperBound())}; } std::optional ForOp::getLoopResults() { return getResults(); } @@ -1431,19 +1431,19 @@ ValueRange ForallOp::getInductionVars() { } // Get lower bounds as OpFoldResult. -SmallVector ForallOp::getMixedLowerBound() { +std::optional> ForallOp::getLowerBounds() { Builder b(getOperation()->getContext()); return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b); } // Get upper bounds as OpFoldResult. -SmallVector ForallOp::getMixedUpperBound() { +std::optional> ForallOp::getUpperBounds() { Builder b(getOperation()->getContext()); return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b); } // Get steps as OpFoldResult. -SmallVector ForallOp::getMixedStep() { +std::optional> ForallOp::getSteps() { Builder b(getOperation()->getContext()); return getMixedValues(getStaticStep(), getDynamicStep(), b); } @@ -3006,15 +3006,17 @@ SmallVector ParallelOp::getLoopRegions() { return {&getRegion()}; } ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); } -SmallVector ParallelOp::getMixedLowerBound() { +std::optional> ParallelOp::getLowerBounds() { return getLowerBound(); } -SmallVector ParallelOp::getMixedUpperBound() { +std::optional> ParallelOp::getUpperBounds() { return getUpperBound(); } -SmallVector ParallelOp::getMixedStep() { return getStep(); } +std::optional> ParallelOp::getSteps() { + return getStep(); +} ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) { auto ivArg = llvm::dyn_cast(val); diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp index 198cb2e6cc69e..5da1b76e929be 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp @@ -34,12 +34,9 @@ mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp, rewriter.setInsertionPoint(forallOp); Location loc = forallOp.getLoc(); - SmallVector lbs = getValueOrCreateConstantIndexOp( - rewriter, loc, forallOp.getMixedLowerBound()); - SmallVector ubs = getValueOrCreateConstantIndexOp( - rewriter, loc, forallOp.getMixedUpperBound()); - SmallVector steps = - getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep()); + SmallVector lbs = forallOp.getLowerBound(rewriter); + SmallVector ubs = forallOp.getUpperBound(rewriter); + SmallVector steps = forallOp.getStep(rewriter); LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps); SmallVector ivs = llvm::map_to_vector( diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp index d8cdb213070da..07504a99fecd3 100644 --- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp +++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp @@ -27,35 +27,57 @@ class SCFLoopLikeTest : public ::testing::Test { } void checkUnidimensional(LoopLikeOpInterface loopLikeOp) { - std::optional maybeLb = loopLikeOp.getSingleLowerBound(); + std::optional maybeSingleLb = + loopLikeOp.getSingleLowerBound(); + EXPECT_TRUE(maybeSingleLb.has_value()); + std::optional maybeSingleUb = + loopLikeOp.getSingleUpperBound(); + EXPECT_TRUE(maybeSingleUb.has_value()); + std::optional maybeSingleStep = loopLikeOp.getSingleStep(); + EXPECT_TRUE(maybeSingleStep.has_value()); + std::optional maybeSingleIndVar = + loopLikeOp.getSingleInductionVar(); + EXPECT_TRUE(maybeSingleIndVar.has_value()); + + std::optional> maybeLb = + loopLikeOp.getLowerBounds(); EXPECT_TRUE(maybeLb.has_value()); - std::optional maybeUb = loopLikeOp.getSingleUpperBound(); + EXPECT_EQ((*maybeLb).size(), 1u); + std::optional> maybeUb = + loopLikeOp.getUpperBounds(); EXPECT_TRUE(maybeUb.has_value()); - std::optional maybeStep = loopLikeOp.getSingleStep(); + EXPECT_EQ((*maybeUb).size(), 1u); + std::optional> maybeStep = loopLikeOp.getSteps(); EXPECT_TRUE(maybeStep.has_value()); - std::optional maybeIndVar = - loopLikeOp.getSingleInductionVar(); - EXPECT_TRUE(maybeIndVar.has_value()); + EXPECT_EQ((*maybeStep).size(), 1u); EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u); - EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u); - EXPECT_EQ(loopLikeOp.getMixedStep().size(), 1u); - EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u); } void checkMultidimensional(LoopLikeOpInterface loopLikeOp) { - std::optional maybeLb = loopLikeOp.getSingleLowerBound(); - EXPECT_FALSE(maybeLb.has_value()); - std::optional maybeUb = loopLikeOp.getSingleUpperBound(); - EXPECT_FALSE(maybeUb.has_value()); - std::optional maybeStep = loopLikeOp.getSingleStep(); - EXPECT_FALSE(maybeStep.has_value()); - std::optional maybeIndVar = + std::optional maybeSingleLb = + loopLikeOp.getSingleLowerBound(); + EXPECT_FALSE(maybeSingleLb.has_value()); + std::optional maybeSingleUb = + loopLikeOp.getSingleUpperBound(); + EXPECT_FALSE(maybeSingleUb.has_value()); + std::optional maybeSingleStep = loopLikeOp.getSingleStep(); + EXPECT_FALSE(maybeSingleStep.has_value()); + std::optional maybeSingleIndVar = loopLikeOp.getSingleInductionVar(); - EXPECT_FALSE(maybeIndVar.has_value()); + EXPECT_FALSE(maybeSingleIndVar.has_value()); + + std::optional> maybeLb = + loopLikeOp.getLowerBounds(); + EXPECT_TRUE(maybeLb.has_value()); + EXPECT_EQ((*maybeLb).size(), 2u); + std::optional> maybeUb = + loopLikeOp.getUpperBounds(); + EXPECT_TRUE(maybeUb.has_value()); + EXPECT_EQ((*maybeUb).size(), 2u); + std::optional> maybeStep = loopLikeOp.getSteps(); + EXPECT_TRUE(maybeStep.has_value()); + EXPECT_EQ((*maybeStep).size(), 2u); EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u); - EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u); - EXPECT_EQ(loopLikeOp.getMixedStep().size(), 2u); - EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u); } MLIRContext context; From 1babe681d7858a4992303c62e22684cb73d82472 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 6 Jun 2024 11:31:11 -0500 Subject: [PATCH 07/33] change return type of getInductionVars to SmallVector --- mlir/include/mlir/Interfaces/LoopLikeInterface.td | 2 +- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 4 +++- mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 3 +-- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 6 +++--- mlir/lib/Dialect/SCF/IR/SCF.cpp | 10 ++++++---- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td index cc79d026c8d4e..bace8f8384d44 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -95,7 +95,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { InterfaceMethod<[{ Return all induction variables. }], - /*retTy=*/"::mlir::ValueRange", + /*retTy=*/"::llvm::SmallVector<::mlir::Value>", /*methodName=*/"getInductionVars", /*args=*/(ins), /*methodBody=*/"", diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index d3f034a0660ba..5467c60242664 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2454,7 +2454,9 @@ bool AffineForOp::matchingBoundOperandList() { SmallVector AffineForOp::getLoopRegions() { return {&getRegion()}; } -ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; } +SmallVector AffineForOp::getInductionVars() { + return {getInductionVar()}; +} std::optional> AffineForOp::getLowerBounds() { if (!hasConstantLowerBound()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index b0a4de2da1e86..8b0e04fb61b1b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -184,8 +184,7 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, for (Operation *loopOp : loopOps) { llvm::TypeSwitch(loopOp) .Case([&](scf::ParallelOp parallelOp) { - allIvs.append(parallelOp.getInductionVars().begin(), - parallelOp.getInductionVars().end()); + allIvs.append(parallelOp.getInductionVars()); }) .Case([&](scf::ForOp forOp) { allIvs.push_back(forOp.getInductionVar()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index fd314ef9f8134..4eacaa8d1e327 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -243,7 +243,7 @@ static void calculateTileOffsetsAndSizes( OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(forallOp.getBody(0)); - ValueRange threadIds = forallOp.getInductionVars(); + auto threadIds = forallOp.getInductionVars(); SmallVector nonZeroNumThreads = llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); @@ -746,7 +746,7 @@ FailureOr linalg::tileReductionUsingForall( b.getIndexAttr(0)); SmallVector sizes = tiledSizes; sizes[reductionDim] = b.getIndexAttr(1); - outOffsets[reductionDim] = forallOp.getInductionVars().front(); + outOffsets[reductionDim] = forallOp.getInductionVars()[0]; // TODO: use SubsetExtractOpInterface once it is available. tiledDpsInitOperands.push_back(b.create( loc, cast(initOperand.getType()), @@ -814,7 +814,7 @@ FailureOr linalg::tileReductionUsingForall( int64_t sizeIdx = 0; for (int64_t i = 0, e = numThreads.size(); i < e; ++i) { if (i == reductionDim) { - resultOffsetsRank.push_back(forallOp.getInductionVars().front()); + resultOffsetsRank.push_back(forallOp.getInductionVars()[0]); resultSizesRank.push_back(b.getIndexAttr(1)); continue; } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 281d73afee4a8..0ce10ebdad3e2 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -378,7 +378,7 @@ LogicalResult ForOp::verifyRegions() { return success(); } -ValueRange ForOp::getInductionVars() { return {getInductionVar()}; } +SmallVector ForOp::getInductionVars() { return {getInductionVar()}; } std::optional> ForOp::getLowerBounds() { return SmallVector{OpFoldResult(getLowerBound())}; @@ -1426,8 +1426,8 @@ SmallVector ForallOp::getCombiningOps(BlockArgument bbArg) { return storeOps; } -ValueRange ForallOp::getInductionVars() { - return getBody()->getArguments().take_front(getRank()); +SmallVector ForallOp::getInductionVars() { + return SmallVector(getBody()->getArguments().take_front(getRank())); } // Get lower bounds as OpFoldResult. @@ -3004,7 +3004,9 @@ void ParallelOp::print(OpAsmPrinter &p) { SmallVector ParallelOp::getLoopRegions() { return {&getRegion()}; } -ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); } +SmallVector ParallelOp::getInductionVars() { + return SmallVector(getBody()->getArguments()); +} std::optional> ParallelOp::getLowerBounds() { return getLowerBound(); From 009fd15ab8abefd56afe6424e27f912a4166329d Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 6 Jun 2024 14:02:52 -0500 Subject: [PATCH 08/33] address maks's comments --- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +- mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 4eacaa8d1e327..a0a0e11a6903d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -243,7 +243,7 @@ static void calculateTileOffsetsAndSizes( OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(forallOp.getBody(0)); - auto threadIds = forallOp.getInductionVars(); + SmallVector threadIds = forallOp.getInductionVars(); SmallVector nonZeroNumThreads = llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 0ce10ebdad3e2..a930f8c71454c 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1427,7 +1427,7 @@ SmallVector ForallOp::getCombiningOps(BlockArgument bbArg) { } SmallVector ForallOp::getInductionVars() { - return SmallVector(getBody()->getArguments().take_front(getRank())); + return SmallVector{getBody()->getArguments().take_front(getRank())}; } // Get lower bounds as OpFoldResult. @@ -3005,7 +3005,7 @@ void ParallelOp::print(OpAsmPrinter &p) { SmallVector ParallelOp::getLoopRegions() { return {&getRegion()}; } SmallVector ParallelOp::getInductionVars() { - return SmallVector(getBody()->getArguments()); + return SmallVector{getBody()->getArguments()}; } std::optional> ParallelOp::getLowerBounds() { From d34ad95aba669b5700976f0d2ed4d68b4902e9be Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 6 Jun 2024 15:26:06 -0500 Subject: [PATCH 09/33] change interface method names again and revert steps operand change --- .../mlir/Dialect/Affine/IR/AffineOps.td | 10 +++---- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 27 ++++++++++++------- .../mlir/Interfaces/LoopLikeInterface.td | 16 +++++------ .../AffineToStandard/AffineToStandard.cpp | 4 +-- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 18 ++++++------- mlir/lib/Dialect/Affine/Utils/Utils.cpp | 2 +- mlir/lib/Dialect/SCF/IR/SCF.cpp | 24 ++++++++--------- .../Dialect/SCF/LoopLikeSCFOpsTest.cpp | 18 +++++++------ 8 files changed, 64 insertions(+), 55 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index 4c032e66f7a83..dbec741cf1b1f 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for", [AttrSizedOperandSegments, AutomaticAllocationScope, ImplicitAffineTerminator, ConditionallySpeculatable, RecursiveMemoryEffects, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { @@ -671,7 +671,7 @@ def AffineParallelOp : Affine_Op<"parallel", I32ElementsAttr:$lowerBoundsGroups, AffineMapAttr:$upperBoundsMap, I32ElementsAttr:$upperBoundsGroups, - I64SmallVectorArrayAttr:$step, + I64SmallVectorArrayAttr:$steps, Variadic:$mapOperands); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -682,7 +682,7 @@ def AffineParallelOp : Affine_Op<"parallel", OpBuilder<(ins "TypeRange":$resultTypes, "ArrayRef":$reductions, "ArrayRef":$lbMaps, "ValueRange":$lbArgs, "ArrayRef":$ubMaps, "ValueRange":$ubArgs, - "ArrayRef":$step)> + "ArrayRef":$steps)> ]; let extraClassDeclaration = [{ @@ -727,7 +727,7 @@ def AffineParallelOp : Affine_Op<"parallel", static StringRef getUpperBoundsGroupsAttrStrName() { return "upperBoundsGroups"; } - static StringRef getStepsAttrStrName() { return "step"; } + static StringRef getStepsAttrStrName() { return "steps"; } /// Returns `true` if the loop bounds have min/max expressions. bool hasMinMaxBounds() { diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 66b478f141b32..3704b15972278 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ def ForOp : SCF_Op<"for", [AutomaticAllocationScope, DeclareOpInterfaceMethods, AllTypesMatch<["lowerBound", "upperBound", "step"]>, @@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [ AttrSizedOperandSegments, AutomaticAllocationScope, DeclareOpInterfaceMethods, RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, @@ -510,23 +510,26 @@ def ForallOp : SCF_Op<"forall", [ ]; let extraClassDeclaration = [{ + SmallVector getInductionVars() { + return getLoopInductionVars(); + } // Get lower bounds as OpFoldResult. SmallVector getMixedLowerBound() { - auto maybeLowerBounds = getLowerBounds(); + auto maybeLowerBounds = getLoopLowerBounds(); assert(maybeLowerBounds.has_value() && "expected values"); return *maybeLowerBounds; } // Get upper bounds as OpFoldResult. SmallVector getMixedUpperBound() { - auto maybeUpperBounds = getUpperBounds(); + auto maybeUpperBounds = getLoopUpperBounds(); assert(maybeUpperBounds.has_value() && "expected values"); return *maybeUpperBounds; } // Get steps as OpFoldResult. SmallVector getMixedStep() { - auto maybeSteps = getSteps(); + auto maybeSteps = getLoopSteps(); assert(maybeSteps.has_value() && "expected values"); return *maybeSteps; } @@ -588,7 +591,7 @@ def ForallOp : SCF_Op<"forall", [ } ::mlir::Value getInductionVar(int64_t idx) { - return getInductionVars()[idx]; + return getLoopInductionVars()[idx]; } ::mlir::Block::BlockArgListType getRegionOutArgs() { @@ -764,8 +767,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, RecursiveMemoryEffects, DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::ReduceOp">, @@ -845,6 +848,10 @@ def ParallelOp : SCF_Op<"parallel", ]; let extraClassDeclaration = [{ + // Get induction variables. + SmallVector getInductionVars() { + return getLoopInductionVars(); + } unsigned getNumLoops() { return getStep().size(); } unsigned getNumReductions() { return getInitVals().size(); } }]; diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td index bace8f8384d44..5312ace4db68e 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -96,7 +96,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { Return all induction variables. }], /*retTy=*/"::llvm::SmallVector<::mlir::Value>", - /*methodName=*/"getInductionVars", + /*methodName=*/"getLoopInductionVars", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -107,7 +107,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { Return all lower bounds. }], /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>", - /*methodName=*/"getLowerBounds", + /*methodName=*/"getLoopLowerBounds", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -118,7 +118,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { Return all steps. }], /*retTy=*/"std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>", - /*methodName=*/"getSteps", + /*methodName=*/"getLoopSteps", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -129,7 +129,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { Return all upper bounds. }], /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>", - /*methodName=*/"getUpperBounds", + /*methodName=*/"getLoopUpperBounds", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -234,7 +234,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { /// If there is a single induction variable return it, otherwise return /// std::nullopt. ::std::optional<::mlir::Value> getSingleInductionVar() { - auto inductionVars = this->getInductionVars(); + auto inductionVars = this->getLoopInductionVars(); if (inductionVars.size() == 1) return inductionVars[0]; return std::nullopt; @@ -242,7 +242,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { /// Return the single lower bound value or attribute if it exists, otherwise /// return std::nullopt. ::std::optional<::mlir::OpFoldResult> getSingleLowerBound() { - auto lowerBounds = this->getLowerBounds(); + auto lowerBounds = this->getLoopLowerBounds(); if (lowerBounds.has_value() && (*lowerBounds).size() == 1) return (*lowerBounds)[0]; return std::nullopt; @@ -250,7 +250,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { /// Return the single step value or attribute if it exists, otherwise /// return std::nullopt. ::std::optional<::mlir::OpFoldResult> getSingleStep() { - auto steps = this->getSteps(); + auto steps = this->getLoopSteps(); if (steps.has_value() && (*steps).size() == 1) return (*steps)[0]; return std::nullopt; @@ -258,7 +258,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { /// Return the single upper bound value or attribute if it exists, otherwise /// return std::nullopt. ::std::optional<::mlir::OpFoldResult> getSingleUpperBound() { - auto upperBounds = this->getUpperBounds(); + auto upperBounds = this->getLoopUpperBounds(); if (upperBounds.has_value() && (*upperBounds).size() == 1) return (*upperBounds)[0]; return std::nullopt; diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 20487b32e3fe0..10ccd5c97783b 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -196,8 +196,8 @@ class AffineParallelLowering : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds"); upperBoundTuple.push_back(upper); } - steps.reserve(op.getStep().size()); - for (int64_t step : op.getStep()) + steps.reserve(op.getSteps().size()); + for (int64_t step : op.getSteps()) steps.push_back(rewriter.create(loc, step)); // Get the terminator op. diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 5467c60242664..d5cb04743dfb9 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2454,11 +2454,11 @@ bool AffineForOp::matchingBoundOperandList() { SmallVector AffineForOp::getLoopRegions() { return {&getRegion()}; } -SmallVector AffineForOp::getInductionVars() { +SmallVector AffineForOp::getLoopInductionVars() { return {getInductionVar()}; } -std::optional> AffineForOp::getLowerBounds() { +std::optional> AffineForOp::getLoopLowerBounds() { if (!hasConstantLowerBound()) return std::nullopt; OpBuilder b(getContext()); @@ -2466,13 +2466,13 @@ std::optional> AffineForOp::getLowerBounds() { OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))}; } -std::optional> AffineForOp::getSteps() { +std::optional> AffineForOp::getLoopSteps() { OpBuilder b(getContext()); return SmallVector{ OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))}; } -std::optional> AffineForOp::getUpperBounds() { +std::optional> AffineForOp::getLoopUpperBounds() { if (!hasConstantUpperBound()) return {}; OpBuilder b(getContext()); @@ -3758,7 +3758,7 @@ SmallVector AffineParallelOp::getLoopRegions() { return {&getRegion()}; } -unsigned AffineParallelOp::getNumDims() { return getStep().size(); } +unsigned AffineParallelOp::getNumDims() { return getSteps().size(); } AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() { return getOperands().take_front(getLowerBoundsMap().getNumInputs()); @@ -3843,7 +3843,7 @@ void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) { } void AffineParallelOp::setSteps(ArrayRef newSteps) { - setStepAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); + setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); } // check whether resultType match op or not in affine.parallel @@ -3893,14 +3893,14 @@ LogicalResult AffineParallelOp::verify() { auto numDims = getNumDims(); if (getLowerBoundsGroups().getNumElements() != numDims || getUpperBoundsGroups().getNumElements() != numDims || - getStep().size() != numDims || getBody()->getNumArguments() != numDims) { + getSteps().size() != numDims || getBody()->getNumArguments() != numDims) { return emitOpError() << "the number of region arguments (" << getBody()->getNumArguments() << ") and the number of map groups for lower (" << getLowerBoundsGroups().getNumElements() << ") and upper bound (" << getUpperBoundsGroups().getNumElements() - << "), and the number of steps (" << getStep().size() + << "), and the number of steps (" << getSteps().size() << ") must all match"; } @@ -4018,7 +4018,7 @@ void AffineParallelOp::print(OpAsmPrinter &p) { printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(), getUpperBoundsOperands(), "min"); p << ')'; - SmallVector steps = getStep(); + SmallVector steps = getSteps(); bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; }); if (!elideSteps) { p << " step ("; diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index a652ee4a488d1..f46381403bc52 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -494,7 +494,7 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) { return; AffineMap lbMap = op.getLowerBoundsMap(); - SmallVector steps = op.getStep(); + SmallVector steps = op.getSteps(); // No need to do any work if the parallel op is already normalized. bool isAlreadyNormalized = llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) { diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index a930f8c71454c..e921177f73215 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -378,17 +378,17 @@ LogicalResult ForOp::verifyRegions() { return success(); } -SmallVector ForOp::getInductionVars() { return {getInductionVar()}; } +SmallVector ForOp::getLoopInductionVars() { return {getInductionVar()}; } -std::optional> ForOp::getLowerBounds() { +std::optional> ForOp::getLoopLowerBounds() { return SmallVector{OpFoldResult(getLowerBound())}; } -std::optional> ForOp::getSteps() { +std::optional> ForOp::getLoopSteps() { return SmallVector{OpFoldResult(getStep())}; } -std::optional> ForOp::getUpperBounds() { +std::optional> ForOp::getLoopUpperBounds() { return SmallVector{OpFoldResult(getUpperBound())}; } @@ -1426,24 +1426,24 @@ SmallVector ForallOp::getCombiningOps(BlockArgument bbArg) { return storeOps; } -SmallVector ForallOp::getInductionVars() { +SmallVector ForallOp::getLoopInductionVars() { return SmallVector{getBody()->getArguments().take_front(getRank())}; } // Get lower bounds as OpFoldResult. -std::optional> ForallOp::getLowerBounds() { +std::optional> ForallOp::getLoopLowerBounds() { Builder b(getOperation()->getContext()); return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b); } // Get upper bounds as OpFoldResult. -std::optional> ForallOp::getUpperBounds() { +std::optional> ForallOp::getLoopUpperBounds() { Builder b(getOperation()->getContext()); return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b); } // Get steps as OpFoldResult. -std::optional> ForallOp::getSteps() { +std::optional> ForallOp::getLoopSteps() { Builder b(getOperation()->getContext()); return getMixedValues(getStaticStep(), getDynamicStep(), b); } @@ -3004,19 +3004,19 @@ void ParallelOp::print(OpAsmPrinter &p) { SmallVector ParallelOp::getLoopRegions() { return {&getRegion()}; } -SmallVector ParallelOp::getInductionVars() { +SmallVector ParallelOp::getLoopInductionVars() { return SmallVector{getBody()->getArguments()}; } -std::optional> ParallelOp::getLowerBounds() { +std::optional> ParallelOp::getLoopLowerBounds() { return getLowerBound(); } -std::optional> ParallelOp::getUpperBounds() { +std::optional> ParallelOp::getLoopUpperBounds() { return getUpperBound(); } -std::optional> ParallelOp::getSteps() { +std::optional> ParallelOp::getLoopSteps() { return getStep(); } diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp index 07504a99fecd3..75cd2bfb01de0 100644 --- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp +++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp @@ -40,17 +40,18 @@ class SCFLoopLikeTest : public ::testing::Test { EXPECT_TRUE(maybeSingleIndVar.has_value()); std::optional> maybeLb = - loopLikeOp.getLowerBounds(); + loopLikeOp.getLoopLowerBounds(); EXPECT_TRUE(maybeLb.has_value()); EXPECT_EQ((*maybeLb).size(), 1u); std::optional> maybeUb = - loopLikeOp.getUpperBounds(); + loopLikeOp.getLoopUpperBounds(); EXPECT_TRUE(maybeUb.has_value()); EXPECT_EQ((*maybeUb).size(), 1u); - std::optional> maybeStep = loopLikeOp.getSteps(); + std::optional> maybeStep = + loopLikeOp.getLoopSteps(); EXPECT_TRUE(maybeStep.has_value()); EXPECT_EQ((*maybeStep).size(), 1u); - EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u); + EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 1u); } void checkMultidimensional(LoopLikeOpInterface loopLikeOp) { @@ -67,17 +68,18 @@ class SCFLoopLikeTest : public ::testing::Test { EXPECT_FALSE(maybeSingleIndVar.has_value()); std::optional> maybeLb = - loopLikeOp.getLowerBounds(); + loopLikeOp.getLoopLowerBounds(); EXPECT_TRUE(maybeLb.has_value()); EXPECT_EQ((*maybeLb).size(), 2u); std::optional> maybeUb = - loopLikeOp.getUpperBounds(); + loopLikeOp.getLoopUpperBounds(); EXPECT_TRUE(maybeUb.has_value()); EXPECT_EQ((*maybeUb).size(), 2u); - std::optional> maybeStep = loopLikeOp.getSteps(); + std::optional> maybeStep = + loopLikeOp.getLoopSteps(); EXPECT_TRUE(maybeStep.has_value()); EXPECT_EQ((*maybeStep).size(), 2u); - EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u); + EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 2u); } MLIRContext context; From e0e526210a5ab6ce28fbc5fa5ee24f79cb1ee9a8 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 6 Jun 2024 16:04:47 -0500 Subject: [PATCH 10/33] return option induction vars --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 10 +++++++--- mlir/include/mlir/Interfaces/LoopLikeInterface.td | 8 ++++---- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 4 ++-- mlir/lib/Dialect/SCF/IR/SCF.cpp | 8 +++++--- mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp | 10 ++++++++-- 5 files changed, 26 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 3704b15972278..d425c1c2a47b4 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -511,7 +511,9 @@ def ForallOp : SCF_Op<"forall", [ let extraClassDeclaration = [{ SmallVector getInductionVars() { - return getLoopInductionVars(); + auto maybeInductionVars = getLoopInductionVars();; + assert(maybeInductionVars.has_value() && "expected values"); + return *maybeInductionVars; } // Get lower bounds as OpFoldResult. SmallVector getMixedLowerBound() { @@ -591,7 +593,7 @@ def ForallOp : SCF_Op<"forall", [ } ::mlir::Value getInductionVar(int64_t idx) { - return getLoopInductionVars()[idx]; + return getInductionVars()[idx]; } ::mlir::Block::BlockArgListType getRegionOutArgs() { @@ -850,7 +852,9 @@ def ParallelOp : SCF_Op<"parallel", let extraClassDeclaration = [{ // Get induction variables. SmallVector getInductionVars() { - return getLoopInductionVars(); + auto maybeInductionVars = getLoopInductionVars();; + assert(maybeInductionVars.has_value() && "expected values"); + return *maybeInductionVars; } unsigned getNumLoops() { return getStep().size(); } unsigned getNumReductions() { return getInitVals().size(); } diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td index 5312ace4db68e..2e6aabda30b07 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -95,12 +95,12 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { InterfaceMethod<[{ Return all induction variables. }], - /*retTy=*/"::llvm::SmallVector<::mlir::Value>", + /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::Value>>", /*methodName=*/"getLoopInductionVars", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return {}; + return std::nullopt; }] >, InterfaceMethod<[{ @@ -235,8 +235,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { /// std::nullopt. ::std::optional<::mlir::Value> getSingleInductionVar() { auto inductionVars = this->getLoopInductionVars(); - if (inductionVars.size() == 1) - return inductionVars[0]; + if (inductionVars.has_value() && (*inductionVars).size() == 1) + return (*inductionVars)[0]; return std::nullopt; } /// Return the single lower bound value or attribute if it exists, otherwise diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index d5cb04743dfb9..0a58d2fdb02f5 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2454,8 +2454,8 @@ bool AffineForOp::matchingBoundOperandList() { SmallVector AffineForOp::getLoopRegions() { return {&getRegion()}; } -SmallVector AffineForOp::getLoopInductionVars() { - return {getInductionVar()}; +std::optional> AffineForOp::getLoopInductionVars() { + return SmallVector{getInductionVar()}; } std::optional> AffineForOp::getLoopLowerBounds() { diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index e921177f73215..c00579443ea29 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -378,7 +378,9 @@ LogicalResult ForOp::verifyRegions() { return success(); } -SmallVector ForOp::getLoopInductionVars() { return {getInductionVar()}; } +std::optional> ForOp::getLoopInductionVars() { + return SmallVector{getInductionVar()}; +} std::optional> ForOp::getLoopLowerBounds() { return SmallVector{OpFoldResult(getLowerBound())}; @@ -1426,7 +1428,7 @@ SmallVector ForallOp::getCombiningOps(BlockArgument bbArg) { return storeOps; } -SmallVector ForallOp::getLoopInductionVars() { +std::optional> ForallOp::getLoopInductionVars() { return SmallVector{getBody()->getArguments().take_front(getRank())}; } @@ -3004,7 +3006,7 @@ void ParallelOp::print(OpAsmPrinter &p) { SmallVector ParallelOp::getLoopRegions() { return {&getRegion()}; } -SmallVector ParallelOp::getLoopInductionVars() { +std::optional> ParallelOp::getLoopInductionVars() { return SmallVector{getBody()->getArguments()}; } diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp index 75cd2bfb01de0..20dbc8d362d27 100644 --- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp +++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp @@ -51,7 +51,10 @@ class SCFLoopLikeTest : public ::testing::Test { loopLikeOp.getLoopSteps(); EXPECT_TRUE(maybeStep.has_value()); EXPECT_EQ((*maybeStep).size(), 1u); - EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 1u); + std::optional> maybeInductionVars = + loopLikeOp.getLoopInductionVars(); + EXPECT_TRUE(maybeInductionVars.has_value()); + EXPECT_EQ((*maybeInductionVars).size(), 1u); } void checkMultidimensional(LoopLikeOpInterface loopLikeOp) { @@ -79,7 +82,10 @@ class SCFLoopLikeTest : public ::testing::Test { loopLikeOp.getLoopSteps(); EXPECT_TRUE(maybeStep.has_value()); EXPECT_EQ((*maybeStep).size(), 2u); - EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 2u); + std::optional> maybeInductionVars = + loopLikeOp.getLoopInductionVars(); + EXPECT_TRUE(maybeInductionVars.has_value()); + EXPECT_EQ((*maybeInductionVars).size(), 2u); } MLIRContext context; From 7115a6e08bba43fe9750f8cef5c73f6be1b373fd Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 7 Jun 2024 11:45:44 -0500 Subject: [PATCH 11/33] address review comments --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 19 ++++++------ .../mlir/Interfaces/LoopLikeInterface.td | 30 +++++++++++++------ mlir/lib/Dialect/SCF/IR/SCF.cpp | 8 ++--- .../Dialect/SCF/LoopLikeSCFOpsTest.cpp | 16 +++++----- 4 files changed, 43 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index d425c1c2a47b4..f35ea962bea16 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -510,28 +510,29 @@ def ForallOp : SCF_Op<"forall", [ ]; let extraClassDeclaration = [{ + /// Get induction variables. SmallVector getInductionVars() { - auto maybeInductionVars = getLoopInductionVars();; + std::optional> maybeInductionVars = getLoopInductionVars(); assert(maybeInductionVars.has_value() && "expected values"); return *maybeInductionVars; } - // Get lower bounds as OpFoldResult. + /// Get lower bounds as OpFoldResult. SmallVector getMixedLowerBound() { - auto maybeLowerBounds = getLoopLowerBounds(); + std::optional> maybeLowerBounds = getLoopLowerBounds(); assert(maybeLowerBounds.has_value() && "expected values"); return *maybeLowerBounds; } - // Get upper bounds as OpFoldResult. + /// Get upper bounds as OpFoldResult. SmallVector getMixedUpperBound() { - auto maybeUpperBounds = getLoopUpperBounds(); + std::optional> maybeUpperBounds = getLoopUpperBounds(); assert(maybeUpperBounds.has_value() && "expected values"); return *maybeUpperBounds; } - // Get steps as OpFoldResult. + /// Get steps as OpFoldResult. SmallVector getMixedStep() { - auto maybeSteps = getLoopSteps(); + std::optional> maybeSteps = getLoopSteps(); assert(maybeSteps.has_value() && "expected values"); return *maybeSteps; } @@ -850,9 +851,9 @@ def ParallelOp : SCF_Op<"parallel", ]; let extraClassDeclaration = [{ - // Get induction variables. + /// Get induction variables. SmallVector getInductionVars() { - auto maybeInductionVars = getLoopInductionVars();; + std::optional> maybeInductionVars = getLoopInductionVars();; assert(maybeInductionVars.has_value() && "expected values"); return *maybeInductionVars; } diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td index 2e6aabda30b07..b748d5e29114a 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -93,47 +93,59 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { }] >, InterfaceMethod<[{ - Return all induction variables. + Return all induction variables, if they exist. If the op has no notion of + induction variable, then return std::nullopt. If it does have + a notion but an instance doesn't have induction variables, then + return empty vector. }], /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::Value>>", /*methodName=*/"getLoopInductionVars", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return std::nullopt; + return ::std::nullopt; }] >, InterfaceMethod<[{ - Return all lower bounds. + Return all lower bounds, if they exist. If the op has no notion of + lower bounds, then return std::nullopt. If it does have + a notion but an instance doesn't have lower bounds, then + return empty vector. }], /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>", /*methodName=*/"getLoopLowerBounds", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return std::nullopt; + return ::std::nullopt; }] >, InterfaceMethod<[{ - Return all steps. + Return all steps, if they exist. If the op has no notion of + steps, then return std::nullopt. If it does have + a notion but an instance doesn't have steps, then + return empty vector. }], - /*retTy=*/"std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>", + /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>", /*methodName=*/"getLoopSteps", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return std::nullopt; + return ::std::nullopt; }] >, InterfaceMethod<[{ - Return all upper bounds. + Return all upper bounds, if they exist. If the op has no notion of + lower bounds, then return std::nullopt. If it does have + a notion but an instance doesn't have lower bounds, then + return empty vector. }], /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>", /*methodName=*/"getLoopUpperBounds", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return std::nullopt; + return ::std::nullopt; }] >, InterfaceMethod<[{ diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index c00579443ea29..5e94f4dc612a7 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -379,19 +379,19 @@ LogicalResult ForOp::verifyRegions() { } std::optional> ForOp::getLoopInductionVars() { - return SmallVector{getInductionVar()}; + return SmallVector{getInductionVar()}; } std::optional> ForOp::getLoopLowerBounds() { - return SmallVector{OpFoldResult(getLowerBound())}; + return SmallVector{OpFoldResult(getLowerBound())}; } std::optional> ForOp::getLoopSteps() { - return SmallVector{OpFoldResult(getStep())}; + return SmallVector{OpFoldResult(getStep())}; } std::optional> ForOp::getLoopUpperBounds() { - return SmallVector{OpFoldResult(getUpperBound())}; + return SmallVector{OpFoldResult(getUpperBound())}; } std::optional ForOp::getLoopResults() { return getResults(); } diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp index 20dbc8d362d27..53a4af14d119a 100644 --- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp +++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp @@ -41,19 +41,19 @@ class SCFLoopLikeTest : public ::testing::Test { std::optional> maybeLb = loopLikeOp.getLoopLowerBounds(); - EXPECT_TRUE(maybeLb.has_value()); + ASSERT_TRUE(maybeLb.has_value()); EXPECT_EQ((*maybeLb).size(), 1u); std::optional> maybeUb = loopLikeOp.getLoopUpperBounds(); - EXPECT_TRUE(maybeUb.has_value()); + ASSERT_TRUE(maybeUb.has_value()); EXPECT_EQ((*maybeUb).size(), 1u); std::optional> maybeStep = loopLikeOp.getLoopSteps(); - EXPECT_TRUE(maybeStep.has_value()); + ASSERT_TRUE(maybeStep.has_value()); EXPECT_EQ((*maybeStep).size(), 1u); std::optional> maybeInductionVars = loopLikeOp.getLoopInductionVars(); - EXPECT_TRUE(maybeInductionVars.has_value()); + ASSERT_TRUE(maybeInductionVars.has_value()); EXPECT_EQ((*maybeInductionVars).size(), 1u); } @@ -72,19 +72,19 @@ class SCFLoopLikeTest : public ::testing::Test { std::optional> maybeLb = loopLikeOp.getLoopLowerBounds(); - EXPECT_TRUE(maybeLb.has_value()); + ASSERT_TRUE(maybeLb.has_value()); EXPECT_EQ((*maybeLb).size(), 2u); std::optional> maybeUb = loopLikeOp.getLoopUpperBounds(); - EXPECT_TRUE(maybeUb.has_value()); + ASSERT_TRUE(maybeUb.has_value()); EXPECT_EQ((*maybeUb).size(), 2u); std::optional> maybeStep = loopLikeOp.getLoopSteps(); - EXPECT_TRUE(maybeStep.has_value()); + ASSERT_TRUE(maybeStep.has_value()); EXPECT_EQ((*maybeStep).size(), 2u); std::optional> maybeInductionVars = loopLikeOp.getLoopInductionVars(); - EXPECT_TRUE(maybeInductionVars.has_value()); + ASSERT_TRUE(maybeInductionVars.has_value()); EXPECT_EQ((*maybeInductionVars).size(), 2u); } From 6336fdf28f06c7525bcf6822a386f8c4cabe3c2d Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 7 Jun 2024 15:25:40 -0500 Subject: [PATCH 12/33] update after rebase --- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index ce20730459c2a..e3660e89fb684 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1198,9 +1198,9 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) { bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target, LoopLikeOpInterface &source) { auto iterSpaceEq = - target.getMixedLowerBound() == source.getMixedLowerBound() && - target.getMixedUpperBound() == source.getMixedUpperBound() && - target.getMixedStep() == source.getMixedStep(); + target.getLoopLowerBounds() == source.getLoopLowerBounds() && + target.getLoopUpperBounds() == source.getLoopUpperBounds() && + target.getLoopSteps() == source.getLoopSteps(); auto forAllTarget = dyn_cast(*target); auto forAllSource = dyn_cast(*source); if (forAllTarget && forAllSource) From 86406c335dd216ac91a941102c060bc680af10b1 Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 8 Jun 2024 23:31:57 -0500 Subject: [PATCH 13/33] refactor main parallel fusion logic from fuseIfLegal to util func --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 6 - .../SCF/Transforms/ParallelLoopFusion.cpp | 159 +++++++++++++++++- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 140 ++++++--------- 3 files changed, 208 insertions(+), 97 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index ab9d154aa480d..ac4434b337890 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -163,12 +163,6 @@ void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, bool checkFusionStructuralLegality(LoopLikeOpInterface &target, LoopLikeOpInterface &source); -/// Prepends operations of firstPloop's body into secondPloop's body. -/// Updates secondPloop with new loop. -void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop, - OpBuilder builder, - llvm::function_ref mayAlias); - /// Given two scf.forall loops, `target` and `source`, fuses `target` into /// `source`. Assumes that the given loops are siblings and are independent of /// each other. diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index abac91cfaf7d9..326a8f93162b9 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -31,6 +31,163 @@ namespace mlir { using namespace mlir; using namespace mlir::scf; +/// Verify there are no nested ParallelOps. +static bool hasNestedParallelOp(ParallelOp ploop) { + auto walkResult = + ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); + return walkResult.wasInterrupted(); +} + +/// Verify equal iteration spaces. +static bool equalIterationSpaces(ParallelOp firstPloop, + ParallelOp secondPloop) { + if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) + return false; + + auto matchOperands = [&](const OperandRange &lhs, + const OperandRange &rhs) -> bool { + // TODO: Extend this to support aliases and equal constants. + return std::equal(lhs.begin(), lhs.end(), rhs.begin()); + }; + return matchOperands(firstPloop.getLowerBound(), + secondPloop.getLowerBound()) && + matchOperands(firstPloop.getUpperBound(), + secondPloop.getUpperBound()) && + matchOperands(firstPloop.getStep(), secondPloop.getStep()); +} + +/// Checks if the parallel loops have mixed access to the same buffers. Returns +/// `true` if the first parallel loop writes to the same indices that the second +/// loop reads. +static bool haveNoReadsAfterWriteExceptSameIndex( + ParallelOp firstPloop, ParallelOp secondPloop, + const IRMapping &firstToSecondPloopIndices, + llvm::function_ref mayAlias) { + DenseMap> bufferStores; + SmallVector bufferStoresVec; + firstPloop.getBody()->walk([&](memref::StoreOp store) { + bufferStores[store.getMemRef()].push_back(store.getIndices()); + bufferStoresVec.emplace_back(store.getMemRef()); + }); + auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { + Value loadMem = load.getMemRef(); + // Stop if the memref is defined in secondPloop body. Careful alias analysis + // is needed. + auto *memrefDef = loadMem.getDefiningOp(); + if (memrefDef && memrefDef->getBlock() == load->getBlock()) + return WalkResult::interrupt(); + + for (Value store : bufferStoresVec) + if (store != loadMem && mayAlias(store, loadMem)) + return WalkResult::interrupt(); + + auto write = bufferStores.find(loadMem); + if (write == bufferStores.end()) + return WalkResult::advance(); + + // Check that at last one store was retrieved + if (!write->second.size()) + return WalkResult::interrupt(); + + auto storeIndices = write->second.front(); + + // Multiple writes to the same memref are allowed only on the same indices + for (const auto &othStoreIndices : write->second) { + if (othStoreIndices != storeIndices) + return WalkResult::interrupt(); + } + + // Check that the load indices of secondPloop coincide with store indices of + // firstPloop for the same memrefs. + auto loadIndices = load.getIndices(); + if (storeIndices.size() != loadIndices.size()) + return WalkResult::interrupt(); + for (int i = 0, e = storeIndices.size(); i < e; ++i) { + if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != + loadIndices[i]) { + auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); + auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); + if (storeIndexDefOp && loadIndexDefOp) { + if (!isMemoryEffectFree(storeIndexDefOp)) + return WalkResult::interrupt(); + if (!isMemoryEffectFree(loadIndexDefOp)) + return WalkResult::interrupt(); + if (!OperationEquivalence::isEquivalentTo( + storeIndexDefOp, loadIndexDefOp, + [&](Value storeIndex, Value loadIndex) { + if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) != + firstToSecondPloopIndices.lookupOrDefault(loadIndex)) + return failure(); + else + return success(); + }, + /*markEquivalent=*/nullptr, + OperationEquivalence::Flags::IgnoreLocations)) { + return WalkResult::interrupt(); + } + } else + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + return !walkResult.wasInterrupted(); +} + +/// Analyzes dependencies in the most primitive way by checking simple read and +/// write patterns. +static LogicalResult +verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, + const IRMapping &firstToSecondPloopIndices, + llvm::function_ref mayAlias) { + if (!haveNoReadsAfterWriteExceptSameIndex( + firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) + return failure(); + + IRMapping secondToFirstPloopIndices; + secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), + firstPloop.getBody()->getArguments()); + return success(haveNoReadsAfterWriteExceptSameIndex( + secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); +} + +static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, + const IRMapping &firstToSecondPloopIndices, + llvm::function_ref mayAlias) { + return !hasNestedParallelOp(firstPloop) && + !hasNestedParallelOp(secondPloop) && + equalIterationSpaces(firstPloop, secondPloop) && + succeeded(verifyDependencies(firstPloop, secondPloop, + firstToSecondPloopIndices, mayAlias)); +} + +/// Prepends operations of firstPloop's body into secondPloop's body. +/// Updates secondPloop with new loop. +static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, + OpBuilder builder, + llvm::function_ref mayAlias) { + Block *block1 = firstPloop.getBody(); + Block *block2 = secondPloop.getBody(); + IRMapping firstToSecondPloopIndices; + firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments()); + + if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, + mayAlias)) + return; + + DominanceInfo dom; + // We are fusing first loop into second, make sure there are no users of the + // first loop results between loops. + for (Operation *user : firstPloop->getUsers()) + if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) + return; + + IRRewriter rewriter(builder); + secondPloop = mlir::fuseIndependentSiblingParallelLoops( + firstPloop, secondPloop, rewriter); + ; +} + void mlir::scf::naivelyFuseParallelOps( Region ®ion, llvm::function_ref mayAlias) { OpBuilder b(region); @@ -59,7 +216,7 @@ void mlir::scf::naivelyFuseParallelOps( } for (MutableArrayRef ploops : ploopChains) { for (int i = 0, e = ploops.size(); i + 1 < e; ++i) - mlir::fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); + fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); } } } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index e3660e89fb684..5f58767be409d 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1188,13 +1188,6 @@ static bool equalIterationSpaces(scf::ParallelOp firstPloop, // Fusion related helpers //===----------------------------------------------------------------------===// -/// Verify there are no nested ParallelOps. -static bool hasNestedParallelOp(scf::ParallelOp ploop) { - auto walkResult = ploop.getBody()->walk( - [](scf::ParallelOp) { return WalkResult::interrupt(); }); - return walkResult.wasInterrupted(); -} - bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target, LoopLikeOpInterface &source) { auto iterSpaceEq = @@ -1209,86 +1202,6 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target, return iterSpaceEq; } -static bool isFusionLegal(scf::ParallelOp firstPloop, - scf::ParallelOp secondPloop, - const IRMapping &firstToSecondPloopIndices, - llvm::function_ref mayAlias) { - return !hasNestedParallelOp(firstPloop) && - !hasNestedParallelOp(secondPloop) && - equalIterationSpaces(firstPloop, secondPloop) && - succeeded(verifyDependencies(firstPloop, secondPloop, - firstToSecondPloopIndices, mayAlias)); -} - -void mlir::fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop, - OpBuilder builder, - llvm::function_ref mayAlias) { - Block *block1 = firstPloop.getBody(); - Block *block2 = secondPloop.getBody(); - IRMapping firstToSecondPloopIndices; - firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments()); - - if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, - mayAlias)) - return; - - DominanceInfo dom; - // We are fusing first loop into second, make sure there are no users of the - // first loop results between loops. - for (Operation *user : firstPloop->getUsers()) - if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) - return; - - ValueRange inits1 = firstPloop.getInitVals(); - ValueRange inits2 = secondPloop.getInitVals(); - - SmallVector newInitVars(inits1.begin(), inits1.end()); - newInitVars.append(inits2.begin(), inits2.end()); - - IRRewriter b(builder); - b.setInsertionPoint(secondPloop); - auto newSecondPloop = b.create( - secondPloop.getLoc(), secondPloop.getLowerBound(), - secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); - - Block *newBlock = newSecondPloop.getBody(); - auto term1 = cast(block1->getTerminator()); - auto term2 = cast(block2->getTerminator()); - - b.inlineBlockBefore(block2, newBlock, newBlock->begin(), - newBlock->getArguments()); - b.inlineBlockBefore(block1, newBlock, newBlock->begin(), - newBlock->getArguments()); - - ValueRange results = newSecondPloop.getResults(); - if (!results.empty()) { - b.setInsertionPointToEnd(newBlock); - - ValueRange reduceArgs1 = term1.getOperands(); - ValueRange reduceArgs2 = term2.getOperands(); - SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); - newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); - - auto newReduceOp = b.create(term2.getLoc(), newReduceArgs); - - for (auto &&[i, reg] : llvm::enumerate(llvm::concat( - term1.getReductions(), term2.getReductions()))) { - Block &oldRedBlock = reg.front(); - Block &newRedBlock = newReduceOp.getReductions()[i].front(); - b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), - newRedBlock.getArguments()); - } - - firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); - secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); - } - term1->erase(); - term2->erase(); - firstPloop.erase(); - secondPloop.erase(); - secondPloop = newSecondPloop; -} - scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter) { @@ -1393,7 +1306,54 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops( scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) { - auto mayAlias = [&](Value val1, Value val2) -> bool { return false; }; - mlir::fuseIfLegal(target, source, rewriter, mayAlias); - return source; + Block *block1 = target.getBody(); + Block *block2 = source.getBody(); + auto term1 = cast(block1->getTerminator()); + auto term2 = cast(block2->getTerminator()); + + ValueRange inits1 = target.getInitVals(); + ValueRange inits2 = source.getInitVals(); + + SmallVector newInitVars(inits1.begin(), inits1.end()); + newInitVars.append(inits2.begin(), inits2.end()); + + rewriter.setInsertionPoint(source); + auto fusedLoop = rewriter.create( + source.getLoc(), source.getLowerBound(), source.getUpperBound(), + source.getStep(), newInitVars); + Block *newBlock = fusedLoop.getBody(); + rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(), + newBlock->getArguments()); + rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(), + newBlock->getArguments()); + + ValueRange results = fusedLoop.getResults(); + if (!results.empty()) { + rewriter.setInsertionPointToEnd(newBlock); + + ValueRange reduceArgs1 = term1.getOperands(); + ValueRange reduceArgs2 = term2.getOperands(); + SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); + newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); + + auto newReduceOp = + rewriter.create(term2.getLoc(), newReduceArgs); + + for (auto &&[i, reg] : llvm::enumerate(llvm::concat( + term1.getReductions(), term2.getReductions()))) { + Block &oldRedBlock = reg.front(); + Block &newRedBlock = newReduceOp.getReductions()[i].front(); + rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock, + newRedBlock.begin(), + newRedBlock.getArguments()); + } + target.replaceAllUsesWith(results.take_front(inits1.size())); + source.replaceAllUsesWith(results.take_back(inits2.size())); + } + term1->erase(); + term2->erase(); + target.erase(); + source.erase(); + + return fusedLoop; } From 694d589dc535892f3dda9d27c2a43052fc0b445e Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 8 Jun 2024 23:35:30 -0500 Subject: [PATCH 14/33] remove unused functions --- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 113 --------------------------- 1 file changed, 113 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 5f58767be409d..e6cb88c427da8 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1071,119 +1071,6 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, return tileLoops; } -/// Checks if the parallel loops have mixed access to the same buffers. Returns -/// `true` if the first parallel loop writes to the same indices that the second -/// loop reads. -static bool haveNoReadsAfterWriteExceptSameIndex( - scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, - const IRMapping &firstToSecondPloopIndices, - llvm::function_ref mayAlias) { - DenseMap> bufferStores; - SmallVector bufferStoresVec; - firstPloop.getBody()->walk([&](memref::StoreOp store) { - bufferStores[store.getMemRef()].push_back(store.getIndices()); - bufferStoresVec.emplace_back(store.getMemRef()); - }); - auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { - Value loadMem = load.getMemRef(); - // Stop if the memref is defined in secondPloop body. Careful alias analysis - // is needed. - auto *memrefDef = loadMem.getDefiningOp(); - if (memrefDef && memrefDef->getBlock() == load->getBlock()) - return WalkResult::interrupt(); - - for (Value store : bufferStoresVec) - if (store != loadMem && mayAlias(store, loadMem)) - return WalkResult::interrupt(); - - auto write = bufferStores.find(loadMem); - if (write == bufferStores.end()) - return WalkResult::advance(); - - // Check that at last one store was retrieved - if (!write->second.size()) - return WalkResult::interrupt(); - - auto storeIndices = write->second.front(); - - // Multiple writes to the same memref are allowed only on the same indices - for (const auto &othStoreIndices : write->second) { - if (othStoreIndices != storeIndices) - return WalkResult::interrupt(); - } - - // Check that the load indices of secondPloop coincide with store indices of - // firstPloop for the same memrefs. - auto loadIndices = load.getIndices(); - if (storeIndices.size() != loadIndices.size()) - return WalkResult::interrupt(); - for (int i = 0, e = storeIndices.size(); i < e; ++i) { - if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != - loadIndices[i]) { - auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); - auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); - if (storeIndexDefOp && loadIndexDefOp) { - if (!isMemoryEffectFree(storeIndexDefOp)) - return WalkResult::interrupt(); - if (!isMemoryEffectFree(loadIndexDefOp)) - return WalkResult::interrupt(); - if (!OperationEquivalence::isEquivalentTo( - storeIndexDefOp, loadIndexDefOp, - [&](Value storeIndex, Value loadIndex) { - if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) != - firstToSecondPloopIndices.lookupOrDefault(loadIndex)) - return failure(); - else - return success(); - }, - /*markEquivalent=*/nullptr, - OperationEquivalence::Flags::IgnoreLocations)) { - return WalkResult::interrupt(); - } - } else - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); - return !walkResult.wasInterrupted(); -} - -/// Analyzes dependencies in the most primitive way by checking simple read and -/// write patterns. -static LogicalResult -verifyDependencies(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, - const IRMapping &firstToSecondPloopIndices, - llvm::function_ref mayAlias) { - if (!haveNoReadsAfterWriteExceptSameIndex( - firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) - return failure(); - - IRMapping secondToFirstPloopIndices; - secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), - firstPloop.getBody()->getArguments()); - return success(haveNoReadsAfterWriteExceptSameIndex( - secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); -} - -/// Verify equal iteration spaces. -static bool equalIterationSpaces(scf::ParallelOp firstPloop, - scf::ParallelOp secondPloop) { - if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) - return false; - - auto matchOperands = [&](const OperandRange &lhs, - const OperandRange &rhs) -> bool { - // TODO: Extend this to support aliases and equal constants. - return std::equal(lhs.begin(), lhs.end(), rhs.begin()); - }; - return matchOperands(firstPloop.getLowerBound(), - secondPloop.getLowerBound()) && - matchOperands(firstPloop.getUpperBound(), - secondPloop.getUpperBound()) && - matchOperands(firstPloop.getStep(), secondPloop.getStep()); -} - //===----------------------------------------------------------------------===// // Fusion related helpers //===----------------------------------------------------------------------===// From 67cb64f1a773795bfb2d4e9f0c981dd502572676 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 9 Jun 2024 14:45:57 -0500 Subject: [PATCH 15/33] refactor fuseIndependentSiblingForLoops to reuse replaceWithAdditionalYields --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 4 +- mlir/lib/Dialect/SCF/IR/SCF.cpp | 54 ++++++++ mlir/lib/Dialect/SCF/Utils/Utils.cpp | 151 +++++++++++++++------ 3 files changed, 168 insertions(+), 41 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index f35ea962bea16..e7b9665f797fa 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [ AttrSizedOperandSegments, AutomaticAllocationScope, DeclareOpInterfaceMethods, RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 5e94f4dc612a7..6ad181e2f3d77 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -616,8 +616,62 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, regions.push_back(RegionSuccessor(getResults())); } +std::optional ForallOp::getLoopResults() { return getResults(); } + SmallVector ForallOp::getLoopRegions() { return {&getRegion()}; } +FailureOr ForallOp::replaceWithAdditionalYields( + RewriterBase &rewriter, ValueRange newInitOperands, + bool replaceInitOperandUsesInLoop, + const NewYieldValuesFn &newYieldValuesFn) { + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(getOperation()); + auto inits = llvm::to_vector(getOutputs()); + inits.append(newInitOperands.begin(), newInitOperands.end()); + scf::ForallOp newLoop = rewriter.create( + getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(), + inits, getMapping()); + + // Generate the new yield values and append them to the scf.yield operation. + auto yieldOp = cast(getTerminator()); + ArrayRef newIterArgs = + newLoop.getBody()->getArguments().take_back(newInitOperands.size()); + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(yieldOp); + SmallVector newYieldedValues = + newYieldValuesFn(rewriter, getLoc(), newIterArgs); + assert(newInitOperands.size() == newYieldedValues.size() && + "expected as many new yield values as new iter operands"); + // rewriter.modifyOpInPlace(yieldOp, [&]() { + // yieldOp.getResultsMutable().append(newYieldedValues); + // }); + } + + // Move the loop body to the new op. + rewriter.mergeBlocks(getBody(), newLoop.getBody(), + newLoop.getBody()->getArguments().take_front( + getBody()->getNumArguments())); + + if (replaceInitOperandUsesInLoop) { + // Replace all uses of `newInitOperands` with the corresponding basic block + // arguments. + for (auto it : llvm::zip(newInitOperands, newIterArgs)) { + rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it), + [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } + } + + // Replace the old loop. + rewriter.replaceOp(getOperation(), + newLoop->getResults().take_front(getNumResults())); + return cast(newLoop.getOperation()); +} + /// Promotes the loop body of a forallOp to its containing block if it can be /// determined that the loop has a single iteration. LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index e6cb88c427da8..a61428208c405 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1089,6 +1089,92 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target, return iterSpaceEq; } +template +void fuseTerminator(RewriterBase &rewriter, LoopTy target, LoopTy source, + LoopTy &fused, IRMapping &mapping) {} + +template <> +void fuseTerminator(RewriterBase &rewriter, scf::ForallOp target, + scf::ForallOp source, scf::ForallOp &fused, + IRMapping &mapping) { + // Fuse the old terminator in_parallel ops into the new one. + scf::InParallelOp targetTerm = target.getTerminator(); + scf::InParallelOp sourceTerm = source.getTerminator(); + scf::InParallelOp fusedTerm = fused.getTerminator(); + rewriter.setInsertionPointToStart(fusedTerm.getBody()); + for (Operation &op : targetTerm.getYieldingOps()) + rewriter.clone(op, mapping); + for (Operation &op : sourceTerm.getYieldingOps()) + rewriter.clone(op, mapping); +} + +template <> +void fuseTerminator(RewriterBase &rewriter, scf::ForOp target, + scf::ForOp source, scf::ForOp &fused, IRMapping &mapping) { + // Build fused yield results by appropriately mapping original yield operands. + SmallVector yieldResults; + for (Value operand : target.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + for (Value operand : source.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + if (!yieldResults.empty()) + rewriter.create(source.getLoc(), yieldResults); +} + +template +LoopTy createFused(LoopLikeOpInterface target, LoopLikeOpInterface source, + RewriterBase &rewriter) { + auto targetResults = target.getLoopResults(); + auto sourceResults = source.getLoopResults(); + int64_t numTargetOuts = (*targetResults).size(); + int64_t numSourceOuts = (*sourceResults).size(); + printf("numTargetOuts %ld\n", numTargetOuts); + + // Create fused shared_outs. + SmallVector fusedOuts; + llvm::append_range(fusedOuts, *targetResults); + llvm::append_range(fusedOuts, *sourceResults); + + // Create a new scf.forall op after the source loop. + rewriter.setInsertionPointAfter(source); + // LoopTy fusedLoop = builder.create( + // source.getLoc(), source.getLoopLowerBounds(), + // source.getLoopUpperBounds(), source.getLoopSteps(), fusedOuts, + // source->getAttrs()); + LoopTy fusedLoop = rewriter.cloneWithoutRegions(cast(source)); + + // Map control operands. + IRMapping mapping; + mapping.map(*target.getLoopInductionVars(), + *fusedLoop.getLoopInductionVars()); + mapping.map(*source.getLoopInductionVars(), + *fusedLoop.getLoopInductionVars()); + + // Map shared outs. + mapping.map(target.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); + mapping.map(source.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); + + // Append everything except the terminator into the fused operation. + rewriter.setInsertionPointToStart( + &fusedLoop.getLoopRegions().front()->front()); + for (Operation &op : + target.getLoopRegions().front()->front().without_terminator()) + rewriter.clone(op, mapping); + for (Operation &op : + source.getLoopRegions().front()->front().without_terminator()) + rewriter.clone(op, mapping); + + fuseTerminator(rewriter, cast(target), cast(source), + cast(fusedLoop), mapping); + + rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); + rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + + return fusedLoop; +} + scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter) { @@ -1144,50 +1230,37 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter) { - unsigned numTargetOuts = target.getNumResults(); - unsigned numSourceOuts = source.getNumResults(); - - // Create fused init_args, with target's init_args before source's init_args. - SmallVector fusedInitArgs; - llvm::append_range(fusedInitArgs, target.getInitArgs()); - llvm::append_range(fusedInitArgs, source.getInitArgs()); - - // Create a new scf.for op after the source loop (with scf.yield terminator - // (without arguments) only in case its init_args is empty). - rewriter.setInsertionPointAfter(source); - scf::ForOp fusedLoop = rewriter.create( - source.getLoc(), source.getLowerBound(), source.getUpperBound(), - source.getStep(), fusedInitArgs); - + auto targetIterArgs = target.getRegionIterArgs(); + auto targetInductionVar = target.getInductionVar(); + SmallVector targetYieldOperands(source.getYieldedValues()); + auto sourceIterArgs = source.getRegionIterArgs(); + auto sourceInductionVar = source.getInductionVar(); + SmallVector sourceYieldOperands(source.getYieldedValues()); + scf::ForOp fusedLoop = cast(*target.replaceWithAdditionalYields( + rewriter, source.getInitArgs(), /*replaceInitOperandUsesInLoop=*/false, + [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { + return sourceYieldOperands; + })); // Map original induction variables and operands to those of the fused loop. IRMapping mapping; - mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); - mapping.map(target.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); - mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); - mapping.map(source.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); - + mapping.map(targetInductionVar, fusedLoop.getInductionVar()); + mapping.map(targetIterArgs, + fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size())); + mapping.map(targetYieldOperands, + fusedLoop.getYieldedValues().take_front(targetIterArgs.size())); + mapping.map(sourceInductionVar, fusedLoop.getInductionVar()); + mapping.map(sourceIterArgs, + fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size())); + mapping.map(sourceYieldOperands, + fusedLoop.getYieldedValues().take_back(sourceIterArgs.size())); // Merge target's body into the new (fused) for loop and then source's body. - rewriter.setInsertionPointToStart(fusedLoop.getBody()); - for (Operation &op : target.getBody()->without_terminator()) - rewriter.clone(op, mapping); + rewriter.setInsertionPoint(fusedLoop.getBody()->getTerminator()); for (Operation &op : source.getBody()->without_terminator()) rewriter.clone(op, mapping); - - // Build fused yield results by appropriately mapping original yield operands. - SmallVector yieldResults; - for (Value operand : target.getBody()->getTerminator()->getOperands()) - yieldResults.push_back(mapping.lookupOrDefault(operand)); - for (Value operand : source.getBody()->getTerminator()->getOperands()) - yieldResults.push_back(mapping.lookupOrDefault(operand)); - if (!yieldResults.empty()) - rewriter.create(source.getLoc(), yieldResults); - - // Replace old loops by substituting their uses by results of the fused loop. - rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); - rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); - + auto newTerm = rewriter.clone(*fusedLoop.getBody()->getTerminator(), mapping); + rewriter.replaceOp(fusedLoop.getBody()->getTerminator(), newTerm); + rewriter.replaceOp(source, + fusedLoop.getResults().take_back(source.getNumResults())); return fusedLoop; } From cc8599f69a90f9b460bbff950505004c214ac72e Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 9 Jun 2024 17:20:46 -0500 Subject: [PATCH 16/33] refactor fuseIndependentSiblingForallLoops to reuse replaceWithAdditionalYields --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 12 +--- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 69 +++++++++---------- .../SCF/transform-loop-fuse-sibling.mlir | 3 +- 3 files changed, 36 insertions(+), 48 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 6ad181e2f3d77..6850d632f10d0 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -637,17 +637,7 @@ FailureOr ForallOp::replaceWithAdditionalYields( auto yieldOp = cast(getTerminator()); ArrayRef newIterArgs = newLoop.getBody()->getArguments().take_back(newInitOperands.size()); - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(yieldOp); - SmallVector newYieldedValues = - newYieldValuesFn(rewriter, getLoc(), newIterArgs); - assert(newInitOperands.size() == newYieldedValues.size() && - "expected as many new yield values as new iter operands"); - // rewriter.modifyOpInPlace(yieldOp, [&]() { - // yieldOp.getResultsMutable().append(newYieldedValues); - // }); - } + newLoop.getTerminator().erase(); // Move the loop body to the new op. rewriter.mergeBlocks(getBody(), newLoop.getBody(), diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index a61428208c405..a822b4199fe9d 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1128,7 +1128,6 @@ LoopTy createFused(LoopLikeOpInterface target, LoopLikeOpInterface source, auto sourceResults = source.getLoopResults(); int64_t numTargetOuts = (*targetResults).size(); int64_t numSourceOuts = (*sourceResults).size(); - printf("numTargetOuts %ld\n", numTargetOuts); // Create fused shared_outs. SmallVector fusedOuts; @@ -1178,51 +1177,49 @@ LoopTy createFused(LoopLikeOpInterface target, LoopLikeOpInterface source, scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter) { - unsigned numTargetOuts = target.getNumResults(); - unsigned numSourceOuts = source.getNumResults(); - - // Create fused shared_outs. - SmallVector fusedOuts; - llvm::append_range(fusedOuts, target.getOutputs()); - llvm::append_range(fusedOuts, source.getOutputs()); - - // Create a new scf.forall op after the source loop. - rewriter.setInsertionPointAfter(source); - scf::ForallOp fusedLoop = rewriter.create( - source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), - source.getMixedStep(), fusedOuts, source.getMapping()); - + auto targetIterArgs = target.getRegionIterArgs(); + auto targetInductionVar = target.getInductionVars(); + SmallVector targetYieldOperands(target.getYieldedValues()); + auto sourceIterArgs = source.getRegionIterArgs(); + auto sourceInductionVar = source.getInductionVars(); + scf::InParallelOp sourceTerm = source.getTerminator(); + auto sourceYieldOps = sourceTerm.getYieldingOps(); + auto sourceBody = source.getBody(); + SmallVector sourceYieldOperands(llvm::map_range( + sourceTerm.getDests(), [](auto arg) { return cast(arg); })); + scf::ForallOp fusedLoop = + cast(*target.replaceWithAdditionalYields( + rewriter, source.getOutputs(), /*replaceInitOperandUsesInLoop=*/false, + [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { + for (Operation &op : sourceYieldOps) + b.clone(op); + return sourceYieldOperands; + })); // Map control operands. IRMapping mapping; - mapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); - mapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); - - // Map shared outs. - mapping.map(target.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); - mapping.map(source.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); - + mapping.map(targetInductionVar, fusedLoop.getInductionVars()); + mapping.map(targetIterArgs, + fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size())); + mapping.map(targetYieldOperands, + fusedLoop.getYieldedValues().take_front(targetIterArgs.size())); + mapping.map(sourceInductionVar, fusedLoop.getInductionVars()); + mapping.map(sourceIterArgs, + fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size())); + mapping.map(sourceYieldOperands, + fusedLoop.getYieldedValues().take_back(sourceIterArgs.size())); // Append everything except the terminator into the fused operation. - rewriter.setInsertionPointToStart(fusedLoop.getBody()); - for (Operation &op : target.getBody()->without_terminator()) - rewriter.clone(op, mapping); - for (Operation &op : source.getBody()->without_terminator()) + rewriter.setInsertionPoint(fusedLoop.getBody()->getTerminator()); + for (Operation &op : sourceBody->without_terminator()) rewriter.clone(op, mapping); // Fuse the old terminator in_parallel ops into the new one. - scf::InParallelOp targetTerm = target.getTerminator(); - scf::InParallelOp sourceTerm = source.getTerminator(); scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); - rewriter.setInsertionPointToStart(fusedTerm.getBody()); - for (Operation &op : targetTerm.getYieldingOps()) - rewriter.clone(op, mapping); + rewriter.setInsertionPointToEnd(fusedTerm.getBody()); for (Operation &op : sourceTerm.getYieldingOps()) rewriter.clone(op, mapping); - // Replace old loops by substituting their uses by results of the fused loop. - rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); - rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + rewriter.replaceOp(source, + fusedLoop.getResults().take_back(source.getNumResults())); return fusedLoop; } diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index 46c6be36c3271..47bfe0baa7651 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -189,7 +189,8 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK: func.func @matmul_fuse_2nd_forall_into_1st([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} +// CHECK-LABEL: func.func @matmul_fuse_2nd_forall_into_1st +// CHECK-SAME: [[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} func.func @matmul_fuse_2nd_forall_into_1st(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { %zero = arith.constant 0.0 : f32 %out_alloc = tensor.empty() : tensor<128x128xf32> From 48b1af9cb4392b8ccad748e17ce40fa997db6a59 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 9 Jun 2024 19:12:40 -0500 Subject: [PATCH 17/33] wip --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 4 +- mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 + mlir/lib/Dialect/SCF/Utils/Utils.cpp | 170 ++++++--------------- 3 files changed, 50 insertions(+), 128 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index e7b9665f797fa..b9345f6ecdbb2 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -302,8 +302,8 @@ def ForallOp : SCF_Op<"forall", [ AutomaticAllocationScope, DeclareOpInterfaceMethods, + "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps", "getYieldedValuesMutable", + "replaceWithAdditionalYields", "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>, RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 6850d632f10d0..b4a16e519a15a 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1472,6 +1472,10 @@ SmallVector ForallOp::getCombiningOps(BlockArgument bbArg) { return storeOps; } +std::optional> ForallOp::getYieldedValuesMutable() { + return getOutputsMutable(); +} + std::optional> ForallOp::getLoopInductionVars() { return SmallVector{getBody()->getArguments().take_front(getRank())}; } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index a822b4199fe9d..fb2d1d11fb6ae 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1090,134 +1090,76 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target, } template -void fuseTerminator(RewriterBase &rewriter, LoopTy target, LoopTy source, - LoopTy &fused, IRMapping &mapping) {} +void fuseTerminator(RewriterBase &rewriter, LoopTy source, LoopTy &fused, + IRMapping &mapping) {} template <> -void fuseTerminator(RewriterBase &rewriter, scf::ForallOp target, - scf::ForallOp source, scf::ForallOp &fused, - IRMapping &mapping) { +void fuseTerminator(RewriterBase &rewriter, scf::ForallOp source, + scf::ForallOp &fused, IRMapping &mapping) { // Fuse the old terminator in_parallel ops into the new one. - scf::InParallelOp targetTerm = target.getTerminator(); - scf::InParallelOp sourceTerm = source.getTerminator(); scf::InParallelOp fusedTerm = fused.getTerminator(); - rewriter.setInsertionPointToStart(fusedTerm.getBody()); - for (Operation &op : targetTerm.getYieldingOps()) - rewriter.clone(op, mapping); - for (Operation &op : sourceTerm.getYieldingOps()) + rewriter.setInsertionPointToEnd(fusedTerm.getBody()); + for (Operation &op : source.getTerminator().getYieldingOps()) rewriter.clone(op, mapping); } template <> -void fuseTerminator(RewriterBase &rewriter, scf::ForOp target, - scf::ForOp source, scf::ForOp &fused, IRMapping &mapping) { +void fuseTerminator(RewriterBase &rewriter, scf::ForOp source, + scf::ForOp &fused, IRMapping &mapping) { // Build fused yield results by appropriately mapping original yield operands. - SmallVector yieldResults; - for (Value operand : target.getBody()->getTerminator()->getOperands()) - yieldResults.push_back(mapping.lookupOrDefault(operand)); - for (Value operand : source.getBody()->getTerminator()->getOperands()) - yieldResults.push_back(mapping.lookupOrDefault(operand)); - if (!yieldResults.empty()) - rewriter.create(source.getLoc(), yieldResults); + auto newTerm = rewriter.clone(*fused.getBody()->getTerminator(), mapping); + rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm); } template -LoopTy createFused(LoopLikeOpInterface target, LoopLikeOpInterface source, - RewriterBase &rewriter) { - auto targetResults = target.getLoopResults(); - auto sourceResults = source.getLoopResults(); - int64_t numTargetOuts = (*targetResults).size(); - int64_t numSourceOuts = (*sourceResults).size(); - - // Create fused shared_outs. - SmallVector fusedOuts; - llvm::append_range(fusedOuts, *targetResults); - llvm::append_range(fusedOuts, *sourceResults); - - // Create a new scf.forall op after the source loop. - rewriter.setInsertionPointAfter(source); - // LoopTy fusedLoop = builder.create( - // source.getLoc(), source.getLoopLowerBounds(), - // source.getLoopUpperBounds(), source.getLoopSteps(), fusedOuts, - // source->getAttrs()); - LoopTy fusedLoop = rewriter.cloneWithoutRegions(cast(source)); - - // Map control operands. - IRMapping mapping; - mapping.map(*target.getLoopInductionVars(), - *fusedLoop.getLoopInductionVars()); - mapping.map(*source.getLoopInductionVars(), - *fusedLoop.getLoopInductionVars()); - - // Map shared outs. - mapping.map(target.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); - mapping.map(source.getRegionIterArgs(), - fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); - - // Append everything except the terminator into the fused operation. - rewriter.setInsertionPointToStart( - &fusedLoop.getLoopRegions().front()->front()); - for (Operation &op : - target.getLoopRegions().front()->front().without_terminator()) - rewriter.clone(op, mapping); - for (Operation &op : - source.getLoopRegions().front()->front().without_terminator()) - rewriter.clone(op, mapping); - - fuseTerminator(rewriter, cast(target), cast(source), - cast(fusedLoop), mapping); - - rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); - rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); - - return fusedLoop; -} - -scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, - scf::ForallOp source, - RewriterBase &rewriter) { +LoopLikeOpInterface +createFused(LoopLikeOpInterface target, LoopLikeOpInterface source, + RewriterBase &rewriter, NewYieldValuesFn newYieldValuesFn) { auto targetIterArgs = target.getRegionIterArgs(); - auto targetInductionVar = target.getInductionVars(); + auto targetInductionVar = *target.getLoopInductionVars(); SmallVector targetYieldOperands(target.getYieldedValues()); auto sourceIterArgs = source.getRegionIterArgs(); - auto sourceInductionVar = source.getInductionVars(); - scf::InParallelOp sourceTerm = source.getTerminator(); - auto sourceYieldOps = sourceTerm.getYieldingOps(); - auto sourceBody = source.getBody(); - SmallVector sourceYieldOperands(llvm::map_range( - sourceTerm.getDests(), [](auto arg) { return cast(arg); })); - scf::ForallOp fusedLoop = - cast(*target.replaceWithAdditionalYields( - rewriter, source.getOutputs(), /*replaceInitOperandUsesInLoop=*/false, - [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { - for (Operation &op : sourceYieldOps) - b.clone(op); - return sourceYieldOperands; - })); + auto sourceInductionVar = *source.getLoopInductionVars(); + SmallVector sourceYieldOperands(source.getYieldedValues()); + auto sourceRegion = source.getLoopRegions().front(); + LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields( + rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false, + newYieldValuesFn); + // Map control operands. IRMapping mapping; - mapping.map(targetInductionVar, fusedLoop.getInductionVars()); + mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars()); mapping.map(targetIterArgs, fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size())); mapping.map(targetYieldOperands, fusedLoop.getYieldedValues().take_front(targetIterArgs.size())); - mapping.map(sourceInductionVar, fusedLoop.getInductionVars()); + mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars()); mapping.map(sourceIterArgs, fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size())); mapping.map(sourceYieldOperands, fusedLoop.getYieldedValues().take_back(sourceIterArgs.size())); // Append everything except the terminator into the fused operation. - rewriter.setInsertionPoint(fusedLoop.getBody()->getTerminator()); - for (Operation &op : sourceBody->without_terminator()) + rewriter.setInsertionPoint( + fusedLoop.getLoopRegions().front()->front().getTerminator()); + for (Operation &op : sourceRegion->front().without_terminator()) rewriter.clone(op, mapping); - // Fuse the old terminator in_parallel ops into the new one. - scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); - rewriter.setInsertionPointToEnd(fusedTerm.getBody()); - for (Operation &op : sourceTerm.getYieldingOps()) - rewriter.clone(op, mapping); + fuseTerminator(rewriter, cast(source), + cast(fusedLoop), mapping); + + return fusedLoop; +} +scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, + scf::ForallOp source, + RewriterBase &rewriter) { + scf::ForallOp fusedLoop = cast(createFused( + target, source, rewriter, + [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { + for (Operation &op : source.getTerminator().getYieldingOps()) + b.clone(op); + return source.getYieldedValues(); + })); rewriter.replaceOp(source, fusedLoop.getResults().take_back(source.getNumResults())); @@ -1227,35 +1169,11 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter) { - auto targetIterArgs = target.getRegionIterArgs(); - auto targetInductionVar = target.getInductionVar(); - SmallVector targetYieldOperands(source.getYieldedValues()); - auto sourceIterArgs = source.getRegionIterArgs(); - auto sourceInductionVar = source.getInductionVar(); - SmallVector sourceYieldOperands(source.getYieldedValues()); - scf::ForOp fusedLoop = cast(*target.replaceWithAdditionalYields( - rewriter, source.getInitArgs(), /*replaceInitOperandUsesInLoop=*/false, + scf::ForOp fusedLoop = cast(createFused( + target, source, rewriter, [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { - return sourceYieldOperands; + return source.getYieldedValues(); })); - // Map original induction variables and operands to those of the fused loop. - IRMapping mapping; - mapping.map(targetInductionVar, fusedLoop.getInductionVar()); - mapping.map(targetIterArgs, - fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size())); - mapping.map(targetYieldOperands, - fusedLoop.getYieldedValues().take_front(targetIterArgs.size())); - mapping.map(sourceInductionVar, fusedLoop.getInductionVar()); - mapping.map(sourceIterArgs, - fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size())); - mapping.map(sourceYieldOperands, - fusedLoop.getYieldedValues().take_back(sourceIterArgs.size())); - // Merge target's body into the new (fused) for loop and then source's body. - rewriter.setInsertionPoint(fusedLoop.getBody()->getTerminator()); - for (Operation &op : source.getBody()->without_terminator()) - rewriter.clone(op, mapping); - auto newTerm = rewriter.clone(*fusedLoop.getBody()->getTerminator(), mapping); - rewriter.replaceOp(fusedLoop.getBody()->getTerminator(), newTerm); rewriter.replaceOp(source, fusedLoop.getResults().take_back(source.getNumResults())); return fusedLoop; From 7a51cb34afd5d8a2b67cceaef457f50c032affbd Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 17 Jun 2024 14:36:44 -0500 Subject: [PATCH 18/33] Decouple concrete loop type from `createFused` function --- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 35 +++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index fb2d1d11fb6ae..910c41b3e3d54 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1111,10 +1111,29 @@ void fuseTerminator(RewriterBase &rewriter, scf::ForOp source, rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm); } -template -LoopLikeOpInterface -createFused(LoopLikeOpInterface target, LoopLikeOpInterface source, - RewriterBase &rewriter, NewYieldValuesFn newYieldValuesFn) { +// TODO: We should maybe add this as a method to LoopLikeOpInterface. +// For now, this acts as a placeholder. +template <> +void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source, + LoopLikeOpInterface &fused, IRMapping &mapping) { + if (isa(source) && isa(fused)) { + fuseTerminator(rewriter, cast(source), cast(fused), + mapping); + } else if (isa(source) && isa(fused)) { + fuseTerminator(rewriter, cast(source), + cast(fused), mapping); + } else if (isa(source) && isa(fused)) { + fuseTerminator(rewriter, cast(source), + cast(fused), mapping); + } else { + return; + } +} + +LoopLikeOpInterface createFused(LoopLikeOpInterface target, + LoopLikeOpInterface source, + RewriterBase &rewriter, + NewYieldValuesFn newYieldValuesFn) { auto targetIterArgs = target.getRegionIterArgs(); auto targetInductionVar = *target.getLoopInductionVars(); SmallVector targetYieldOperands(target.getYieldedValues()); @@ -1144,8 +1163,8 @@ createFused(LoopLikeOpInterface target, LoopLikeOpInterface source, for (Operation &op : sourceRegion->front().without_terminator()) rewriter.clone(op, mapping); - fuseTerminator(rewriter, cast(source), - cast(fusedLoop), mapping); + // TODO: Replace with interface method if added + fuseTerminator(rewriter, source, fusedLoop, mapping); return fusedLoop; } @@ -1153,7 +1172,7 @@ createFused(LoopLikeOpInterface target, LoopLikeOpInterface source, scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter) { - scf::ForallOp fusedLoop = cast(createFused( + scf::ForallOp fusedLoop = cast(createFused( target, source, rewriter, [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { for (Operation &op : source.getTerminator().getYieldingOps()) @@ -1169,7 +1188,7 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter) { - scf::ForOp fusedLoop = cast(createFused( + scf::ForOp fusedLoop = cast(createFused( target, source, rewriter, [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { return source.getYieldedValues(); From 30873263faaab18267109231094af408b819059a Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 17 Jun 2024 15:10:44 -0500 Subject: [PATCH 19/33] Refactor ForallOp::replaceWithAdditionalYields --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 9 +++------ mlir/lib/Dialect/SCF/Utils/Utils.cpp | 5 ++--- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index b4a16e519a15a..c5a9e18e2610c 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -633,12 +633,7 @@ FailureOr ForallOp::replaceWithAdditionalYields( getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(), inits, getMapping()); - // Generate the new yield values and append them to the scf.yield operation. - auto yieldOp = cast(getTerminator()); - ArrayRef newIterArgs = - newLoop.getBody()->getArguments().take_back(newInitOperands.size()); newLoop.getTerminator().erase(); - // Move the loop body to the new op. rewriter.mergeBlocks(getBody(), newLoop.getBody(), newLoop.getBody()->getArguments().take_front( @@ -647,7 +642,9 @@ FailureOr ForallOp::replaceWithAdditionalYields( if (replaceInitOperandUsesInLoop) { // Replace all uses of `newInitOperands` with the corresponding basic block // arguments. - for (auto it : llvm::zip(newInitOperands, newIterArgs)) { + for (auto it : + llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back( + newInitOperands.size()))) { rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it), [&](OpOperand &use) { Operation *user = use.getOwner(); diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 910c41b3e3d54..5ef6718bc5346 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1175,9 +1175,8 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp fusedLoop = cast(createFused( target, source, rewriter, [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { - for (Operation &op : source.getTerminator().getYieldingOps()) - b.clone(op); - return source.getYieldedValues(); + // `ForallOp` does not have yields, rather an `InParallelOp` terminator. + return ValueRange{}; })); rewriter.replaceOp(source, fusedLoop.getResults().take_back(source.getNumResults())); From bcf3d4aaed9e425f3a3b2d97660c6e816e333abe Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 17 Jun 2024 15:50:55 -0500 Subject: [PATCH 20/33] revert unnecessary changes --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 7 ++++--- mlir/lib/Dialect/SCF/IR/SCF.cpp | 6 ------ mlir/lib/Dialect/SCF/Utils/Utils.cpp | 1 - 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index b9345f6ecdbb2..bf95fbe6721cf 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -301,9 +301,10 @@ def ForallOp : SCF_Op<"forall", [ AttrSizedOperandSegments, AutomaticAllocationScope, DeclareOpInterfaceMethods, + ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", + "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps", + "replaceWithAdditionalYields", "promoteIfSingleIteration", + "yieldTiledValuesAndReplace"]>, RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index c5a9e18e2610c..deface43028b1 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -616,8 +616,6 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, regions.push_back(RegionSuccessor(getResults())); } -std::optional ForallOp::getLoopResults() { return getResults(); } - SmallVector ForallOp::getLoopRegions() { return {&getRegion()}; } FailureOr ForallOp::replaceWithAdditionalYields( @@ -1469,10 +1467,6 @@ SmallVector ForallOp::getCombiningOps(BlockArgument bbArg) { return storeOps; } -std::optional> ForallOp::getYieldedValuesMutable() { - return getOutputsMutable(); -} - std::optional> ForallOp::getLoopInductionVars() { return SmallVector{getBody()->getArguments().take_front(getRank())}; } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 5ef6718bc5346..2e61f9998a7d8 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -15,7 +15,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" From 0cb3c4ea08b22eea318fa47634914f921f08f7f2 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 18 Jun 2024 10:35:53 -0500 Subject: [PATCH 21/33] cleanup --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++-- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 6 +++--- mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir | 3 +-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index deface43028b1..2baef9ca45db1 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -625,13 +625,13 @@ FailureOr ForallOp::replaceWithAdditionalYields( // Create a new loop before the existing one, with the extra operands. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(getOperation()); - auto inits = llvm::to_vector(getOutputs()); + SmallVector inits(getOutputs()); inits.append(newInitOperands.begin(), newInitOperands.end()); scf::ForallOp newLoop = rewriter.create( getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(), inits, getMapping()); - newLoop.getTerminator().erase(); + rewriter.eraseOp(newLoop.getTerminator()); // Move the loop body to the new op. rewriter.mergeBlocks(getBody(), newLoop.getBody(), newLoop.getBody()->getArguments().take_front( diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 2e61f9998a7d8..09da6e6233ffc 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1110,8 +1110,8 @@ void fuseTerminator(RewriterBase &rewriter, scf::ForOp source, rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm); } -// TODO: We should maybe add this as a method to LoopLikeOpInterface. -// For now, this acts as a placeholder. +// TODO: We should maybe add a method to LoopLikeOpInterface that will +// facilitate this transformation. For now, this acts as a placeholder. template <> void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source, LoopLikeOpInterface &fused, IRMapping &mapping) { @@ -1162,7 +1162,7 @@ LoopLikeOpInterface createFused(LoopLikeOpInterface target, for (Operation &op : sourceRegion->front().without_terminator()) rewriter.clone(op, mapping); - // TODO: Replace with interface method if added + // TODO: Replace with corresponding interface method if added fuseTerminator(rewriter, source, fusedLoop, mapping); return fusedLoop; diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index 47bfe0baa7651..46c6be36c3271 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -189,8 +189,7 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func.func @matmul_fuse_2nd_forall_into_1st -// CHECK-SAME: [[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} +// CHECK: func.func @matmul_fuse_2nd_forall_into_1st([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} func.func @matmul_fuse_2nd_forall_into_1st(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { %zero = arith.constant 0.0 : f32 %out_alloc = tensor.empty() : tensor<128x128xf32> From 7e41a549f966956204f6f0971831e0423a9aeb9d Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 21 Jun 2024 15:52:08 -0500 Subject: [PATCH 22/33] address some review comments --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 13 ++++++------- .../Dialect/SCF/TransformOps/SCFTransformOps.cpp | 9 ++++++--- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 13 ++++++------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 2baef9ca45db1..0c967ac68a081 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -626,7 +626,7 @@ FailureOr ForallOp::replaceWithAdditionalYields( OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(getOperation()); SmallVector inits(getOutputs()); - inits.append(newInitOperands.begin(), newInitOperands.end()); + llvm::append_range(inits, newInitOperands); scf::ForallOp newLoop = rewriter.create( getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(), inits, getMapping()); @@ -640,14 +640,13 @@ FailureOr ForallOp::replaceWithAdditionalYields( if (replaceInitOperandUsesInLoop) { // Replace all uses of `newInitOperands` with the corresponding basic block // arguments. - for (auto it : + for (auto &&[newOperand, oldOperand] : llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back( newInitOperands.size()))) { - rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it), - [&](OpOperand &use) { - Operation *user = use.getOwner(); - return newLoop->isProperAncestor(user); - }); + rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); } } diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 99f92d7e24840..0e13b503098f0 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp, return 1; }; - std::optional ubConstant = getConstantIntValue(forOp.getUpperBound()); - std::optional lbConstant = getConstantIntValue(forOp.getLowerBound()); + std::optional ubConstant = + getConstantIntValue(forOp.getUpperBound()); + std::optional lbConstant = + getConstantIntValue(forOp.getLowerBound()); DenseMap opCycles; std::map> wrappedSchedule; for (Operation &op : forOp.getBody()->getOperations()) { @@ -528,7 +530,8 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, << "operations cannot be fused"; Operation *fusedLoop; - /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. + // TODO: Support fusion for loop-like ops besides scf.for, scf.forall + // and scf.parallel. if (isa(target) && isa(source)) { fusedLoop = fuseIndependentSiblingForLoops( cast(target), cast(source), rewriter); diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 09da6e6233ffc..dc15015e9bec2 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1076,7 +1076,7 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target, LoopLikeOpInterface &source) { - auto iterSpaceEq = + bool iterSpaceEq = target.getLoopLowerBounds() == source.getLoopLowerBounds() && target.getLoopUpperBounds() == source.getLoopUpperBounds() && target.getLoopSteps() == source.getLoopSteps(); @@ -1125,6 +1125,7 @@ void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source, fuseTerminator(rewriter, cast(source), cast(fused), mapping); } else { + llvm_unreachable("unsupported loop types."); return; } } @@ -1239,13 +1240,11 @@ scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops( newRedBlock.begin(), newRedBlock.getArguments()); } - target.replaceAllUsesWith(results.take_front(inits1.size())); - source.replaceAllUsesWith(results.take_back(inits2.size())); } - term1->erase(); - term2->erase(); - target.erase(); - source.erase(); + rewriter.replaceOp(target, results.take_front(inits1.size())); + rewriter.replaceOp(source, results.take_back(inits2.size())); + rewriter.eraseOp(term1); + rewriter.eraseOp(term2); return fusedLoop; } From cc95d75d2cc09f8a33850f3867c8313e374a0dfd Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 24 Jun 2024 14:56:48 -0500 Subject: [PATCH 23/33] move `createFused` to `LoopLikeInterface.h` --- .../mlir/Interfaces/LoopLikeInterface.h | 20 ++++ mlir/lib/Dialect/SCF/Utils/Utils.cpp | 101 ++++-------------- mlir/lib/Interfaces/LoopLikeInterface.cpp | 42 ++++++++ 3 files changed, 82 insertions(+), 81 deletions(-) diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h index 42609e824c86a..d862439a07790 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h @@ -57,4 +57,24 @@ class HasParallelRegion : public TraitBase { /// Include the generated interface declarations. #include "mlir/Interfaces/LoopLikeInterface.h.inc" +namespace mlir { +/// A function that rewrites `target`'s terminator as a teminator obtained by +/// fusing `source` into `target`. +using FuseTerminatorFn = + std::function; + +/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to +/// `target`. The `NewYieldValuesFn` callback is used to pass to the +/// `replaceWithAdditionalYields` interface method to replace the loop with a +/// new loop with (possibly) additional yields, while the `FuseTerminatorFn` +/// callback is repsonsible for updating the fused loop terminator. +LoopLikeOpInterface createFused(LoopLikeOpInterface target, + LoopLikeOpInterface source, + RewriterBase &rewriter, + NewYieldValuesFn newYieldValuesFn, + FuseTerminatorFn fuseTerminatorFn); + +} // namespace mlir + #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_ diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index dc15015e9bec2..93e7a40845b2e 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1082,93 +1082,14 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target, target.getLoopSteps() == source.getLoopSteps(); auto forAllTarget = dyn_cast(*target); auto forAllSource = dyn_cast(*source); + // TODO: Decouple checks on concrete loop types and move this function + // somewhere for general utility for `LoopLikeOpInterface` if (forAllTarget && forAllSource) return iterSpaceEq && forAllTarget.getMapping() == forAllSource.getMapping(); return iterSpaceEq; } -template -void fuseTerminator(RewriterBase &rewriter, LoopTy source, LoopTy &fused, - IRMapping &mapping) {} - -template <> -void fuseTerminator(RewriterBase &rewriter, scf::ForallOp source, - scf::ForallOp &fused, IRMapping &mapping) { - // Fuse the old terminator in_parallel ops into the new one. - scf::InParallelOp fusedTerm = fused.getTerminator(); - rewriter.setInsertionPointToEnd(fusedTerm.getBody()); - for (Operation &op : source.getTerminator().getYieldingOps()) - rewriter.clone(op, mapping); -} - -template <> -void fuseTerminator(RewriterBase &rewriter, scf::ForOp source, - scf::ForOp &fused, IRMapping &mapping) { - // Build fused yield results by appropriately mapping original yield operands. - auto newTerm = rewriter.clone(*fused.getBody()->getTerminator(), mapping); - rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm); -} - -// TODO: We should maybe add a method to LoopLikeOpInterface that will -// facilitate this transformation. For now, this acts as a placeholder. -template <> -void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source, - LoopLikeOpInterface &fused, IRMapping &mapping) { - if (isa(source) && isa(fused)) { - fuseTerminator(rewriter, cast(source), cast(fused), - mapping); - } else if (isa(source) && isa(fused)) { - fuseTerminator(rewriter, cast(source), - cast(fused), mapping); - } else if (isa(source) && isa(fused)) { - fuseTerminator(rewriter, cast(source), - cast(fused), mapping); - } else { - llvm_unreachable("unsupported loop types."); - return; - } -} - -LoopLikeOpInterface createFused(LoopLikeOpInterface target, - LoopLikeOpInterface source, - RewriterBase &rewriter, - NewYieldValuesFn newYieldValuesFn) { - auto targetIterArgs = target.getRegionIterArgs(); - auto targetInductionVar = *target.getLoopInductionVars(); - SmallVector targetYieldOperands(target.getYieldedValues()); - auto sourceIterArgs = source.getRegionIterArgs(); - auto sourceInductionVar = *source.getLoopInductionVars(); - SmallVector sourceYieldOperands(source.getYieldedValues()); - auto sourceRegion = source.getLoopRegions().front(); - LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields( - rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false, - newYieldValuesFn); - - // Map control operands. - IRMapping mapping; - mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars()); - mapping.map(targetIterArgs, - fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size())); - mapping.map(targetYieldOperands, - fusedLoop.getYieldedValues().take_front(targetIterArgs.size())); - mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars()); - mapping.map(sourceIterArgs, - fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size())); - mapping.map(sourceYieldOperands, - fusedLoop.getYieldedValues().take_back(sourceIterArgs.size())); - // Append everything except the terminator into the fused operation. - rewriter.setInsertionPoint( - fusedLoop.getLoopRegions().front()->front().getTerminator()); - for (Operation &op : sourceRegion->front().without_terminator()) - rewriter.clone(op, mapping); - - // TODO: Replace with corresponding interface method if added - fuseTerminator(rewriter, source, fusedLoop, mapping); - - return fusedLoop; -} - scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter) { @@ -1177,6 +1098,15 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { // `ForallOp` does not have yields, rather an `InParallelOp` terminator. return ValueRange{}; + }, + [&](RewriterBase &b, LoopLikeOpInterface source, + LoopLikeOpInterface &target, IRMapping mapping) { + auto sourceForall = cast(source); + auto targetForall = cast(target); + scf::InParallelOp fusedTerm = targetForall.getTerminator(); + b.setInsertionPointToEnd(fusedTerm.getBody()); + for (Operation &op : sourceForall.getTerminator().getYieldingOps()) + b.clone(op, mapping); })); rewriter.replaceOp(source, fusedLoop.getResults().take_back(source.getNumResults())); @@ -1191,12 +1121,21 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, target, source, rewriter, [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { return source.getYieldedValues(); + }, + [&](RewriterBase &b, LoopLikeOpInterface source, + LoopLikeOpInterface &target, IRMapping mapping) { + auto sourceFor = cast(source); + auto targetFor = cast(target); + auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping); + b.replaceOp(targetFor.getBody()->getTerminator(), newTerm); })); rewriter.replaceOp(source, fusedLoop.getResults().take_back(source.getNumResults())); return fusedLoop; } +// TODO: Finish refactoring this a la the above, but likely requires additional +// interface methods. scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops( scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) { Block *block1 = target.getBody(); diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp index 1e0e87b64e811..aefd388461570 100644 --- a/mlir/lib/Interfaces/LoopLikeInterface.cpp +++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp @@ -8,6 +8,8 @@ #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "llvm/ADT/DenseSet.h" @@ -113,3 +115,43 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) { return success(); } + +LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target, + LoopLikeOpInterface source, + RewriterBase &rewriter, + NewYieldValuesFn newYieldValuesFn, + FuseTerminatorFn fuseTerminatorFn) { + auto targetIterArgs = target.getRegionIterArgs(); + auto targetInductionVar = *target.getLoopInductionVars(); + SmallVector targetYieldOperands(target.getYieldedValues()); + auto sourceIterArgs = source.getRegionIterArgs(); + auto sourceInductionVar = *source.getLoopInductionVars(); + SmallVector sourceYieldOperands(source.getYieldedValues()); + auto sourceRegion = source.getLoopRegions().front(); + LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields( + rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false, + newYieldValuesFn); + + // Map control operands. + IRMapping mapping; + mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars()); + mapping.map(targetIterArgs, + fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size())); + mapping.map(targetYieldOperands, + fusedLoop.getYieldedValues().take_front(targetIterArgs.size())); + mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars()); + mapping.map(sourceIterArgs, + fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size())); + mapping.map(sourceYieldOperands, + fusedLoop.getYieldedValues().take_back(sourceIterArgs.size())); + // Append everything except the terminator into the fused operation. + rewriter.setInsertionPoint( + fusedLoop.getLoopRegions().front()->front().getTerminator()); + for (Operation &op : sourceRegion->front().without_terminator()) + rewriter.clone(op, mapping); + + // TODO: Replace with corresponding interface method if added + fuseTerminatorFn(rewriter, source, fusedLoop, mapping); + + return fusedLoop; +} From 3430a36fda3c53d466550a7d8fd13b331f96f005 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 26 Jun 2024 13:51:34 -0500 Subject: [PATCH 24/33] address more review comments --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++-- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 0c967ac68a081..1e42376ce58ca 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -629,9 +629,9 @@ FailureOr ForallOp::replaceWithAdditionalYields( llvm::append_range(inits, newInitOperands); scf::ForallOp newLoop = rewriter.create( getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(), - inits, getMapping()); + inits, getMapping(), + /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); - rewriter.eraseOp(newLoop.getTerminator()); // Move the loop body to the new op. rewriter.mergeBlocks(getBody(), newLoop.getBody(), newLoop.getBody()->getArguments().take_front( diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 93e7a40845b2e..e7496cd97cd63 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1124,7 +1124,6 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, }, [&](RewriterBase &b, LoopLikeOpInterface source, LoopLikeOpInterface &target, IRMapping mapping) { - auto sourceFor = cast(source); auto targetFor = cast(target); auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping); b.replaceOp(targetFor.getBody()->getTerminator(), newTerm); @@ -1151,8 +1150,9 @@ scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops( rewriter.setInsertionPoint(source); auto fusedLoop = rewriter.create( - source.getLoc(), source.getLowerBound(), source.getUpperBound(), - source.getStep(), newInitVars); + rewriter.getFusedLoc(target.getLoc(), source.getLoc()), + source.getLowerBound(), source.getUpperBound(), source.getStep(), + newInitVars); Block *newBlock = fusedLoop.getBody(); rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(), newBlock->getArguments()); @@ -1168,8 +1168,8 @@ scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops( SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); - auto newReduceOp = - rewriter.create(term2.getLoc(), newReduceArgs); + auto newReduceOp = rewriter.create( + rewriter.getFusedLoc(term1.getLoc(), term2.getLoc()), newReduceArgs); for (auto &&[i, reg] : llvm::enumerate(llvm::concat( term1.getReductions(), term2.getReductions()))) { From 8447c121b95279a283b5e7b25f094f6abb062216 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 26 Jun 2024 20:06:43 -0500 Subject: [PATCH 25/33] switch to function_ref --- mlir/include/mlir/Interfaces/LoopLikeInterface.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h index d862439a07790..cfe2c14b838f6 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h @@ -61,8 +61,8 @@ namespace mlir { /// A function that rewrites `target`'s terminator as a teminator obtained by /// fusing `source` into `target`. using FuseTerminatorFn = - std::function; + function_ref; /// Returns a fused `LoopLikeOpInterface` created by fusing `source` to /// `target`. The `NewYieldValuesFn` callback is used to pass to the From fbd7b72bb44c7833a683d93fccaa9d992856ee8b Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 26 Jun 2024 22:52:20 -0500 Subject: [PATCH 26/33] check optional values --- mlir/lib/Interfaces/LoopLikeInterface.cpp | 27 +++++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp index aefd388461570..6f0ebec0519be 100644 --- a/mlir/lib/Interfaces/LoopLikeInterface.cpp +++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp @@ -122,24 +122,37 @@ LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target, NewYieldValuesFn newYieldValuesFn, FuseTerminatorFn fuseTerminatorFn) { auto targetIterArgs = target.getRegionIterArgs(); - auto targetInductionVar = *target.getLoopInductionVars(); + std::optional> targetInductionVar = + target.getLoopInductionVars(); SmallVector targetYieldOperands(target.getYieldedValues()); auto sourceIterArgs = source.getRegionIterArgs(); - auto sourceInductionVar = *source.getLoopInductionVars(); + std::optional> sourceInductionVar = + *source.getLoopInductionVars(); SmallVector sourceYieldOperands(source.getYieldedValues()); auto sourceRegion = source.getLoopRegions().front(); - LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields( - rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false, - newYieldValuesFn); + + FailureOr maybeFusedLoop = + target.replaceWithAdditionalYields(rewriter, source.getInits(), + /*replaceInitOperandUsesInLoop=*/false, + newYieldValuesFn); + if (failed(maybeFusedLoop)) + llvm_unreachable("failed to replace loop"); + LoopLikeOpInterface fusedLoop = *maybeFusedLoop; // Map control operands. IRMapping mapping; - mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars()); + std::optional> fusedInductionVar = + fusedLoop.getLoopInductionVars(); + if (fusedInductionVar) { + if (!targetInductionVar || !sourceInductionVar) + llvm_unreachable("expected target and source loops to have induction vars"); + mapping.map(*targetInductionVar, *fusedInductionVar); + mapping.map(*sourceInductionVar, *fusedInductionVar); + } mapping.map(targetIterArgs, fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size())); mapping.map(targetYieldOperands, fusedLoop.getYieldedValues().take_front(targetIterArgs.size())); - mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars()); mapping.map(sourceIterArgs, fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size())); mapping.map(sourceYieldOperands, From ffb73a7a76b382414f8f8295f6d6dc14a3edfa99 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 26 Jun 2024 23:34:38 -0500 Subject: [PATCH 27/33] replace equalIterationSpaces with checkFusionStructuredLegality --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 5 +++-- .../SCF/Transforms/ParallelLoopFusion.cpp | 20 +------------------ mlir/lib/Dialect/SCF/Utils/Utils.cpp | 4 ++-- 3 files changed, 6 insertions(+), 23 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index ac4434b337890..ca3ab0aeae1de 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -160,8 +160,9 @@ void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, // Fusion related helpers //===----------------------------------------------------------------------===// -bool checkFusionStructuralLegality(LoopLikeOpInterface &target, - LoopLikeOpInterface &source); +/// Check structural compatibility between two loops such as iteration space. +bool checkFusionStructuralLegality(LoopLikeOpInterface target, + LoopLikeOpInterface source); /// Given two scf.forall loops, `target` and `source`, fuses `target` into /// `source`. Assumes that the given loops are siblings and are independent of diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index 326a8f93162b9..fd57a9228186e 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -38,24 +38,6 @@ static bool hasNestedParallelOp(ParallelOp ploop) { return walkResult.wasInterrupted(); } -/// Verify equal iteration spaces. -static bool equalIterationSpaces(ParallelOp firstPloop, - ParallelOp secondPloop) { - if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) - return false; - - auto matchOperands = [&](const OperandRange &lhs, - const OperandRange &rhs) -> bool { - // TODO: Extend this to support aliases and equal constants. - return std::equal(lhs.begin(), lhs.end(), rhs.begin()); - }; - return matchOperands(firstPloop.getLowerBound(), - secondPloop.getLowerBound()) && - matchOperands(firstPloop.getUpperBound(), - secondPloop.getUpperBound()) && - matchOperands(firstPloop.getStep(), secondPloop.getStep()); -} - /// Checks if the parallel loops have mixed access to the same buffers. Returns /// `true` if the first parallel loop writes to the same indices that the second /// loop reads. @@ -156,7 +138,7 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, llvm::function_ref mayAlias) { return !hasNestedParallelOp(firstPloop) && !hasNestedParallelOp(secondPloop) && - equalIterationSpaces(firstPloop, secondPloop) && + checkFusionStructuralLegality(firstPloop, secondPloop) && succeeded(verifyDependencies(firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)); } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index e7496cd97cd63..fab6592d9eb2a 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1074,8 +1074,8 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, // Fusion related helpers //===----------------------------------------------------------------------===// -bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target, - LoopLikeOpInterface &source) { +bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target, + LoopLikeOpInterface source) { bool iterSpaceEq = target.getLoopLowerBounds() == source.getLoopLowerBounds() && target.getLoopUpperBounds() == source.getLoopUpperBounds() && From a6d0588da17170b1d3653efb51704b10d770dc58 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 27 Jun 2024 11:31:03 -0500 Subject: [PATCH 28/33] check if isOpSibling in checkFusionStructuralLegality --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 6 +- .../SCF/TransformOps/SCFTransformOps.cpp | 84 +---------------- .../SCF/Transforms/ParallelLoopFusion.cpp | 3 +- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 91 ++++++++++++++++++- .../SCF/transform-loop-fuse-sibling.mlir | 3 +- 5 files changed, 99 insertions(+), 88 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index ca3ab0aeae1de..59aeff2da14ea 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -160,9 +160,11 @@ void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, // Fusion related helpers //===----------------------------------------------------------------------===// -/// Check structural compatibility between two loops such as iteration space. +/// Check structural compatibility between two loops such as iteration space +/// and dominance. bool checkFusionStructuralLegality(LoopLikeOpInterface target, - LoopLikeOpInterface source); + LoopLikeOpInterface source, + Diagnostic &diag); /// Given two scf.forall loops, `target` and `source`, fuses `target` into /// `source`. Assumes that the given loops are siblings and are independent of diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 0e13b503098f0..3e0a483615a3d 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -425,78 +425,6 @@ void transform::TakeAssumedBranchOp::getEffects( // LoopFuseSiblingOp //===----------------------------------------------------------------------===// -/// Check if `target` and `source` are siblings, in the context that `target` -/// is being fused into `source`. -/// -/// This is a simple check that just checks if both operations are in the same -/// block and some checks to ensure that the fused IR does not violate -/// dominance. -static DiagnosedSilenceableFailure isOpSibling(Operation *target, - Operation *source) { - // Check if both operations are same. - if (target == source) - return emitSilenceableFailure(source) - << "target and source need to be different loops"; - - // Check if both operations are in the same block. - if (target->getBlock() != source->getBlock()) - return emitSilenceableFailure(source) - << "target and source are not in the same block"; - - // Check if fusion will violate dominance. - DominanceInfo domInfo(source); - if (target->isBeforeInBlock(source)) { - // Since `target` is before `source`, all users of results of `target` - // need to be dominated by `source`. - for (Operation *user : target->getUsers()) { - if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { - return emitSilenceableFailure(target) - << "user of results of target should be properly dominated by " - "source"; - } - } - } else { - // Since `target` is after `source`, all values used by `target` need - // to dominate `source`. - - // Check if operands of `target` are dominated by `source`. - for (Value operand : target->getOperands()) { - Operation *operandOp = operand.getDefiningOp(); - // Operands without defining operations are block arguments. When `target` - // and `source` occur in the same block, these operands dominate `source`. - if (!operandOp) - continue; - - // Operand's defining operation should properly dominate `source`. - if (!domInfo.properlyDominates(operandOp, source, - /*enclosingOpOk=*/false)) - return emitSilenceableFailure(target) - << "operands of target should be properly dominated by source"; - } - - // Check if values used by `target` are dominated by `source`. - bool failed = false; - OpOperand *failedValue = nullptr; - visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { - Operation *operandOp = operand->get().getDefiningOp(); - if (operandOp && !domInfo.properlyDominates(operandOp, source, - /*enclosingOpOk=*/false)) { - // `operand` is not an argument of an enclosing block and the defining - // op of `operand` is outside `target` but does not dominate `source`. - failed = true; - failedValue = operand; - } - }); - - if (failed) - return emitSilenceableFailure(failedValue->getOwner()) - << "values used inside regions of target should be properly " - "dominated by source"; - } - - return DiagnosedSilenceableFailure::success(); -} - DiagnosedSilenceableFailure transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -520,14 +448,10 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, return emitSilenceableFailure(target->getLoc()) << "target or source is not a loop op"; - // Check if the target and source are siblings. - DiagnosedSilenceableFailure diag = isOpSibling(target, source); - if (!diag.succeeded()) - return diag; - - if (!mlir::checkFusionStructuralLegality(target, source)) - return emitSilenceableFailure(target->getLoc()) - << "operations cannot be fused"; + // Check if loops can be fused + Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error); + if (!mlir::checkFusionStructuralLegality(target, source, diag)) + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); Operation *fusedLoop; // TODO: Support fusion for loop-like ops besides scf.for, scf.forall diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index fd57a9228186e..b46535078dd8b 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -136,9 +136,10 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref mayAlias) { + Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark); return !hasNestedParallelOp(firstPloop) && !hasNestedParallelOp(secondPloop) && - checkFusionStructuralLegality(firstPloop, secondPloop) && + checkFusionStructuralLegality(firstPloop, secondPloop, diag) && succeeded(verifyDependencies(firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)); } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index fab6592d9eb2a..b1a367281a6ca 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -1074,8 +1075,86 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, // Fusion related helpers //===----------------------------------------------------------------------===// +/// Check if `target` and `source` are siblings, in the context that `target` +/// is being fused into `source`. +/// +/// This is a simple check that just checks if both operations are in the same +/// block and some checks to ensure that the fused IR does not violate +/// dominance. +static bool isOpSibling(Operation *target, Operation *source, + Diagnostic &diag) { + // Check if both operations are same. + if (target == source) { + diag << "target and source need to be different loops"; + return false; + } + + // Check if both operations are in the same block. + if (target->getBlock() != source->getBlock()) { + diag << "target and source are not in the same block"; + return false; + } + + // Check if fusion will violate dominance. + DominanceInfo domInfo(source); + if (target->isBeforeInBlock(source)) { + // Since `target` is before `source`, all users of results of `target` + // need to be dominated by `source`. + for (Operation *user : target->getUsers()) { + if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { + diag << "user of results of target should " + "be properly dominated by source"; + return false; + } + } + } else { + // Since `target` is after `source`, all values used by `target` need + // to dominate `source`. + + // Check if operands of `target` are dominated by `source`. + for (Value operand : target->getOperands()) { + Operation *operandOp = operand.getDefiningOp(); + // Operands without defining operations are block arguments. When `target` + // and `source` occur in the same block, these operands dominate `source`. + if (!operandOp) + continue; + + // Operand's defining operation should properly dominate `source`. + if (!domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) { + diag << "operands of target should be properly dominated by source"; + return false; + } + } + + // Check if values used by `target` are dominated by `source`. + bool failed = false; + OpOperand *failedValue = nullptr; + visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { + Operation *operandOp = operand->get().getDefiningOp(); + if (operandOp && !domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) { + // `operand` is not an argument of an enclosing block and the defining + // op of `operand` is outside `target` but does not dominate `source`. + failed = true; + failedValue = operand; + } + }); + + if (failed) { + diag << "values used inside regions of target should be properly " + "dominated by source"; + diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation"; + return false; + } + } + + return true; +} + bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target, - LoopLikeOpInterface source) { + LoopLikeOpInterface source, + Diagnostic &diag) { bool iterSpaceEq = target.getLoopLowerBounds() == source.getLoopLowerBounds() && target.getLoopUpperBounds() == source.getLoopUpperBounds() && @@ -1085,9 +1164,13 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target, // TODO: Decouple checks on concrete loop types and move this function // somewhere for general utility for `LoopLikeOpInterface` if (forAllTarget && forAllSource) - return iterSpaceEq && - forAllTarget.getMapping() == forAllSource.getMapping(); - return iterSpaceEq; + iterSpaceEq = + iterSpaceEq && forAllTarget.getMapping() == forAllSource.getMapping(); + if (!iterSpaceEq) { + diag << "target and source iteration spaces must be equal"; + return false; + } + return isOpSibling(target, source, diag); } scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index 46c6be36c3271..b03aa5cf38bfa 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -335,8 +335,9 @@ func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>, %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> scf.yield %6 : tensor<128xf32> } - %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) { // expected-error @below {{values used inside regions of target should be properly dominated by source}} + %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) { + // expected-note @below {{see operation}} %dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> %dup5 = arith.addf %dup3, %dup2 : vector<16xf32> From ff47980d71330f65ecf05451f4d2345145a24e21 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 27 Jun 2024 11:46:02 -0500 Subject: [PATCH 29/33] remove extra dominance check --- mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index b46535078dd8b..95ec8861aee2b 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -158,13 +158,6 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, mayAlias)) return; - DominanceInfo dom; - // We are fusing first loop into second, make sure there are no users of the - // first loop results between loops. - for (Operation *user : firstPloop->getUsers()) - if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) - return; - IRRewriter rewriter(builder); secondPloop = mlir::fuseIndependentSiblingParallelLoops( firstPloop, secondPloop, rewriter); From c6847ec9212aa1754ad27c16f568a8d16346197d Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 27 Jun 2024 12:34:20 -0500 Subject: [PATCH 30/33] address more review comments --- mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp | 6 ++---- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 1 + 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 3e0a483615a3d..8c93554f4016e 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -440,10 +440,8 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, << "source handle (got " << llvm::range_size(sourceOps) << ")"; } - LoopLikeOpInterface target = - dyn_cast(*targetOps.begin()); - LoopLikeOpInterface source = - dyn_cast(*sourceOps.begin()); + auto target = dyn_cast(*targetOps.begin()); + auto source = dyn_cast(*sourceOps.begin()); if (!target || !source) return emitSilenceableFailure(target->getLoc()) << "target or source is not a loop op"; diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index b1a367281a6ca..666b67517f4d4 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1220,6 +1220,7 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, // interface methods. scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops( scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); Block *block1 = target.getBody(); Block *block2 = source.getBody(); auto term1 = cast(block1->getTerminator()); From f50c6aa14b36836950cc47909d4cca03d5ede8e3 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 27 Jun 2024 12:55:08 -0500 Subject: [PATCH 31/33] add more lit tests for scf.parallel --- .../SCF/transform-loop-fuse-sibling.mlir | 144 ++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index b03aa5cf38bfa..1d46a3d88f47d 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -100,6 +100,116 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func @fuse_two_parallel_reverse +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { +func.func @fuse_two_parallel_reverse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 +// CHECK: [[SUM:%.*]] = memref.alloc() + %sum = memref.alloc() : memref<2x2xf32> +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] +// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] +// CHECK-NOT: scf.parallel +// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] +// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: scf.reduce +// CHECK: } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + scf.reduce + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> + scf.reduce + } +// CHECK: memref.dealloc [[SUM]] + memref.dealloc %sum : memref<2x2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#1 into %parallel#0 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @fuse_reductions_two +// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32) +func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) +// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) +// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32) +// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]] +// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] +// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) { +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32 + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %init2 = arith.constant 2.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + scf.reduce(%B_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + return %res1, %res2 : f32, f32 +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + // CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) { // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index @@ -382,3 +492,37 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func @non_matching_iteration_spaces_err(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 + %sum = memref.alloc() : memref<2x2xf32> + // expected-error @below {{target and source iteration spaces must be equal}} + scf.parallel (%i) = (%c0) to (%c2) step (%c1) { + %B_elem = memref.load %B[%i, %c0] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i, %c0] : memref<2x2xf32> + scf.reduce + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> + scf.reduce + } + memref.dealloc %sum : memref<2x2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} From 6dd68c1b8408f05acaff9d040d4a686044295fcc Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 27 Jun 2024 13:24:11 -0500 Subject: [PATCH 32/33] check for equal loop types in checkFusionStructuralLegality --- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 5 +++ .../SCF/transform-loop-fuse-sibling.mlir | 33 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 666b67517f4d4..0c966bf182cbd 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1155,6 +1155,11 @@ static bool isOpSibling(Operation *target, Operation *source, bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target, LoopLikeOpInterface source, Diagnostic &diag) { + if (target->getName() != source->getName()) { + diag << "target and source must be same loop type"; + return false; + } + bool iterSpaceEq = target.getLoopLowerBounds() == source.getLoopLowerBounds() && target.getLoopUpperBounds() == source.getLoopUpperBounds() && diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index 1d46a3d88f47d..505013d328962 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -526,3 +526,36 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func @non_matching_loop_types_err(%A: memref<2xf32>, %B: memref<2xf32>) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 + %sum = memref.alloc() : memref<2xf32> + // expected-error @below {{target and source must be same loop type}} + scf.for %i = %c0 to %c2 step %c1 { + %B_elem = memref.load %B[%i] : memref<2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i] : memref<2xf32> + } + scf.parallel (%i) = (%c0) to (%c2) step (%c1) { + %sum_elem = memref.load %sum[%i] : memref<2xf32> + %A_elem = memref.load %A[%i] : memref<2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i] : memref<2xf32> + scf.reduce + } + memref.dealloc %sum : memref<2xf32> + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %fused = transform.loop.fuse_sibling %0 into %1 : (!transform.any_op,!transform.any_op) -> !transform.any_op + transform.yield + } +} From 99d821b47cb731ac7a12b60c44d88af5ad2fb0d1 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 27 Jun 2024 13:42:20 -0500 Subject: [PATCH 33/33] address more comments --- mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp | 1 - mlir/lib/Dialect/SCF/Utils/Utils.cpp | 8 +++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index 95ec8861aee2b..b775f988576e3 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -161,7 +161,6 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, IRRewriter rewriter(builder); secondPloop = mlir::fuseIndependentSiblingParallelLoops( firstPloop, secondPloop, rewriter); - ; } void mlir::scf::naivelyFuseParallelOps( diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 0c966bf182cbd..a79aef34e48b1 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1164,13 +1164,11 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target, target.getLoopLowerBounds() == source.getLoopLowerBounds() && target.getLoopUpperBounds() == source.getLoopUpperBounds() && target.getLoopSteps() == source.getLoopSteps(); - auto forAllTarget = dyn_cast(*target); - auto forAllSource = dyn_cast(*source); // TODO: Decouple checks on concrete loop types and move this function // somewhere for general utility for `LoopLikeOpInterface` - if (forAllTarget && forAllSource) - iterSpaceEq = - iterSpaceEq && forAllTarget.getMapping() == forAllSource.getMapping(); + if (auto forAllTarget = dyn_cast(*target)) + iterSpaceEq = iterSpaceEq && forAllTarget.getMapping() == + cast(*source).getMapping(); if (!iterSpaceEq) { diag << "target and source iteration spaces must be equal"; return false;