Skip to content

Commit 358f5da

Browse files
[fixup] Add/change a few tests
1 parent 0761dcc commit 358f5da

File tree

1 file changed

+130
-67
lines changed

1 file changed

+130
-67
lines changed

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 130 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -70,41 +70,10 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
7070

7171
// -----
7272

73-
// The shape of the memref and the vector don't match, but the vector is a
74-
// contiguous subset of the memref, so "flattenable". The leading unit dimensions
75-
// of the vector have no effect on the memref area read even if they
76-
// span a non-contiguous part of the memref.
77-
78-
func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
79-
%mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
80-
81-
%c0 = arith.constant 0 : index
82-
%cst = arith.constant 0 : i8
83-
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
84-
memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
85-
return %res : vector<1x1x2x2xi8>
86-
}
87-
88-
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
89-
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
90-
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
91-
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
92-
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
93-
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
94-
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
95-
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<4xi8>
96-
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
97-
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
98-
99-
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
100-
// CHECK-128B: memref.collapse_shape
101-
102-
// -----
103-
10473
// The shape of the memref and the vector don't match, but the vector is a
10574
// contiguous subset of the memref, so "flattenable"
10675

107-
func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
76+
func.func @transfer_read_dims_mismatch_contiguous(
10877
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x3x2xi8> {
10978

11079
%c0 = arith.constant 0 : index
@@ -114,7 +83,7 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
11483
return %res : vector<2x3x2xi8>
11584
}
11685

117-
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
86+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
11887
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
11988
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
12089
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -126,9 +95,73 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
12695
// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
12796
// CHECK: return %[[VEC]] : vector<2x3x2xi8>
12897

129-
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
98+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
99+
// CHECK-128B: memref.collapse_shape
100+
101+
// -----
102+
103+
// The shape of the memref and the vector don't match, but the mismatch is only
104+
// at the leading unit dimensions of the vector.
105+
106+
func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
107+
%mem : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>) -> vector<1x1x4x3x2xi8> {
108+
109+
%c0 = arith.constant 0 : index
110+
%cst = arith.constant 0 : i8
111+
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0, %c0], %cst :
112+
memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>, vector<1x1x4x3x2xi8>
113+
return %res : vector<1x1x4x3x2xi8>
114+
}
115+
116+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
117+
// CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
118+
// CHECK-SAME: -> vector<1x1x4x3x2xi8>
119+
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
120+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
121+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
122+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
123+
// CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
124+
// CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
125+
// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
126+
// CHECK-SAME: {in_bounds = [true]} : memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>, vector<24xi8>
127+
// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<24xi8> to vector<1x1x4x3x2xi8>
128+
// CHECK: return %[[VEC]] : vector<1x1x4x3x2xi8>
129+
130+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
131+
// CHECK-128B: memref.collapse_shape
132+
133+
// -----
134+
135+
// The memref is non-contiguous, but the vector is a contiguous subset of the
136+
// memref, so "flattenable". The leading unit dimensions of the vector have no
137+
// effect on the memref area read even if they span the non-contiguous part of
138+
// the memref.
139+
140+
func.func @transfer_read_non_contiguous_unit_dims(
141+
%mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
142+
143+
%c0 = arith.constant 0 : index
144+
%cst = arith.constant 0 : i8
145+
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
146+
memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>, vector<1x1x3x2xi8>
147+
return %res : vector<1x1x3x2xi8>
148+
}
149+
150+
// CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
151+
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
152+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
153+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
154+
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
155+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
156+
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
157+
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<6xi8>
158+
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<6xi8> to vector<1x1x3x2xi8>
159+
// CHECK: return %[[VAL_5]] : vector<1x1x3x2xi8>
160+
161+
// CHECK-128B-LABEL: func @transfer_read_non_contiguous_unit_dims(
130162
// CHECK-128B: memref.collapse_shape
131163

164+
132165
// -----
133166

134167
func.func @transfer_read_dims_mismatch_non_zero_indices(
@@ -414,61 +447,92 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
414447
// -----
415448

416449
// The shape of the memref and the vector don't match, but the vector is a
417-
// contiguous subset of the memref, so "flattenable". The leading unit dimensions
418-
// of the vector have no effect on the memref area written even if they
419-
// span a non-contiguous part of the memref.
450+
// contiguous subset of the memref, so "flattenable".
420451

421-
func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
422-
%mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
423-
%vec : vector<1x1x2x2xi8>) {
452+
func.func @transfer_write_dims_mismatch_contiguous(
453+
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
454+
%vec : vector<2x2xi8>) {
424455

425456
%c0 = arith.constant 0 : index
426457
vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
427-
vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>
458+
vector<2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
428459
return
429460
}
430461

431-
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
432-
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
433-
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
434-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
435-
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
462+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
463+
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
464+
// CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
465+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
466+
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
436467
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
437-
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
438-
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
439-
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
440-
// CHECK-SAME: {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
468+
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
469+
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
470+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
471+
// CHECK-SAME: : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
472+
473+
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
474+
// CHECK-128B: memref.collapse_shape
475+
476+
// -----
477+
478+
// The shape of the memref and the vector don't match, but the mismatch is only
479+
// at the leading unit dimensions of the vector.
480+
481+
func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
482+
%mem : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>,
483+
%vec : vector<1x1x4x3x2xi8>) {
484+
485+
%c0 = arith.constant 0 : index
486+
vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0, %c0] :
487+
vector<1x1x4x3x2xi8>, memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
488+
489+
return
490+
}
491+
492+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
493+
// CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
494+
// CHECK-SAME: %[[VEC:.+]]: vector<1x1x4x3x2xi8>
495+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
496+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
497+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
498+
// CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
499+
// CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
500+
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x4x3x2xi8> to vector<24xi8>
501+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
502+
// CHECK-SAME: {in_bounds = [true]} : vector<24xi8>, memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
441503

442504
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
443505
// CHECK-128B: memref.collapse_shape
444506

445507
// -----
446508

447-
// The shape of the memref and the vector don't match, but the vector is a
448-
// contiguous subset of the memref, so "flattenable".
509+
// The memref is non-contiguous, but the vector is a contiguous subset of the
510+
// memref, so "flattenable". The leading unit dimensions of the vector have no
511+
// effect on the memref area read even if they span the non-contiguous part of
512+
// the memref.
449513

450-
func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
451-
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
452-
%vec : vector<2x2xi8>) {
514+
func.func @transfer_write_non_contiguous_unit_dims(
515+
%mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
516+
%vec : vector<1x1x3x2xi8>) {
453517

454518
%c0 = arith.constant 0 : index
455519
vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
456-
vector<2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
520+
vector<1x1x3x2xi8>, memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>
457521
return
458522
}
459523

460-
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
461-
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
462-
// CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
463-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
464-
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
524+
// CHECK-LABEL: func.func @transfer_write_non_contiguous_unit_dims
525+
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
526+
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x3x2xi8>) {
527+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
528+
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
465529
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
466-
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
467-
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
468-
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
469-
// CHECK-SAME: : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
530+
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
531+
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x3x2xi8> to vector<6xi8>
532+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
533+
// CHECK-SAME: {in_bounds = [true]} : vector<6xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
470534

471-
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
535+
// CHECK-128B-LABEL: func @transfer_write_non_contiguous_unit_dims(
472536
// CHECK-128B: memref.collapse_shape
473537

474538
// -----
@@ -714,4 +778,3 @@ func.func @negative_out_of_bound_transfer_write(
714778
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
715779
// CHECK-128B-NOT: memref.collapse_shape
716780
// CHECK-128B-NOT: vector.shape_cast
717-

0 commit comments

Comments
 (0)