Skip to content

Commit b879ee1

Browse files
authored
[GPU][Codegen] Expand iteration space based on new expand_dims attribute (#22342)
This patch introduces iteration space expansion for reductions in the VectorDistribute path. Specifically, we: 1. Add a new attribute, `expand_dims`, for reductions. 2. Introduce a new pass, `GPUExpandDimensions`, which uses `expand_dims` to expand the iteration space of relevant dimensions. 3. Refactor common functionality shared between `GPUExpandDimensions` and `BlockDynamicDimensions` into reusable utilities. 4. Refactor encoding helpers from `EncodingAttrs.cpp` into reusable utilities. This change also enables [chain FMA](#21855) in matvec codegen as we iterate along the K reduction dimension. --- **Performance Summary** **IREE benchmark module** * Only expansion: ~4% improvement * Expansion + chain FMA: ~11% improvement **rocprof** * Only expansion: ~13% worse * Expansion + chain FMA: ~9% better **Register usage** * 10% reduction (60 → 54 registers for matvec dispatches) **Instruction latency (post-reduction loop epilogue)** * 3.5% improvement (340 → 328 total mean latency) --- **Notes** * As a follow-up, we can explore applying iteration space expansion to the reduction in attention * Right now, we only expand one dimension into two although the implementation supports expansion to N dimensions. * Please note this PR changes the reduction order, so expect some minor changes to the numerics * This is does not improve performance by itself/can cause regression without chain FMA #21855 Traces for matvec dispatches are attached for all variations (original, only expansion, and expansion + chain FMA). [115_expansion_and_chain.tar.gz](https://github.com/user-attachments/files/23268046/115_expansion_and_chain.tar.gz) [115_nothing.tar.gz](https://github.com/user-attachments/files/23268047/115_nothing.tar.gz) [115_only_expansion.tar.gz](https://github.com/user-attachments/files/23268048/115_only_expansion.tar.gz) Fixes: #22153 ci-extra: test_torch --------- Signed-off-by: Eric Feng <[email protected]> Signed-off-by: Eric Feng <[email protected]>
1 parent 8dadc59 commit b879ee1

27 files changed

+902
-203
lines changed

compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,6 @@ using TensorDivisibilityInfo =
3636

3737
namespace {
3838

39-
struct RemoveOptimizationBarrier final
40-
: public OpRewritePattern<IREE::Util::OptimizationBarrierOp> {
41-
using Base::Base;
42-
43-
LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp barrierOp,
44-
PatternRewriter &rewriter) const override {
45-
rewriter.replaceOp(barrierOp, barrierOp.getOperands());
46-
return success();
47-
}
48-
};
49-
5039
/// This pass is used to materialize information about dynamic dimensions of
5140
/// `tensor` operands of an operation in the IR. If a dynamic dimension is
5241
/// known to be a multiple of a compile-time constant value, this pass
@@ -110,10 +99,6 @@ getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis,
11099
/// inverses of each other. The `util.optimization.barrier` avoid these from
111100
/// getting folded away during reshape propagation. Return the result of the
112101
/// `tensor.collapse_shape generated.
113-
struct ReshapeOps {
114-
tensor::ExpandShapeOp expandShapeOp;
115-
tensor::CollapseShapeOp collapseShapeOp;
116-
};
117102
static std::optional<ReshapeOps>
118103
blockDynamicDimensionsOfValue(RewriterBase &rewriter,
119104
const TensorDivisibilityInfo &divisibilityInfo,
@@ -413,7 +398,7 @@ void BlockDynamicDimensionsPass::runOnOperation() {
413398
// Delete the optimization barrier and run some further cleanup.
414399
{
415400
RewritePatternSet removeBarrierOpsPatterns(context);
416-
removeBarrierOpsPatterns.insert<RemoveOptimizationBarrier>(context);
401+
populateRemoveOptimizationBarrierPatterns(removeBarrierOpsPatterns);
417402
tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns,
418403
context);
419404
tensor::CollapseShapeOp::getCanonicalizationPatterns(

compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ iree_compiler_cc_library(
7474
"GPUDistributeScfFor.cpp",
7575
"GPUDistributeSharedMemoryCopy.cpp",
7676
"GPUDistributionPatterns.cpp",
77+
"GPUExpandDimensions.cpp",
7778
"GPUFuseAndHoistParallelLoops.cpp",
7879
"GPUGeneralizeNamedOps.cpp",
7980
"GPUGreedilyDistributeToThreads.cpp",
@@ -125,6 +126,7 @@ iree_compiler_cc_library(
125126
"//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
126127
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
127128
"//compiler/src/iree/compiler/Dialect/TensorExt/IR",
129+
"//compiler/src/iree/compiler/Dialect/Util/IR",
128130
"//compiler/src/iree/compiler/Utils",
129131
"@llvm-project//llvm:Support",
130132
"@llvm-project//mlir:AMDGPUDialect",

compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ iree_cc_library(
6767
"GPUDistributeScfFor.cpp"
6868
"GPUDistributeSharedMemoryCopy.cpp"
6969
"GPUDistributionPatterns.cpp"
70+
"GPUExpandDimensions.cpp"
7071
"GPUFuseAndHoistParallelLoops.cpp"
7172
"GPUGeneralizeNamedOps.cpp"
7273
"GPUGreedilyDistributeToThreads.cpp"
@@ -159,6 +160,7 @@ iree_cc_library(
159160
iree::compiler::Dialect::LinalgExt::Transforms
160161
iree::compiler::Dialect::LinalgExt::Utils
161162
iree::compiler::Dialect::TensorExt::IR
163+
iree::compiler::Dialect::Util::IR
162164
iree::compiler::Utils
163165
PUBLIC
164166
)
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/Transforms.h"
8+
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
9+
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
10+
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
11+
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
12+
#include "llvm/ADT/STLExtras.h"
13+
#include "llvm/ADT/SmallVectorExtras.h"
14+
#include "llvm/Support/DebugLog.h"
15+
#include "llvm/Support/LogicalResult.h"
16+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
17+
#include "mlir/Dialect/Arith/Utils/Utils.h"
18+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
19+
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
20+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21+
#include "mlir/IR/AffineExpr.h"
22+
#include "mlir/IR/AffineMap.h"
23+
#include "mlir/IR/BuiltinTypeInterfaces.h"
24+
#include "mlir/IR/OpDefinition.h"
25+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26+
27+
#define DEBUG_TYPE "iree-codegen-gpu-expand-dimensions"
28+
29+
namespace mlir::iree_compiler {
30+
31+
#define GEN_PASS_DEF_GPUEXPANDDIMENSIONSPASS
32+
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
33+
34+
namespace {
35+
36+
struct GPUExpandDimensionsPass final
37+
: impl::GPUExpandDimensionsPassBase<GPUExpandDimensionsPass> {
38+
using Base::Base;
39+
void runOnOperation() override;
40+
};
41+
} // namespace
42+
43+
// Compute the expanded shape for a reassociation group. Requires the original
44+
// dimension to be static and evenly divisible by the product of static factors
45+
// in the target shape.
46+
static FailureOr<SmallVector<OpFoldResult>> computeExpandedGroupShape(
47+
RewriterBase &rewriter, Location loc, OpFoldResult origDimSize,
48+
ArrayRef<int64_t> groupTargetShape, unsigned iteratorDim) {
49+
if (groupTargetShape.size() == 1) {
50+
return SmallVector<OpFoldResult>{origDimSize};
51+
}
52+
53+
std::optional<int64_t> staticOrigDim = getConstantIntValue(origDimSize);
54+
if (!staticOrigDim) {
55+
return rewriter.notifyMatchFailure(
56+
loc, "dimension " + Twine(iteratorDim) +
57+
" is dynamic, but expand_dims requires static dimensions");
58+
}
59+
60+
int64_t staticFactor = llvm::product_of(
61+
llvm::make_filter_range(groupTargetShape, ShapedType::isStatic));
62+
63+
if (staticFactor < 1) {
64+
return rewriter.notifyMatchFailure(
65+
loc, "invalid expansion factor " + Twine(staticFactor) +
66+
" for iterator dimension " + Twine(iteratorDim));
67+
}
68+
69+
if (staticOrigDim.value() % staticFactor != 0) {
70+
return rewriter.notifyMatchFailure(
71+
loc, "dimension " + Twine(iteratorDim) +
72+
" (size=" + Twine(staticOrigDim.value()) +
73+
") not divisible by expansion factor " + Twine(staticFactor));
74+
}
75+
76+
return llvm::map_to_vector(
77+
groupTargetShape, [&](int64_t size) -> OpFoldResult {
78+
if (ShapedType::isStatic(size)) {
79+
return rewriter.getIndexAttr(size);
80+
}
81+
AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
82+
return affine::makeComposedFoldedAffineApply(
83+
rewriter, loc, s0.floorDiv(staticFactor), {origDimSize});
84+
});
85+
}
86+
87+
// For an operation annotated with the `expand_dims` attribute, replace relevant
88+
// operands with tensor.expand_shape/tensor.collapse_shape pair to materialize
89+
// dimension expansion according to the reassociation and output_shape defined
90+
// in the attribute.
91+
//
92+
// Example:
93+
//
94+
// ```mlir
95+
// %0 = <some_op>(..., %0, ...) {
96+
// lowering_config = #iree_gpu.lowering_config<{
97+
// expand_dims = #iree_gpu.expand_dims
98+
// [[0], [1, 2]], output_shape = [?, ?, 8]>
99+
// }>
100+
// } : ... -> tensor<4x128xf32>
101+
// ```
102+
//
103+
// becomes:
104+
//
105+
// ```mlir
106+
// %expanded = tensor.expand_shape %0 [[0], [1, 2]]
107+
// : tensor<4x128xf32> into tensor<4x16x8xf32>
108+
// %barrier = util.optimization_barrier %expanded
109+
// %collapsed = tensor.collapse_shape %barrier [[0], [1, 2]]
110+
// : tensor<4x16x8xf32> into tensor<4x128xf32>
111+
// %1 = <some_op>(..., %collapsed, ...) : ... -> tensor<4x128xf32>
112+
// ```
113+
static std::optional<ReshapeOps>
114+
createDimensionExpansionOps(RewriterBase &rewriter,
115+
IREE::GPU::DimensionExpansionAttr config, Value v,
116+
AffineMap indexingMap, linalg::LinalgOp op) {
117+
auto tensorType = dyn_cast<RankedTensorType>(v.getType());
118+
if (!tensorType) {
119+
return std::nullopt;
120+
}
121+
122+
Location loc = v.getLoc();
123+
MLIRContext *ctx = op.getContext();
124+
int64_t tensorRank = tensorType.getRank();
125+
ArrayRef<int64_t> outputShape = config.getOutputShape().asArrayRef();
126+
SmallVector<OpFoldResult> origShape = tensor::getMixedSizes(rewriter, loc, v);
127+
128+
// Map each tensor dimension to its expanded shape components.
129+
SmallVector<SmallVector<OpFoldResult>> expandedShapes(tensorRank);
130+
for (auto [iterDim, reassocIndices] :
131+
llvm::enumerate(config.getReassociationIndices())) {
132+
std::optional<unsigned> tensorDim =
133+
indexingMap.getResultPosition(getAffineDimExpr(iterDim, ctx));
134+
if (!tensorDim.has_value()) {
135+
continue;
136+
}
137+
138+
auto groupOutputShape = llvm::map_to_vector(
139+
reassocIndices, [&](int64_t i) { return outputShape[i]; });
140+
141+
FailureOr<SmallVector<OpFoldResult>> groupShape = computeExpandedGroupShape(
142+
rewriter, loc, origShape[tensorDim.value()], groupOutputShape, iterDim);
143+
if (failed(groupShape)) {
144+
return std::nullopt;
145+
}
146+
147+
expandedShapes[tensorDim.value()] = std::move(groupShape.value());
148+
}
149+
150+
// Build reassociation indices and expanded shape in tensor dimension order.
151+
SmallVector<ReassociationIndices> reassociation;
152+
SmallVector<OpFoldResult> expandedShape;
153+
for (auto [tensorDim, expanded] : llvm::enumerate(expandedShapes)) {
154+
ReassociationIndices &indices = reassociation.emplace_back();
155+
auto addDim = [&](OpFoldResult dim) {
156+
indices.push_back(expandedShape.size());
157+
expandedShape.push_back(dim);
158+
};
159+
if (expanded.empty()) {
160+
addDim(origShape[tensorDim]);
161+
} else {
162+
llvm::for_each(expanded, addDim);
163+
}
164+
}
165+
166+
// If no expansion is needed, return early.
167+
if (llvm::equal(origShape, expandedShape)) {
168+
return std::nullopt;
169+
}
170+
171+
auto staticShape = llvm::map_to_vector(expandedShape, [](OpFoldResult ofr) {
172+
return getConstantIntValue(ofr).value();
173+
});
174+
175+
auto expandedType = RankedTensorType::get(
176+
staticShape, tensorType.getElementType(), tensorType.getEncoding());
177+
178+
auto expandOp = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, v,
179+
reassociation, expandedShape);
180+
Value barrier = IREE::Util::OptimizationBarrierOp::create(
181+
rewriter, loc, expandOp.getResult())
182+
.getResult(0);
183+
auto collapseOp = tensor::CollapseShapeOp::create(rewriter, loc, tensorType,
184+
barrier, reassociation);
185+
186+
return ReshapeOps{expandOp, collapseOp};
187+
}
188+
189+
static LogicalResult expandIterationSpace(RewriterBase &rewriter,
190+
linalg::LinalgOp op) {
191+
auto loweringConfig = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
192+
if (!loweringConfig) {
193+
return success();
194+
}
195+
auto config = IREE::GPU::getDimensionExpansion(loweringConfig);
196+
if (!config) {
197+
return success();
198+
}
199+
200+
LDBG() << "Expanding dimensions for op: " << *op;
201+
202+
for (OpOperand &operand : op->getOpOperands()) {
203+
AffineMap indexingMap = op.getMatchingIndexingMap(&operand);
204+
std::optional<ReshapeOps> reshapes = createDimensionExpansionOps(
205+
rewriter, config, operand.get(), indexingMap, op);
206+
if (reshapes.has_value()) {
207+
rewriter.modifyOpInPlace(
208+
op, [&]() { operand.set(reshapes.value().collapseShapeOp); });
209+
}
210+
}
211+
212+
return success();
213+
}
214+
215+
void GPUExpandDimensionsPass::runOnOperation() {
216+
Operation *operation = getOperation();
217+
MLIRContext *context = &getContext();
218+
IRRewriter rewriter(context);
219+
220+
SmallVector<linalg::LinalgOp> worklist;
221+
operation->walk([&](linalg::LinalgOp op) {
222+
if (auto cfg = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op)) {
223+
if (IREE::GPU::getDimensionExpansion(cfg)) {
224+
worklist.push_back(op);
225+
}
226+
}
227+
});
228+
229+
for (linalg::LinalgOp op : worklist) {
230+
rewriter.setInsertionPoint(op);
231+
if (failed(expandIterationSpace(rewriter, op))) {
232+
return signalPassFailure();
233+
}
234+
}
235+
236+
LDBG() << "After expanding dimensions: " << *operation;
237+
238+
ConfigTrackingListener listener;
239+
GreedyRewriteConfig config;
240+
config.setListener(&listener);
241+
242+
{
243+
RewritePatternSet bubbleExpandShapePatterns(context);
244+
linalg::ControlFusionFn controlFn = [](OpOperand *opOperand) {
245+
return !isa_and_nonnull<linalg::FillOp, tensor::EmptyOp>(
246+
opOperand->get().getDefiningOp());
247+
};
248+
linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
249+
controlFn);
250+
IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
251+
bubbleExpandShapePatterns, controlFn);
252+
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
253+
tensor::populateBubbleUpExpandShapePatterns(bubbleExpandShapePatterns);
254+
linalg::FillOp::getCanonicalizationPatterns(
255+
bubbleExpandShapePatterns, bubbleExpandShapePatterns.getContext());
256+
memref::populateResolveRankedShapedTypeResultDimsPatterns(
257+
bubbleExpandShapePatterns);
258+
if (failed(applyPatternsGreedily(
259+
operation, std::move(bubbleExpandShapePatterns), config))) {
260+
operation->emitOpError(
261+
"failed in application of bubble up expand shape patterns");
262+
return signalPassFailure();
263+
}
264+
}
265+
266+
LDBG() << "After reshape propagation: " << *operation;
267+
268+
{
269+
RewritePatternSet removeBarrierOpsPatterns(context);
270+
populateRemoveOptimizationBarrierPatterns(removeBarrierOpsPatterns);
271+
tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns,
272+
context);
273+
tensor::CollapseShapeOp::getCanonicalizationPatterns(
274+
removeBarrierOpsPatterns, context);
275+
tensor::populateFoldTensorEmptyPatterns(removeBarrierOpsPatterns);
276+
linalg::FillOp::getCanonicalizationPatterns(removeBarrierOpsPatterns,
277+
context);
278+
memref::populateResolveRankedShapedTypeResultDimsPatterns(
279+
removeBarrierOpsPatterns);
280+
if (failed(applyPatternsGreedily(operation,
281+
std::move(removeBarrierOpsPatterns)))) {
282+
operation->emitOpError("failed in cleanup patterns");
283+
return signalPassFailure();
284+
}
285+
}
286+
287+
return;
288+
}
289+
290+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,14 @@ def GPUApplyPaddingLevelPass :
383383
];
384384
}
385385

386+
def GPUExpandDimensionsPass :
387+
InterfacePass<"iree-codegen-gpu-expand-dimensions", "mlir::FunctionOpInterface"> {
388+
let summary = "Pass to expand tensor op dims based on `expand_dims` lowering_config";
389+
let dependentDialects = [
390+
"::mlir::iree_compiler::IREE::Util::UtilDialect"
391+
];
392+
}
393+
386394
def GPUTensorTileToSerialLoopsPass :
387395
InterfacePass<"iree-codegen-gpu-tensor-tile-to-serial-loops", "mlir::FunctionOpInterface"> {
388396
let summary = "Pass to tile reduction dimensions for certain GPU ops";

compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ iree_lit_test_suite(
3636
"gpu_distribute_forall.mlir",
3737
"gpu_distribute_scf_for.mlir",
3838
"gpu_distribute_shared_memory.mlir",
39+
"gpu_expand_dimensions.mlir",
3940
"gpu_fuse_and_hoist_forall.mlir",
4041
"gpu_generalize_named_ops.mlir",
4142
"gpu_greedily_distribute_to_threads.mlir",

compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ iree_lit_test_suite(
3232
"gpu_distribute_forall.mlir"
3333
"gpu_distribute_scf_for.mlir"
3434
"gpu_distribute_shared_memory.mlir"
35+
"gpu_expand_dimensions.mlir"
3536
"gpu_fuse_and_hoist_forall.mlir"
3637
"gpu_generalize_named_ops.mlir"
3738
"gpu_greedily_distribute_to_threads.mlir"

0 commit comments

Comments
 (0)