diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index e053e6c97e143..c12ed7f5d0180 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -18,6 +18,9 @@ class FuncOp; namespace bufferization { struct OneShotBufferizationOptions; +/// Maps from symbol table to its corresponding dealloc helper function. +using DeallocHelperMap = llvm::DenseMap; + //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// @@ -46,7 +49,7 @@ std::unique_ptr createLowerDeallocationsPass(); /// Adds the conversion pattern of the `bufferization.dealloc` operation to the /// given pattern set for use in other transformation passes. void populateBufferizationDeallocLoweringPattern( - RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc); + RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap); /// Construct the library function needed for the fully generic /// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass. diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp index 2aae39f51b940..f9903071be084 100644 --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -132,27 +132,30 @@ struct BufferizationToMemRefPass return; } - func::FuncOp helperFuncOp; + bufferization::DeallocHelperMap deallocHelperFuncMap; if (auto module = dyn_cast(getOperation())) { OpBuilder builder = OpBuilder::atBlockBegin(&module.getBodyRegion().front()); - SymbolTable symbolTable(module); // Build dealloc helper function if there are deallocs. getOperation()->walk([&](bufferization::DeallocOp deallocOp) { - if (deallocOp.getMemrefs().size() > 1) { - helperFuncOp = bufferization::buildDeallocationLibraryFunction( - builder, getOperation()->getLoc(), symbolTable); - return WalkResult::interrupt(); + Operation *symtableOp = + deallocOp->getParentWithTrait(); + if (deallocOp.getMemrefs().size() > 1 && + !deallocHelperFuncMap.contains(symtableOp)) { + SymbolTable symbolTable(symtableOp); + func::FuncOp helperFuncOp = + bufferization::buildDeallocationLibraryFunction( + builder, getOperation()->getLoc(), symbolTable); + deallocHelperFuncMap[symtableOp] = helperFuncOp; } - return WalkResult::advance(); }); } RewritePatternSet patterns(&getContext()); patterns.add(patterns.getContext()); - bufferization::populateBufferizationDeallocLoweringPattern(patterns, - helperFuncOp); + bufferization::populateBufferizationDeallocLoweringPattern( + patterns, deallocHelperFuncMap); ConversionTarget target(getContext()); target.addLegalDialectgetParentWithTrait(); rewriter.create( - op.getLoc(), deallocHelperFunc, + op.getLoc(), deallocHelperFuncMap.lookup(symtableOp), SmallVector{castedDeallocMemref, castedRetainMemref, castedCondsMemref, castedDeallocCondsMemref, castedRetainCondsMemref}); @@ -338,9 +339,11 @@ class DeallocOpConversion } public: - DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc) + DeallocOpConversion( + MLIRContext *context, + const bufferization::DeallocHelperMap &deallocHelperFuncMap) : OpConversionPattern(context), - deallocHelperFunc(deallocHelperFunc) {} + deallocHelperFuncMap(deallocHelperFuncMap) {} LogicalResult matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor, @@ -360,7 +363,8 @@ class DeallocOpConversion if (adaptor.getMemrefs().size() == 1) return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter); - if (!deallocHelperFunc) + Operation *symtableOp = op->getParentWithTrait(); + if (!deallocHelperFuncMap.contains(symtableOp)) return op->emitError( "library function required for generic lowering, but cannot be " "automatically inserted when operating on functions"); @@ -369,7 +373,7 @@ class DeallocOpConversion } private: - func::FuncOp deallocHelperFunc; + const bufferization::DeallocHelperMap &deallocHelperFuncMap; }; } // namespace @@ -385,26 +389,29 @@ struct LowerDeallocationsPass return; } - func::FuncOp helperFuncOp; + bufferization::DeallocHelperMap deallocHelperFuncMap; if (auto module = dyn_cast(getOperation())) { OpBuilder builder = OpBuilder::atBlockBegin(&module.getBodyRegion().front()); - SymbolTable symbolTable(module); // Build dealloc helper function if there are deallocs. getOperation()->walk([&](bufferization::DeallocOp deallocOp) { - if (deallocOp.getMemrefs().size() > 1) { - helperFuncOp = bufferization::buildDeallocationLibraryFunction( - builder, getOperation()->getLoc(), symbolTable); - return WalkResult::interrupt(); + Operation *symtableOp = + deallocOp->getParentWithTrait(); + if (deallocOp.getMemrefs().size() > 1 && + !deallocHelperFuncMap.contains(symtableOp)) { + SymbolTable symbolTable(symtableOp); + func::FuncOp helperFuncOp = + bufferization::buildDeallocationLibraryFunction( + builder, getOperation()->getLoc(), symbolTable); + deallocHelperFuncMap[symtableOp] = helperFuncOp; } - return WalkResult::advance(); }); } RewritePatternSet patterns(&getContext()); - bufferization::populateBufferizationDeallocLoweringPattern(patterns, - helperFuncOp); + bufferization::populateBufferizationDeallocLoweringPattern( + patterns, deallocHelperFuncMap); ConversionTarget target(getContext()); target.addLegalDialect(patterns.getContext(), deallocLibraryFunc); + RewritePatternSet &patterns, + const bufferization::DeallocHelperMap &deallocHelperFuncMap) { + patterns.add(patterns.getContext(), + deallocHelperFuncMap); } std::unique_ptr mlir::bufferization::createLowerDeallocationsPass() { diff --git a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir index 5fedd45555fcd..edffcbdd0ba7d 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir @@ -154,3 +154,44 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32> // CHECK-NEXT: memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]] // CHECK-NEXT: } // CHECK-NEXT: return + +// ----- + +// This test check dealloc_helper function is generated on each nested symbol +// table operation when needed and only generated once. +module @conversion_nest_module_dealloc_helper { + func.func @top_level_func(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) { + %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>) + func.return %0#0, %0#1 : i1, i1 + } + module @nested_module_not_need_dealloc_helper { + func.func @nested_module_not_need_dealloc_helper_func(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) { + %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>) + return %0#0, %0#1 : i1, i1 + } + } + module @nested_module_need_dealloc_helper { + func.func @nested_module_need_dealloc_helper_func0(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) { + %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>) + func.return %0#0, %0#1 : i1, i1 + } + func.func @nested_module_need_dealloc_helper_func1(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) { + %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>) + func.return %0#0, %0#1 : i1, i1 + } + } +} + +// CHECK: module @conversion_nest_module_dealloc_helper { +// CHECK: func.func @top_level_func +// CHECK: call @dealloc_helper +// CHECK: module @nested_module_not_need_dealloc_helper { +// CHECK: func.func @nested_module_not_need_dealloc_helper_func +// CHECK-NOT: @dealloc_helper +// CHECK: module @nested_module_need_dealloc_helper { +// CHECK: func.func @nested_module_need_dealloc_helper_func0 +// CHECK: call @dealloc_helper +// CHECK: func.func @nested_module_need_dealloc_helper_func1 +// CHECK: call @dealloc_helper +// CHECK: func.func private @dealloc_helper +// CHECK: func.func private @dealloc_helper