diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 707b63ff9335b..b949b06631484 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -199,6 +199,24 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, return true; } +static bool skipViewLike(Operation *source0, Operation *source1) { + bool viewLikeCheck = true; + auto assumeAlignOp = dyn_cast_or_null(source0); + if (assumeAlignOp && source0 == source1) { + Value sourceMemRef = assumeAlignOp.getMemref(); + Operation *sourceOp = sourceMemRef.getDefiningOp(); + return isa_and_nonnull(sourceOp); + } + + if (source0 && isa_and_nonnull(source0)) + return true; + + if (source1 && isa_and_nonnull(source1)) + return true; + + return false; +} + void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip) { bool changed = true; @@ -312,12 +330,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, transferRead.getPermutationMap() != transferWrite.getPermutationMap()) return WalkResult::advance(); - auto *source = transferRead.getBase().getDefiningOp(); - if (source && isa_and_nonnull(source)) - return WalkResult::advance(); + auto *source0 = transferRead.getBase().getDefiningOp(); + auto *source1 = transferWrite.getBase().getDefiningOp(); - source = transferWrite.getBase().getDefiningOp(); - if (source && isa_and_nonnull(source)) + if (skipViewLike(source0, source1)) return WalkResult::advance(); // TODO: may want to memoize this information for performance but it diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 318edca73cce1..c58074e40c5f4 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -802,3 +802,55 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// Test hoisting of vector.transfer_read/transfer_write pairs with same location +// and this location is marked with assume_align. + +// CHECK-LABEL: func.func @hoist_vector_transfer_read_write() { +// CHECK: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c256 = arith.constant 256 : index +// CHECK-NEXT: %c4096 = arith.constant 4096 : index +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f16 +// CHECK-NEXT: %alloc = memref.alloc() : memref<4096x4096xf16> +// CHECK-NEXT: %alloc_0 = memref.alloc() : memref<4096x4096xf16> +// CHECK-NEXT: %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16> +// CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> +// CHECK-NEXT: %1 = scf.for %arg0 = %c256 to %c4096 step %c256 iter_args(%arg1 = %0) -> (vector<16x16xf16>) { +// CHECK-NEXT: %2 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> +// CHECK-NEXT: %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %2, %2, %arg1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> +// CHECK-NEXT: scf.yield %3 : vector<16x16xf16> +// CHECK-NEXT: } +// CHECK-NEXT: vector.transfer_write %1, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16> +// CHECK-NEXT: return +// CHECK-NEXT: } + +func.func @hoist_vector_transfer_read_write() { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %cst_0 = arith.constant 0.000000e+00 : f16 + %m0 = memref.alloc() : memref<4096x4096xf16> + %m1 = memref.alloc() : memref<4096x4096xf16> + %assume_align_0 = memref.assume_alignment %m0, 64 : memref<4096x4096xf16> + %assume_align_1 = memref.assume_alignment %m1, 64 : memref<4096x4096xf16> + scf.for %arg0 = %c256 to %c4096 step %c256 { + %1 = vector.transfer_read %assume_align_0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> + %2 = vector.transfer_read %m1[%arg0, %arg0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> + %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %2, %2, %1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + vector.transfer_write %3, %assume_align_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +}