Skip to content

Commit eddfd30

Browse files
committed
[MLIR][Vector] Add Lowering for vector.step
Currently, the lowering for vector.step lives under a folder. This is not ideal if we want to do transformation on it and defer the materizaliztion of the constants much later. This commits adds a rewrite pattern + transform op to do this instead. Thus enabling more control on the lowering.
1 parent e19a5fc commit eddfd30

File tree

13 files changed

+163
-30
lines changed

13 files changed

+163
-30
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2946,7 +2946,6 @@ def Vector_StepOp : Vector_Op<"step", [Pure]> {
29462946
%1 = vector.step : vector<[4]xindex> ; [0, 1, .., <vscale * 4 - 1>]
29472947
```
29482948
}];
2949-
let hasFolder = 1;
29502949
let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
29512950
let assemblyFormat = "attr-dict `:` type($result)";
29522951
}

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,4 +453,13 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
453453
let assemblyFormat = "attr-dict";
454454
}
455455

456+
def ApplyLowerStepToArithConstantOp : Op<Transform_Dialect,
457+
"apply_patterns.vector.step_to_arith_constant",
458+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
459+
let description = [{
460+
Convert vector.step to arith if not using scalable vectors.
461+
}];
462+
let assemblyFormat = "attr-dict";
463+
}
464+
456465
#endif // VECTOR_TRANSFORM_OPS

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ void populateVectorTransferPermutationMapLoweringPatterns(
235235
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
236236
PatternBenefit benefit = 1);
237237

238+
/// Populate the pattern set with the following patterns:
239+
///
240+
/// [StepToArithConstantOp]
241+
/// Convert vector.step op into arith ops if not using scalable vectors
242+
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
243+
PatternBenefit benefit = 1);
244+
238245
/// Populate the pattern set with the following patterns:
239246
///
240247
/// [FlattenGather]

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1886,6 +1886,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
18861886
MLIRContext *ctx = converter.getDialect()->getContext();
18871887
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
18881888
populateVectorInsertExtractStridedSliceTransforms(patterns);
1889+
populateVectorStepLoweringPatterns(patterns);
18891890
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
18901891
patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
18911892
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,

mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Dialect/SCF/IR/SCF.h"
2828
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
2929
#include "mlir/Dialect/Vector/IR/VectorOps.h"
30+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
3031
#include "mlir/IR/Matchers.h"
3132

3233
using namespace mlir;
@@ -664,6 +665,7 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
664665
bool enableVLAVectorization,
665666
bool enableSIMDIndex32) {
666667
assert(vectorLength > 0);
668+
vector::populateVectorStepLoweringPatterns(patterns);
667669
patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
668670
enableVLAVectorization, enableSIMDIndex32);
669671
patterns.add<ReducChainRewriter<vector::InsertElementOp>,

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6423,20 +6423,6 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
64236423
return SplatElementsAttr::get(getType(), {constOperand});
64246424
}
64256425

6426-
//===----------------------------------------------------------------------===//
6427-
// StepOp
6428-
//===----------------------------------------------------------------------===//
6429-
6430-
OpFoldResult StepOp::fold(FoldAdaptor adaptor) {
6431-
auto resultType = cast<VectorType>(getType());
6432-
if (resultType.isScalable())
6433-
return nullptr;
6434-
SmallVector<APInt> indices;
6435-
for (unsigned i = 0; i < resultType.getNumElements(); i++)
6436-
indices.push_back(APInt(/*width=*/64, i));
6437-
return DenseElementsAttr::get(resultType, indices);
6438-
}
6439-
64406426
//===----------------------------------------------------------------------===//
64416427
// WarpExecuteOnLane0Op
64426428
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
207207
populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
208208
}
209209

210+
void transform::ApplyLowerStepToArithConstantOp::populatePatterns(
211+
RewritePatternSet &patterns) {
212+
populateVectorStepLoweringPatterns(patterns);
213+
}
214+
210215
//===----------------------------------------------------------------------===//
211216
// Transform op registration
212217
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
99
LowerVectorMultiReduction.cpp
1010
LowerVectorScan.cpp
1111
LowerVectorShapeCast.cpp
12+
LowerVectorStep.cpp
1213
LowerVectorTransfer.cpp
1314
LowerVectorTranspose.cpp
1415
SubsetOpInterfaceImpl.cpp
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===- LowerVectorStep.cpp - Lower 'vector.step' operation ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to lower the
10+
// 'vector.step' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
16+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
#define DEBUG_TYPE "vector-step-lowering"
20+
21+
using namespace mlir;
22+
using namespace mlir::vector;
23+
24+
namespace {
25+
26+
struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> {
27+
using OpRewritePattern::OpRewritePattern;
28+
29+
LogicalResult matchAndRewrite(vector::StepOp stepOp,
30+
PatternRewriter &rewriter) const override {
31+
auto resultType = cast<VectorType>(stepOp.getType());
32+
if (resultType.isScalable()) {
33+
return failure();
34+
}
35+
int64_t elementCount = resultType.getNumElements();
36+
SmallVector<APInt> indices =
37+
llvm::map_to_vector(llvm::seq(elementCount),
38+
[](int64_t i) { return APInt(/*width=*/64, i); });
39+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
40+
stepOp, DenseElementsAttr::get(resultType, indices));
41+
return success();
42+
}
43+
};
44+
} // namespace
45+
46+
void mlir::vector::populateVectorStepLoweringPatterns(
47+
RewritePatternSet &patterns, PatternBenefit benefit) {
48+
patterns.add<StepToArithConstantOpRewrite>(patterns.getContext(), benefit);
49+
}

mlir/test/Dialect/Linalg/vectorization-scalable.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ module attributes {transform.with_named_sequence} {
180180
%func = transform.structured.match ops{["func.func"]} in %arg1
181181
: (!transform.any_op) -> !transform.any_op
182182
transform.apply_patterns to %func {
183+
transform.apply_patterns.vector.step_to_arith_constant
183184
transform.apply_patterns.canonicalization
184185
transform.apply_patterns.linalg.tiling_canonicalization
185186
} : !transform.any_op

0 commit comments

Comments
 (0)