@@ -70,41 +70,10 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
70
70
71
71
// -----
72
72
73
- // The shape of the memref and the vector don't match, but the vector is a
74
- // contiguous subset of the memref, so "flattenable". The leading unit dimensions
75
- // of the vector have no effect on the memref area read even if they
76
- // span a non-contiguous part of the memref.
77
-
78
- func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
79
- %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
80
-
81
- %c0 = arith.constant 0 : index
82
- %cst = arith.constant 0 : i8
83
- %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
84
- memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
85
- return %res : vector <1 x1 x2 x2 xi8 >
86
- }
87
-
88
- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
89
- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
90
- // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
91
- // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
92
- // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
93
- // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
94
- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
95
- // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<4xi8>
96
- // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
97
- // CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
98
-
99
- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
100
- // CHECK-128B: memref.collapse_shape
101
-
102
- // -----
103
-
104
73
// The shape of the memref and the vector don't match, but the vector is a
105
74
// contiguous subset of the memref, so "flattenable"
106
75
107
- func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
76
+ func.func @transfer_read_dims_mismatch_contiguous (
108
77
%mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x3 x2 xi8 > {
109
78
110
79
%c0 = arith.constant 0 : index
@@ -114,7 +83,7 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
114
83
return %res : vector <2 x3 x2 xi8 >
115
84
}
116
85
117
- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
86
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous (
118
87
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
119
88
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
120
89
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -126,9 +95,73 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
126
95
// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
127
96
// CHECK: return %[[VEC]] : vector<2x3x2xi8>
128
97
129
- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
98
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
99
+ // CHECK-128B: memref.collapse_shape
100
+
101
+ // -----
102
+
103
+ // The shape of the memref and the vector don't match, but the mismatch is only
104
+ // at the leading unit dimensions of the vector.
105
+
106
+ func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
107
+ %mem : memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x4 x3 x2 xi8 > {
108
+
109
+ %c0 = arith.constant 0 : index
110
+ %cst = arith.constant 0 : i8
111
+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 , %c0 ], %cst :
112
+ memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x4 x3 x2 xi8 >
113
+ return %res : vector <1 x1 x4 x3 x2 xi8 >
114
+ }
115
+
116
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
117
+ // CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
118
+ // CHECK-SAME: -> vector<1x1x4x3x2xi8>
119
+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
120
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
121
+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
122
+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
123
+ // CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
124
+ // CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
125
+ // CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
126
+ // CHECK-SAME: {in_bounds = [true]} : memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>, vector<24xi8>
127
+ // CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<24xi8> to vector<1x1x4x3x2xi8>
128
+ // CHECK: return %[[VEC]] : vector<1x1x4x3x2xi8>
129
+
130
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
131
+ // CHECK-128B: memref.collapse_shape
132
+
133
+ // -----
134
+
135
+ // The memref is non-contiguous, but the vector is a contiguous subset of the
136
+ // memref, so "flattenable". The leading unit dimensions of the vector have no
137
+ // effect on the memref area read even if they span the non-contiguous part of
138
+ // the memref.
139
+
140
+ func.func @transfer_read_non_contiguous_unit_dims (
141
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x3 x2 xi8 > {
142
+
143
+ %c0 = arith.constant 0 : index
144
+ %cst = arith.constant 0 : i8
145
+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
146
+ memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x3 x2 xi8 >
147
+ return %res : vector <1 x1 x3 x2 xi8 >
148
+ }
149
+
150
+ // CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
151
+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
152
+ // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
153
+ // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
154
+ // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
155
+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
156
+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
157
+ // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<6xi8>
158
+ // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<6xi8> to vector<1x1x3x2xi8>
159
+ // CHECK: return %[[VAL_5]] : vector<1x1x3x2xi8>
160
+
161
+ // CHECK-128B-LABEL: func @transfer_read_non_contiguous_unit_dims(
130
162
// CHECK-128B: memref.collapse_shape
131
163
164
+
132
165
// -----
133
166
134
167
func.func @transfer_read_dims_mismatch_non_zero_indices (
@@ -414,61 +447,92 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
414
447
// -----
415
448
416
449
// The shape of the memref and the vector don't match, but the vector is a
417
- // contiguous subset of the memref, so "flattenable". The leading unit dimensions
418
- // of the vector have no effect on the memref area written even if they
419
- // span a non-contiguous part of the memref.
450
+ // contiguous subset of the memref, so "flattenable".
420
451
421
- func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
422
- %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
423
- %vec : vector <1 x 1 x 2 x 2 x i8 >) {
452
+ func.func @transfer_write_dims_mismatch_contiguous (
453
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
454
+ %vec : vector <2 x 2 x i8 >) {
424
455
425
456
%c0 = arith.constant 0 : index
426
457
vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
427
- vector <1 x 1 x 2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
458
+ vector <2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
428
459
return
429
460
}
430
461
431
- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
432
- // CHECK-SAME: %[[MEM:.* ]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?> >,
433
- // CHECK-SAME: %[[VEC:.* ]]: vector<1x1x2x2xi8>) {
434
- // CHECK: %[[C0:.* ]] = arith.constant 0 : index
435
- // CHECK: %[[COLLAPSED:.* ]] = memref.collapse_shape %[[MEM]]
462
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
463
+ // CHECK-SAME: %[[MEM:.+ ]]: memref<5x4x3x2xi8, {{.+}} >,
464
+ // CHECK-SAME: %[[VEC:.+ ]]: vector<2x2xi8>
465
+ // CHECK: %[[C0:.+ ]] = arith.constant 0 : index
466
+ // CHECK: %[[COLLAPSED_MEM:.+ ]] = memref.collapse_shape %[[MEM]]
436
467
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
437
- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
438
- // CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
439
- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
440
- // CHECK-SAME: {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
468
+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
469
+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
470
+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
471
+ // CHECK-SAME: : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
472
+
473
+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
474
+ // CHECK-128B: memref.collapse_shape
475
+
476
+ // -----
477
+
478
+ // The shape of the memref and the vector don't match, but the mismatch is only
479
+ // at the leading unit dimensions of the vector.
480
+
481
+ func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
482
+ %mem : memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>,
483
+ %vec : vector <1 x1 x4 x3 x2 xi8 >) {
484
+
485
+ %c0 = arith.constant 0 : index
486
+ vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 , %c0 ] :
487
+ vector <1 x1 x4 x3 x2 xi8 >, memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>
488
+
489
+ return
490
+ }
491
+
492
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
493
+ // CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
494
+ // CHECK-SAME: %[[VEC:.+]]: vector<1x1x4x3x2xi8>
495
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
496
+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
497
+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
498
+ // CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
499
+ // CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
500
+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x4x3x2xi8> to vector<24xi8>
501
+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
502
+ // CHECK-SAME: {in_bounds = [true]} : vector<24xi8>, memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
441
503
442
504
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
443
505
// CHECK-128B: memref.collapse_shape
444
506
445
507
// -----
446
508
447
- // The shape of the memref and the vector don't match, but the vector is a
448
- // contiguous subset of the memref, so "flattenable".
509
+ // The memref is non-contiguous, but the vector is a contiguous subset of the
510
+ // memref, so "flattenable". The leading unit dimensions of the vector have no
511
+ // effect on the memref area read even if they span the non-contiguous part of
512
+ // the memref.
449
513
450
- func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
451
- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
452
- %vec : vector <2 x 2 x i8 >) {
514
+ func.func @transfer_write_non_contiguous_unit_dims (
515
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
516
+ %vec : vector <1 x 1 x 3 x 2 x i8 >) {
453
517
454
518
%c0 = arith.constant 0 : index
455
519
vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
456
- vector <2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
520
+ vector <1 x 1 x 3 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
457
521
return
458
522
}
459
523
460
- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
461
- // CHECK-SAME: %[[MEM:.+ ]]: memref<5x4x3x2xi8, {{.+}} >,
462
- // CHECK-SAME: %[[VEC:.+ ]]: vector<2x2xi8>
463
- // CHECK: %[[C0:.+ ]] = arith.constant 0 : index
464
- // CHECK: %[[COLLAPSED_MEM:.+ ]] = memref.collapse_shape %[[MEM]]
524
+ // CHECK-LABEL: func.func @transfer_write_non_contiguous_unit_dims
525
+ // CHECK-SAME: %[[MEM:.* ]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?> >,
526
+ // CHECK-SAME: %[[VEC:.* ]]: vector<1x1x3x2xi8>) {
527
+ // CHECK: %[[C0:.* ]] = arith.constant 0 : index
528
+ // CHECK: %[[COLLAPSED:.* ]] = memref.collapse_shape %[[MEM]]
465
529
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
466
- // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}} >
467
- // CHECK: %[[VEC_1D:.+ ]] = vector.shape_cast %[[VEC]] : vector<2x2xi8 > to vector<4xi8 >
468
- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM ]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
469
- // CHECK-SAME: : vector<4xi8 >, memref<5x4x6xi8, {{.+}} >
530
+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?> >
531
+ // CHECK: %[[VEC_1D:.* ]] = vector.shape_cast %[[VEC]] : vector<1x1x3x2xi8 > to vector<6xi8 >
532
+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED ]][%[[C0]], %[[C0]], %[[C0]]]
533
+ // CHECK-SAME: {in_bounds = [true]} : vector<6xi8 >, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?> >
470
534
471
- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
535
+ // CHECK-128B-LABEL: func @transfer_write_non_contiguous_unit_dims (
472
536
// CHECK-128B: memref.collapse_shape
473
537
474
538
// -----
@@ -714,4 +778,3 @@ func.func @negative_out_of_bound_transfer_write(
714
778
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
715
779
// CHECK-128B-NOT: memref.collapse_shape
716
780
// CHECK-128B-NOT: vector.shape_cast
717
-
0 commit comments