diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index f29eba90c3ceb..ea4a02f2f2e77 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1293,6 +1293,11 @@ class DropInnerMostUnitDimsTransferRead if (dimsToDrop == 0) return failure(); + // Make sure that the indices to be dropped are equal 0. + // TODO: Deal with cases when the indices are not 0. + if (!llvm::all_of(readOp.getIndices().take_back(dimsToDrop), isZeroIndex)) + return failure(); + auto resultTargetVecType = VectorType::get(targetType.getShape().drop_back(dimsToDrop), targetType.getElementType(), diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir index a50c01898c62e..bb37d5b45520c 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -111,31 +111,41 @@ func.func @contiguous_inner_most_outer_dim_dyn_scalable_inner_dim(%a: index, %b: // ----- -func.func @contiguous_inner_most_dim_non_zero_idxs(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) { +func.func @contiguous_inner_most_dim_non_zero_idx(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) { %c0 = arith.constant 0 : index %f0 = arith.constant 0.0 : f32 - %1 = vector.transfer_read %A[%i, %j], %f0 : memref<16x1xf32>, vector<8x1xf32> + %1 = vector.transfer_read %A[%i, %c0], %f0 : memref<16x1xf32>, vector<8x1xf32> return %1 : vector<8x1xf32> } -// CHECK: func @contiguous_inner_most_dim_non_zero_idxs(%[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index, %[[J:.+]]: index) -> vector<8x1xf32> +// CHECK: func @contiguous_inner_most_dim_non_zero_idx(%[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index) -> vector<8x1xf32> // CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]] // CHECK-SAME: memref<16x1xf32> to memref<16xf32, strided<[1]>> // CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_0]] // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[V]] : vector<8xf32> to vector<8x1xf32> // CHECK: return %[[RESULT]] +// The index to be dropped is != 0 - this is currently not supported. +func.func @negative_contiguous_inner_most_dim_non_zero_idxs(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) { + %f0 = arith.constant 0.0 : f32 + %1 = vector.transfer_read %A[%i, %i], %f0 : memref<16x1xf32>, vector<8x1xf32> + return %1 : vector<8x1xf32> +} +// CHECK-LABEL: func @negative_contiguous_inner_most_dim_non_zero_idxs +// CHECK-NOT: memref.subview +// CHECK: vector.transfer_read + // Same as the top example within this split, but with the outer vector // dim scalable. Note that this example only makes sense when "8 = [8]" (i.e. // vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute. -func.func @contiguous_inner_most_dim_non_zero_idxs_scalable_inner_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<[8]x1xf32>) { +func.func @contiguous_inner_most_dim_non_zero_idx_scalable_inner_dim(%A: memref<16x1xf32>, %i:index) -> (vector<[8]x1xf32>) { %c0 = arith.constant 0 : index %f0 = arith.constant 0.0 : f32 - %1 = vector.transfer_read %A[%i, %j], %f0 : memref<16x1xf32>, vector<[8]x1xf32> + %1 = vector.transfer_read %A[%i, %c0], %f0 : memref<16x1xf32>, vector<[8]x1xf32> return %1 : vector<[8]x1xf32> } -// CHECK-LABEL: func @contiguous_inner_most_dim_non_zero_idxs_scalable_inner_dim( -// CHECK-SAME: %[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index, %[[J:.+]]: index) -> vector<[8]x1xf32> +// CHECK-LABEL: func @contiguous_inner_most_dim_non_zero_idx_scalable_inner_dim( +// CHECK-SAME: %[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index) -> vector<[8]x1xf32> // CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]] // CHECK-SAME: memref<16x1xf32> to memref<16xf32, strided<[1]>> // CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_0]]