Skip to content

Commit 71c2b3c

Browse files
authored
[CPU] Tile all the ops to target vector sizes before vectorization. (iree-org#21900)
The revision introduces a pass that iterates all the compute ops and tiles them to target vector sizes, if any dimension is not yet tiled with the size. It uses the ValueBounds analysis to infer the tiling sizes. If it is not inferrable, it assumes that the op is already within target vector size. Because it usually implies that the op is fused with some tiling config, and the size computation is too complicated. E.g., it can happen on linalg.unpack ops. The revision prevents huge vector failure, and provides a reasonable fallback. The `linalg.fill` op is excluded because it usually goes with the corresponding reduction op and there may be issues in lowering config propagation. It is a fair stopgap in practice. --------- Signed-off-by: hanhanW <[email protected]>
1 parent 5365054 commit 71c2b3c

File tree

8 files changed

+296
-0
lines changed

8 files changed

+296
-0
lines changed

compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ iree_compiler_cc_library(
7171
"LLVMCPUSynchronizeSymbolVisibility.cpp",
7272
"LLVMCPUTile.cpp",
7373
"LLVMCPUTileAndFuseProducerConsumer.cpp",
74+
"LLVMCPUTileToVectorSize.cpp",
7475
"LLVMCPUUnfuseFMAOps.cpp",
7576
"LLVMCPUVectorShapeCastLowering.cpp",
7677
"LLVMCPUVectorTransposeLowering.cpp",

compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ iree_cc_library(
6767
"LLVMCPUSynchronizeSymbolVisibility.cpp"
6868
"LLVMCPUTile.cpp"
6969
"LLVMCPUTileAndFuseProducerConsumer.cpp"
70+
"LLVMCPUTileToVectorSize.cpp"
7071
"LLVMCPUUnfuseFMAOps.cpp"
7172
"LLVMCPUVectorShapeCastLowering.cpp"
7273
"LLVMCPUVectorTransposeLowering.cpp"
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
8+
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
9+
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
10+
#include "llvm/Support/DebugLog.h"
11+
#include "llvm/Support/InterleavedRange.h"
12+
#include "llvm/Support/LogicalResult.h"
13+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14+
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
15+
#include "mlir/Dialect/SCF/IR/SCF.h"
16+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
17+
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
18+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
19+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
20+
#include "mlir/IR/BuiltinTypeInterfaces.h"
21+
#include "mlir/IR/Iterators.h"
22+
#include "mlir/Interfaces/TilingInterface.h"
23+
#include "mlir/Pass/Pass.h"
24+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25+
26+
#define DEBUG_TYPE "iree-llvmcpu-tile-to-vector-size"
27+
28+
namespace mlir::iree_compiler {
29+
30+
#define GEN_PASS_DEF_LLVMCPUTILETOVECTORSIZEPASS
31+
#include "iree/compiler/Codegen/LLVMCPU/Passes.h.inc"
32+
33+
namespace {
34+
35+
struct LLVMCPUTileToVectorSizePass final
36+
: impl::LLVMCPUTileToVectorSizePassBase<LLVMCPUTileToVectorSizePass> {
37+
using impl::LLVMCPUTileToVectorSizePassBase<
38+
LLVMCPUTileToVectorSizePass>::LLVMCPUTileToVectorSizePassBase;
39+
40+
void getDependentDialects(DialectRegistry &registry) const override {
41+
registry.insert<arith::ArithDialect, scf::SCFDialect>();
42+
}
43+
44+
void runOnOperation() override;
45+
};
46+
47+
static std::optional<SmallVector<int64_t>>
48+
getTileSizesForEachDims(linalg::LinalgOp op) {
49+
IREE::Codegen::LoweringConfigAttrInterface loweringConfig =
50+
getLoweringConfig(op);
51+
SmallVector<bool> scalableFlags = loweringConfig.getVectorScalableFlags();
52+
if (llvm::count(scalableFlags, true) > 0) {
53+
return std::nullopt;
54+
}
55+
56+
unsigned numLoops = op.getNumLoops();
57+
std::optional<SmallVector<int64_t>> vectorSizes =
58+
loweringConfig.getVectorSizes();
59+
if (!vectorSizes || vectorSizes->size() != numLoops) {
60+
return std::nullopt;
61+
}
62+
LDBG() << "configured vector sizes: "
63+
<< llvm::interleaved_array(vectorSizes.value());
64+
65+
SmallVector<int64_t> result(numLoops, 0);
66+
for (unsigned dim = 0; dim < numLoops; ++dim) {
67+
SmallVector<std::pair<Value, unsigned>> operandDimPairs;
68+
op.mapIterationSpaceDimToAllOperandDims(dim, operandDimPairs);
69+
if (operandDimPairs.empty()) {
70+
return std::nullopt;
71+
}
72+
73+
Value firstOperand = operandDimPairs[0].first;
74+
unsigned firstOperandDim = operandDimPairs[0].second;
75+
76+
// Trivial case: `dim` size is available in the operand type.
77+
int64_t dimSize = llvm::cast<ShapedType>(firstOperand.getType())
78+
.getShape()[firstOperandDim];
79+
int64_t vectorDimSize = vectorSizes.value()[dim];
80+
if (ShapedType::isStatic(dimSize) && dimSize > vectorDimSize) {
81+
LDBG() << "set dim #" << dim << " size (" << dimSize
82+
<< ") with vector size: " << vectorDimSize;
83+
result[dim] = vectorDimSize;
84+
continue;
85+
}
86+
87+
// If a `tensor.extract_slice` op can not be found, the operand is not tiled
88+
// at all. It implies that the dimension is not yet tiled. `tensor.empty` is
89+
// part of tiling artifacts that can be used to infer tiling sizes.
90+
if (!isa_and_present<tensor::EmptyOp, tensor::ExtractSliceOp>(
91+
firstOperand.getDefiningOp())) {
92+
LDBG() << "set dim #" << dim
93+
<< " size (untiled) with vector size: " << vectorDimSize;
94+
result[dim] = vectorDimSize;
95+
continue;
96+
}
97+
98+
// Use ValueBounds analysis to infer `dim` size upper bound.
99+
std::optional<int64_t> maybeDimSize;
100+
FailureOr<DimBoundSize> maybeDimBound;
101+
for (auto [operand, operandDim] : operandDimPairs) {
102+
FailureOr<int64_t> maybeDimBoundSize =
103+
ValueBoundsConstraintSet::computeConstantBound(
104+
presburger::BoundType::UB, {operand, operandDim},
105+
/*stopCondition=*/nullptr, /*closedUB=*/true);
106+
if (succeeded(maybeDimBoundSize)) {
107+
maybeDimSize = maybeDimBoundSize.value();
108+
break;
109+
}
110+
}
111+
// Assume that the unknown dimension size implies the dimension is already
112+
// tiled. It means that the dimension is definitely tiled, but it is hard to
113+
// infer the tile size. It usually happens in fusion case, so the pass
114+
// assumes that it is not needed.
115+
if (maybeDimSize && maybeDimSize.value() > vectorDimSize) {
116+
LDBG() << "set dim #" << dim << " size (" << maybeDimSize.value()
117+
<< ") with vector size: " << vectorDimSize;
118+
result[dim] = vectorDimSize;
119+
} else {
120+
LDBG() << "dim #" << dim << " either is tiled to vector size ("
121+
<< vectorDimSize << ") or has complex size computation";
122+
}
123+
}
124+
125+
return result;
126+
}
127+
128+
void LLVMCPUTileToVectorSizePass::runOnOperation() {
129+
MLIRContext *context = &getContext();
130+
FunctionOpInterface funcOp = getOperation();
131+
SmallVector<linalg::LinalgOp> candidates;
132+
funcOp.walk([&](linalg::LinalgOp op) {
133+
// XXX(hanchung): linalg.fill usually follow the reduction consumer ops, so
134+
// the additional tiling is not needed. Otherwise, it results in an
135+
// additional loops before converting it to a vector. We may need to fix the
136+
// lowering config issue, but it is a fair stopgap in practice.
137+
if (isa<linalg::FillOp>(op)) {
138+
return;
139+
}
140+
IREE::Codegen::LoweringConfigAttrInterface loweringConfig =
141+
getLoweringConfig(op);
142+
if (!loweringConfig) {
143+
return;
144+
}
145+
if (!loweringConfig.getVectorSizes().has_value()) {
146+
return;
147+
}
148+
candidates.push_back(op);
149+
});
150+
151+
IRRewriter rewriter(context);
152+
for (linalg::LinalgOp op : candidates) {
153+
LDBG() << "candidate: " << op;
154+
std::optional<SmallVector<int64_t>> tileSizes = getTileSizesForEachDims(op);
155+
if (!tileSizes) {
156+
LDBG() << "all the dimensions are either tiled or target scalable tile "
157+
"sizes";
158+
continue;
159+
}
160+
if (llvm::all_of(tileSizes.value(), [](int64_t val) { return val == 0; })) {
161+
LDBG() << "skip the op because tile sizes are all zeros";
162+
continue;
163+
}
164+
LDBG() << "tileSizes: " << llvm::interleaved_array(tileSizes.value());
165+
166+
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
167+
scf::SCFTilingOptions options;
168+
setSCFTileSizes(options, tilingInterfaceOp, std::move(tileSizes.value()),
169+
/*tileScalableFlags=*/{});
170+
FailureOr<scf::SCFTilingResult> tiledResults =
171+
scf::tileUsingSCF(rewriter, tilingInterfaceOp, options);
172+
if (failed(tiledResults)) {
173+
LDBG() << "failed to tile the op";
174+
return signalPassFailure();
175+
}
176+
rewriter.replaceOp(op, tiledResults->replacements);
177+
}
178+
179+
RewritePatternSet patterns =
180+
linalg::getLinalgTilingCanonicalizationPatterns(context);
181+
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
182+
tensor::populateFoldTensorEmptyPatterns(patterns);
183+
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
184+
context->getLoadedDialect<tensor::TensorDialect>()
185+
->getCanonicalizationPatterns(patterns);
186+
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
187+
LDBG() << "----- cleanup failed -----";
188+
return signalPassFailure();
189+
}
190+
}
191+
} // namespace
192+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ void addMultiTilingExpertPassPipeline(
293293
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
294294
funcPassManager.addPass(createCSEPass());
295295
}
296+
funcPassManager.addPass(createLLVMCPUTileToVectorSizePass());
296297

297298
GenericVectorizationPassOptions options;
298299
options.useConfiguredVectorSizes = pipelineOpt.useConfiguredVectorSizes;

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,19 @@ def LLVMCPUTilePass :
144144
];
145145
}
146146

147+
def LLVMCPUTileToVectorSizePass :
148+
InterfacePass<"iree-llvmcpu-tile-to-vector-size", "mlir::FunctionOpInterface"> {
149+
let summary = "Tile TilingInterface operations to target vector size.";
150+
let description = [{
151+
Walk through all the TilingInterface operations and tiling the dimensions to
152+
target vector sizes, if the lowering config is present and the dimension is
153+
known as greater than the vector size.
154+
155+
It is intended to be used before vectorization that avoids big vectors and
156+
stack buffers.
157+
}];
158+
}
159+
147160
def LLVMCPUTileAndFuseProducerConsumerPass
148161
: InterfacePass<"iree-llvmcpu-tile-and-fuse-producer-consumer",
149162
"mlir::FunctionOpInterface"> {

compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ iree_lit_test_suite(
6161
"tile.mlir",
6262
"tile_and_fuse_producer_consumer_anchoring_last_op.mlir",
6363
"tile_and_fuse_producer_consumer_anchoring_root_op.mlir",
64+
"tile_to_vector_size.mlir",
6465
"unfused_fma.mlir",
6566
"vector_contract_to_arm_asm.mlir",
6667
"vector_contract_to_arm_intrinsics.mlir",

compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ iree_lit_test_suite(
5656
"tile.mlir"
5757
"tile_and_fuse_producer_consumer_anchoring_last_op.mlir"
5858
"tile_and_fuse_producer_consumer_anchoring_root_op.mlir"
59+
"tile_to_vector_size.mlir"
5960
"unfused_fma.mlir"
6061
"vector_contract_to_arm_asm.mlir"
6162
"vector_contract_to_arm_intrinsics.mlir"
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-to-vector-size))" --split-input-file %s | FileCheck %s
2+
3+
#config = #iree_cpu.lowering_config<vector_common_parallel = [10, 20, 0], vector_reduction = [0, 0, 30]>
4+
func.func @matmul_all_dims_untiled(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
5+
%0 = linalg.matmul {lowering_config = #config}
6+
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
7+
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
8+
return %0 : tensor<?x?xf32>
9+
}
10+
// CHECK-LABEL: func.func @matmul_all_dims_untiled(
11+
// CHECK: scf.for
12+
// CHECK: scf.for
13+
// CHECK: scf.for
14+
// CHECK: linalg.matmul
15+
16+
// -----
17+
18+
#config = #iree_cpu.lowering_config<vector_common_parallel = [10, 20, 0, 0], vector_reduction = [0, 0, 30, 30]>
19+
func.func @invalid_matmul_vector_config(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
20+
%0 = linalg.matmul {lowering_config = #config}
21+
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
22+
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
23+
return %0 : tensor<?x?xf32>
24+
}
25+
// CHECK-LABEL: func.func @invalid_matmul_vector_config(
26+
// CHECK-NOT: scf.for
27+
// CHECK: linalg.matmul
28+
29+
// -----
30+
31+
#config = #iree_cpu.lowering_config<vector_common_parallel = [10, 30, 0], vector_reduction = [0, 0, 20]>
32+
func.func @static_matmul_with_vector_size(%arg0 : tensor<10x20xf32>, %arg1 : tensor<20x30xf32>, %arg2 : tensor<10x30xf32>) -> tensor<10x30xf32> {
33+
%0 = linalg.matmul {lowering_config = #config}
34+
ins(%arg0, %arg1 : tensor<10x20xf32>, tensor<20x30xf32>)
35+
outs(%arg2 : tensor<10x30xf32>) -> tensor<10x30xf32>
36+
return %0 : tensor<10x30xf32>
37+
}
38+
// CHECK-LABEL: func.func @static_matmul_with_vector_size(
39+
// CHECK-NOT: scf.for
40+
// CHECK: linalg.matmul
41+
42+
// -----
43+
44+
#config = #iree_cpu.lowering_config<vector_common_parallel = [10, 30, 0], vector_reduction = [0, 0, 20]>
45+
func.func @static_matmul_with_untiled_K_dim(%arg0 : tensor<10x40xf32>, %arg1 : tensor<40x30xf32>, %arg2 : tensor<10x30xf32>) -> tensor<10x30xf32> {
46+
%0 = linalg.matmul {lowering_config = #config}
47+
ins(%arg0, %arg1 : tensor<10x40xf32>, tensor<40x30xf32>)
48+
outs(%arg2 : tensor<10x30xf32>) -> tensor<10x30xf32>
49+
return %0 : tensor<10x30xf32>
50+
}
51+
// CHECK-LABEL: func.func @static_matmul_with_untiled_K_dim(
52+
// CHECK: %[[C20:.+]] = arith.constant 20 : index
53+
// CHECK: scf.for
54+
// CHECK-SAME: step %[[C20]]
55+
// CHECK-NOT: scf.for
56+
// CHECK: linalg.matmul
57+
58+
// -----
59+
60+
#map = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
61+
#map1 = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
62+
#map2 = affine_map<(d0)[s0] -> (-d0 + s0, 60)>
63+
#config = #iree_cpu.lowering_config<vector_common_parallel = [10, 20, 0], vector_reduction = [0, 0, 30]>
64+
func.func @matmul_tiled_MxNxK_to_10x20x60(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
65+
%c0 = arith.constant 0 : index
66+
%c1 = arith.constant 1 : index
67+
%M = tensor.dim %arg0, %c0 : tensor<?x?xf32>
68+
%N = tensor.dim %arg1, %c1 : tensor<?x?xf32>
69+
%K = tensor.dim %arg0, %c1 : tensor<?x?xf32>
70+
%mSize = affine.min #map(%c0)[%M]
71+
%nSize = affine.min #map1(%c0)[%N]
72+
%kSize = affine.min #map2(%c0)[%K]
73+
%lhs = tensor.extract_slice %arg0 [0, 0][%mSize, %kSize][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
74+
%rhs = tensor.extract_slice %arg1 [0, 0][%kSize, %nSize][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
75+
%acc = tensor.extract_slice %arg2 [0, 0][%mSize, %nSize][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
76+
%0 = linalg.matmul {lowering_config = #config}
77+
ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>)
78+
outs(%acc : tensor<?x?xf32>) -> tensor<?x?xf32>
79+
return %0 : tensor<?x?xf32>
80+
}
81+
// CHECK-LABEL: func.func @matmul_tiled_MxNxK_to_10x20x60(
82+
// CHECK: %[[C30:.+]] = arith.constant 30 : index
83+
// CHECK: scf.for
84+
// CHECK-SAME: step %[[C30]]
85+
// CHECK-NOT: scf.for
86+
// CHECK: linalg.matmul

0 commit comments

Comments
 (0)