diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index ea4a02f2f2e77..b2005e56b1617 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1225,11 +1225,19 @@ struct FoldI1Select : public OpRewritePattern { /// Returns the number of dims can be folded away from transfer ops. It returns /// a failure if it can not determine the number of dims to be folded. -/// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and -/// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims -/// can be dropped by memref.subview ops. -/// Example 2: it returns "1" if `srcType` is the same memref type with -/// [8192, 16, 8, 1] strides. +/// +/// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and +/// `vectorType` is vector<16x16x1x1xf32> +/// (there two inner most dims can be dropped by memref.subview ops) +/// +/// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with +/// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32> +/// (only the inner most unit dim of `srcType` can be dropped) +/// +/// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and +/// `vectorType` is vector<16x16x1x[1]xf32> +/// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable +/// unit") static FailureOr getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { SmallVector srcStrides; @@ -1351,6 +1359,8 @@ class DropInnerMostUnitDimsTransferRead /// vector.transfer_write %0, %subview[%c0, %arg2, %c0] /// {in_bounds = [true, true, true]} /// : vector<1x16x16xf32>, memref<1x512x16xf32> +/// +/// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`). class DropInnerMostUnitDimsTransferWrite : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir index bb37d5b45520c..5183205db1b47 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -41,27 +41,27 @@ func.func @contiguous_inner_most_scalable_inner_dim(%in: memref<1x1x8x1xf32, str // Same as the top example within this split, but the trailing unit dim was // replaced with a dyn dim - not supported -func.func @non_unit_trailing_dim(%in: memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{ +func.func @negative_dynamic_trailing_dim(%in: memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{ %c0 = arith.constant 0 : index %cst = arith.constant 0.0 : f32 %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32> return %0 : vector<1x8x1xf32> } -// CHECK-LABEL: func @non_unit_trailing_dim +// CHECK-LABEL: func @negative_dynamic_trailing_dim // CHECK-NOT: memref.subview // CHECK-NOT: vector.shape_cast -// Same as the top example within this split, but with a scalable unit dim in -// the output vector - not supported (scalable 1 is _not_ a unit dimension). +// Same as the top example within this split, but with a "scalable unit" dim in +// the output vector - not supported (scalable 1, [1], is _not_ a unit dimension). -func.func @negative_scalable_unit_dim(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x[1]xf32>{ +func.func @negative_scalable_one_trailing_dim(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x[1]xf32>{ %c0 = arith.constant 0 : index %cst = arith.constant 0.0 : f32 %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x[1]xf32> return %0 : vector<1x8x[1]xf32> } -// CHECK-LABEL: func @negative_scalable_unit_dim +// CHECK-LABEL: func @negative_scalable_one_trailing_dim // CHECK-NOT: memref.subview // CHECK-NOT: vector.shape_cast @@ -254,14 +254,14 @@ func.func @negative_non_unit_inner_memref_dim(%arg0: memref<4x8xf32>) -> vector< // 2. vector.transfer_write //----------------------------------------------------------------------------- -func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) { +func.func @drop_two_inner_most_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) { %c0 = arith.constant 0 : index vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true]} : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32> return } -// CHECK: func.func @drop_two_inner_most_dim_for_transfer_write +// CHECK: func.func @drop_two_inner_most_dim // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] // CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] @@ -272,16 +272,67 @@ func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1 // CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]] // CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]] +// Same as the top example within this split, but with the inner vector +// dim scalable. Note that this example only makes sense when "16 = [16]" (i.e. +// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute. + +func.func @drop_two_inner_most_dim_scalable_inner_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x[16]x1x1xf32>, %arg2: index) { + %c0 = arith.constant 0 : index + vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] + {in_bounds = [true, true, true, true, true]} + : vector<1x16x[16]x1x1xf32>, memref<1x512x16x1x1xf32> + return +} +// CHECK: func.func @drop_two_inner_most_dim_scalable_inner_dim +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]] +// CHECK-SAME: memref<1x512x16x1x1xf32> to memref<1x512x16xf32, strided<[8192, 16, 1]>> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x[16]x1x1xf32> to vector<1x16x[16]xf32> +// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]] +// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]] + +// Same as the top example within this split, but the trailing unit dim was +// replaced with a dyn dim - not supported + +func.func @negative_dynamic_trailing_dim(%arg0: memref<1x512x16x1x?xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) { + %c0 = arith.constant 0 : index + vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] + {in_bounds = [true, true, true, true, true]} + : vector<1x16x16x1x1xf32>, memref<1x512x16x1x?xf32> + return +} +// CHECK: func.func @negative_dynamic_trailing_dim +// CHECK-NOT: memref.subview +// CHECK-NOT: vector.shape_cast + +// Same as the top example within this split, but with a "scalable unit" dim in +// the input vector - not supported (scalable 1, [1], is _not_ a unit dimension). + +func.func @negative_scalable_one_trailing_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x[1]xf32>, %arg2: index) { + %c0 = arith.constant 0 : index + vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] + {in_bounds = [true, true, true, true, true]} + : vector<1x16x16x1x[1]xf32>, memref<1x512x16x1x1xf32> + return +} + +// CHECK: func.func @negative_scalable_one_trailing_dim +// CHECK-NOT: memref.subview +// CHECK-NOT: vector.shape_cast + // ----- -func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) { +func.func @drop_inner_most_dim(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) { %c0 = arith.constant 0 : index vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> return } -// CHECK: func.func @drop_inner_most_dim_for_transfer_write +// CHECK: func.func @drop_inner_most_dim // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] // CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] @@ -294,14 +345,14 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, // ----- -func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref>, %arg1: vector<1x16x16x1xf32>, %arg2: index) { +func.func @outer_dyn_drop_inner_most_dim(%arg0: memref>, %arg1: vector<1x16x16x1xf32>, %arg2: index) { %c0 = arith.constant 0 : index vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<1x16x16x1xf32>, memref> return } -// CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write +// CHECK: func.func @outer_dyn_drop_inner_most_dim // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] // CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] @@ -325,30 +376,3 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o // The inner most unit dims can not be dropped if the strides are not ones. // CHECK: func.func @non_unit_strides // CHECK-NOT: memref.subview - -// ----- - -func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %vec: vector<[4]x1xf32>) { - %c0 = arith.constant 0 : index - vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x1xf32>, memref<24x1xf32> - return -} -// CHECK: func.func @leading_scalable_dimension_transfer_write -// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>> -// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32> -// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>> - -// ----- - -// Negative test: [1] (scalable 1) is _not_ a unit dimension. -func.func @trailing_scalable_one_dim_transfer_write(%dest : memref<24x1xf32>, %vec: vector<4x[1]xf32>, %index: index) { - %c0 = arith.constant 0 : index - vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<4x[1]xf32>, memref<24x1xf32> - return -} -// CHECK: func.func @trailing_scalable_one_dim_transfer_write -// CHECK-NOT: vector.shape_cast -// CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32> -// CHECK-NOT: vector.shape_cast