Skip to content

Commit 2342b2f

Browse files
committed
fixup! fixup! fixup! [mlir][vector] Restrict DropInnerMostUnitDimsTransferWrite
Extend to xfer_read
1 parent e2a35ec commit 2342b2f

File tree

2 files changed

+67
-9
lines changed

2 files changed

+67
-9
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,9 +1300,9 @@ class DropInnerMostUnitDimsTransferRead
13001300
if (dimsToDrop == 0)
13011301
return failure();
13021302

1303-
// Make sure that the indices to be dropped are equal 0.
1304-
// TODO: Deal with cases when the indices are not 0.
1305-
if (!llvm::all_of(readOp.getIndices().take_back(dimsToDrop), isZeroIndex))
1303+
auto inBounds = readOp.getInBoundsValues();
1304+
auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1305+
if (llvm::is_contained(droppedInBounds, false))
13061306
return failure();
13071307

13081308
auto resultTargetVecType =

mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,70 @@ func.func @contiguous_inner_most_outer_dim_dyn_scalable_inner_dim(%a: index, %b:
113113

114114
// -----
115115

116-
// The index to be dropped is != 0 - this is currently not supported.
117-
func.func @negative_contiguous_inner_most_dim_non_zero_idxs(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
118-
%f0 = arith.constant 0.0 : f32
119-
%1 = vector.transfer_read %A[%i, %i], %f0 : memref<16x1xf32>, vector<8x1xf32>
116+
// Test the impact of changing the in_bounds attribute. The behaviour will
117+
// depend on whether the index is == 0 or != 0.
118+
119+
// The index to be dropped is == 0, so it's safe to collapse. The other index
120+
// should be preserved correctly.
121+
func.func @contiguous_inner_most_zero_idx_in_bounds(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
122+
%pad = arith.constant 0.0 : f32
123+
%c0 = arith.constant 0 : index
124+
%1 = vector.transfer_read %A[%i, %c0], %pad {in_bounds = [true, true]} : memref<16x1xf32>, vector<8x1xf32>
125+
return %1 : vector<8x1xf32>
126+
}
127+
// CHECK-LABEL: func.func @contiguous_inner_most_zero_idx_in_bounds(
128+
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
129+
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x1xf32> {
130+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
131+
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
132+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32, strided<[1]>>, vector<8xf32>
133+
// CHECK: vector.shape_cast %[[READ]] : vector<8xf32> to vector<8x1xf32>
134+
135+
// The index to be dropped is == 0, so it's safe to collapse. The "out of
136+
// bounds" attribute is too conservative and will be folded to "in bounds"
137+
// before the pattern runs. The other index should be preserved correctly.
138+
func.func @contiguous_inner_most_zero_idx_out_of_bounds(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
139+
%pad = arith.constant 0.0 : f32
140+
%c0 = arith.constant 0 : index
141+
%1 = vector.transfer_read %A[%i, %c0], %pad {in_bounds = [true, false]} : memref<16x1xf32>, vector<8x1xf32>
142+
return %1 : vector<8x1xf32>
143+
}
144+
// CHECK-LABEL: func.func @contiguous_inner_most_zero_idx_out_of_bounds(
145+
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
146+
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x1xf32> {
147+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
148+
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
149+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32, strided<[1]>>, vector<8xf32>
150+
// CHECK: vector.shape_cast %[[READ]] : vector<8xf32> to vector<8x1xf32>
151+
152+
// The index to be dropped is unknown, but since it's "in bounds", it has to be
153+
// == 0. It's safe to collapse the corresponding dim.
154+
func.func @contiguous_inner_most_non_zero_idx_in_bounds(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
155+
%pad = arith.constant 0.0 : f32
156+
%1 = vector.transfer_read %A[%i, %i], %pad {in_bounds = [true, true]} : memref<16x1xf32>, vector<8x1xf32>
157+
return %1 : vector<8x1xf32>
158+
}
159+
// CHECK-LABEL: func.func @contiguous_inner_most_non_zero_idx_in_bounds(
160+
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
161+
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x1xf32> {
162+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
163+
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
164+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32, strided<[1]>>, vector<8xf32>
165+
// CHECK: vector.shape_cast %[[READ]] : vector<8xf32> to vector<8x1xf32>
166+
167+
// The index to be dropped is unknown and "out of bounds" - not safe to
168+
// collapse.
169+
func.func @negative_contiguous_inner_most_non_zero_idx_out_of_bounds(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
170+
%pad = arith.constant 0.0 : f32
171+
%1 = vector.transfer_read %A[%i, %i], %pad {in_bounds = [true, false]} : memref<16x1xf32>, vector<8x1xf32>
120172
return %1 : vector<8x1xf32>
121173
}
122-
// CHECK-LABEL: func @negative_contiguous_inner_most_dim_non_zero_idxs
174+
// CHECK-LABEL: func.func @negative_contiguous_inner_most_non_zero_idx_out_of_bounds(
123175
// CHECK-NOT: memref.subview
176+
// CHECK-NOT: memref.shape_cast
124177
// CHECK: vector.transfer_read
125178

179+
126180
// -----
127181

128182
func.func @contiguous_inner_most_dim_with_subview(%A: memref<1000x1xf32>, %i:index, %ii:index) -> (vector<4x1xf32>) {
@@ -354,6 +408,9 @@ func.func @contiguous_inner_most_zero_idx_in_bounds(%arg0: memref<16x1xf32>, %ar
354408
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
355409
// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>>
356410

411+
// The index to be dropped is == 0, so it's safe to collapse. The "out of
412+
// bounds" attribute is too conservative and will be folded to "in bounds"
413+
// before the pattern runs. The other index should be preserved correctly.
357414
func.func @contiguous_inner_most_zero_idx_out_of_bounds(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
358415
%c0 = arith.constant 0 : index
359416
vector.transfer_write %arg1, %arg0[%i, %c0] {in_bounds = [true, false]} : vector<8x1xf32>, memref<16x1xf32>
@@ -369,7 +426,6 @@ func.func @contiguous_inner_most_zero_idx_out_of_bounds(%arg0: memref<16x1xf32>,
369426

370427
// The index to be dropped is unknown, but since it's "in bounds", it has to be
371428
// == 0. It's safe to collapse the corresponding dim.
372-
373429
func.func @contiguous_inner_most_dim_non_zero_idx_in_bounds(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
374430
vector.transfer_write %arg1, %arg0[%i, %i] {in_bounds = [true, true]} : vector<8x1xf32>, memref<16x1xf32>
375431
return
@@ -382,6 +438,8 @@ func.func @contiguous_inner_most_dim_non_zero_idx_in_bounds(%arg0: memref<16x1xf
382438
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
383439
// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>>
384440

441+
// The index to be dropped is unknown and "out of bounds" - not safe to
442+
// collapse.
385443
func.func @negative_contiguous_inner_most_dim_non_zero_idx_out_of_bounds(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
386444
vector.transfer_write %arg1, %arg0[%i, %i] {in_bounds = [true, false]} : vector<8x1xf32>, memref<16x1xf32>
387445
return

0 commit comments

Comments
 (0)