@@ -2874,6 +2874,22 @@ func.func @contiguous_gather_const_mask(%base: memref<?xf32>,
2874
2874
2875
2875
// -----
2876
2876
2877
+ // CHECK-LABEL: @contiguous_gather_step
2878
+ // CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
2879
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
2880
+ // CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2881
+ // CHECK: return %[[R]]
2882
+ func.func @contiguous_gather_step (%base: memref <?xf32 >,
2883
+ %mask: vector <16 xi1 >, %passthru: vector <16 xf32 >) -> vector <16 xf32 > {
2884
+ %c0 = arith.constant 0 : index
2885
+ %indices = vector.step : vector <16 xindex >
2886
+ %1 = vector.gather %base [%c0 ][%indices ], %mask , %passthru :
2887
+ memref <?xf32 >, vector <16 xindex >, vector <16 xi1 >, vector <16 xf32 > into vector <16 xf32 >
2888
+ return %1 : vector <16 xf32 >
2889
+ }
2890
+
2891
+ // -----
2892
+
2877
2893
// CHECK-LABEL: @contiguous_scatter
2878
2894
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
2879
2895
// CHECK: %[[C0:.*]] = arith.constant 0 : index
@@ -2902,3 +2918,18 @@ func.func @contiguous_scatter_const_mask(%base: memref<?xf32>,
2902
2918
memref <?xf32 >, vector <16 xi32 >, vector <16 xi1 >, vector <16 xf32 >
2903
2919
return
2904
2920
}
2921
+
2922
+ // -----
2923
+
2924
+ // CHECK-LABEL: @contiguous_scatter_step
2925
+ // CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
2926
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
2927
+ // CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
2928
+ func.func @contiguous_scatter_step (%base: memref <?xf32 >,
2929
+ %mask: vector <16 xi1 >, %value: vector <16 xf32 >) {
2930
+ %c0 = arith.constant 0 : index
2931
+ %indices = vector.step : vector <16 xindex >
2932
+ vector.scatter %base [%c0 ][%indices ], %mask , %value :
2933
+ memref <?xf32 >, vector <16 xindex >, vector <16 xi1 >, vector <16 xf32 >
2934
+ return
2935
+ }
0 commit comments