Skip to content

Commit 8b1f69c

Browse files
committed
vector.step support
1 parent facf473 commit 8b1f69c

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5182,6 +5182,9 @@ static LogicalResult isContiguousIndices(Value indexVec) {
51825182
if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
51835183
return failure();
51845184

5185+
if (indexVec.getDefiningOp<StepOp>())
5186+
return success();
5187+
51855188
DenseIntElementsAttr elements;
51865189
if (!matchPattern(indexVec, m_Constant(&elements)))
51875190
return failure();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2874,6 +2874,22 @@ func.func @contiguous_gather_const_mask(%base: memref<?xf32>,
28742874

28752875
// -----
28762876

2877+
// CHECK-LABEL: @contiguous_gather_step
2878+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
2879+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2880+
// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2881+
// CHECK: return %[[R]]
2882+
func.func @contiguous_gather_step(%base: memref<?xf32>,
2883+
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
2884+
%c0 = arith.constant 0 : index
2885+
%indices = vector.step : vector<16xindex>
2886+
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
2887+
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2888+
return %1 : vector<16xf32>
2889+
}
2890+
2891+
// -----
2892+
28772893
// CHECK-LABEL: @contiguous_scatter
28782894
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
28792895
// CHECK: %[[C0:.*]] = arith.constant 0 : index
@@ -2902,3 +2918,18 @@ func.func @contiguous_scatter_const_mask(%base: memref<?xf32>,
29022918
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
29032919
return
29042920
}
2921+
2922+
// -----
2923+
2924+
// CHECK-LABEL: @contiguous_scatter_step
2925+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
2926+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2927+
// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
2928+
func.func @contiguous_scatter_step(%base: memref<?xf32>,
2929+
%mask: vector<16xi1>, %value: vector<16xf32>) {
2930+
%c0 = arith.constant 0 : index
2931+
%indices = vector.step : vector<16xindex>
2932+
vector.scatter %base[%c0][%indices], %mask, %value :
2933+
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
2934+
return
2935+
}

0 commit comments

Comments
 (0)