Skip to content

Commit 7c8f3a2

Browse files
committed
[Codegen] add FoldExtractSliceOfFillThroughBlockArg pattern to TileAndDistributeToWorkgroups
Signed-off-by: Bangtian Liu <[email protected]>
1 parent 7f7e190 commit 7c8f3a2

File tree

4 files changed

+153
-0
lines changed

4 files changed

+153
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
308308
// TODO(Max191): Replace populateSwapExtractWithExpandPattern with upstream
309309
// MLIR version once it is available (llvm-project/pull/126898).
310310
populateSwapExtractWithExpandPattern(cleanupPatterns);
311+
populateFoldExtractSliceOfFillThroughBlockArgPattern(cleanupPatterns);
311312
// When fusing pads we do not want to generate zeroSliceGuards when doing
312313
// workgroup tiling. In `GPUApplyTilingLevelPass` we do have an option called
313314
// `allowZeroSlices` that can control this but we do not want these
@@ -412,6 +413,7 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
412413
{
413414
RewritePatternSet patterns(context);
414415
populateSwapExtractWithCollapsePattern(patterns);
416+
populateFoldExtractSliceOfFillThroughBlockArgPattern(patterns);
415417
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
416418
tensor::populateFoldTensorEmptyPatterns(patterns);
417419
context->getOrLoadDialect<tensor::TensorDialect>()

compiler/src/iree/compiler/Codegen/Common/Transforms.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "llvm/ADT/ScopeExit.h"
1111
#include "mlir/Analysis/SliceAnalysis.h"
1212
#include "mlir/Dialect/Affine/IR/AffineOps.h"
13+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1314
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1415

1516
#define DEBUG_TYPE "iree-codegen-common-transforms"
@@ -354,6 +355,89 @@ void populateSwapExtractWithExpandPattern(RewritePatternSet &patterns) {
354355
patterns.add<SwapExpandShapeWithSlicePattern>(patterns.getContext());
355356
}
356357

358+
namespace {
359+
/// Pattern to fold extract_slice of a fill through a forall's block argument.
360+
/// When extracting a slice from a block argument where the init value is a
361+
/// linalg.fill, we update the forall's shared_outs to use the fill's
362+
/// destination (the empty tensor), and then create a fill on the extracted
363+
/// slice inside the loop body.
364+
///
365+
/// Example:
366+
/// %empty = tensor.empty() : tensor<4x1xf16>
367+
/// %fill = linalg.fill ins(%cst) outs(%empty) -> tensor<4x1xf16>
368+
/// scf.forall ... shared_outs(%arg = %fill) {
369+
/// %slice = tensor.extract_slice %arg[%i, 0] [1, 1] -> tensor<1x1xf16>
370+
/// ...
371+
/// }
372+
/// ->
373+
/// %empty = tensor.empty() : tensor<4x1xf16>
374+
/// scf.forall ... shared_outs(%arg = %empty) { // Updated to use %empty
375+
/// %extracted = tensor.extract_slice %arg[%i, 0] [1, 1] -> tensor<1x1xf16>
376+
/// %slice = linalg.fill ins(%cst) outs(%extracted) -> tensor<1x1xf16>
377+
/// ...
378+
/// }
379+
struct FoldExtractSliceOfFillThroughBlockArg final
380+
: OpRewritePattern<tensor::ExtractSliceOp> {
381+
using OpRewritePattern::OpRewritePattern;
382+
383+
LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
384+
PatternRewriter &rewriter) const override {
385+
auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
386+
if (!blockArg) {
387+
return rewriter.notifyMatchFailure(extractOp,
388+
"source is not a block argument");
389+
}
390+
391+
auto forallOp = dyn_cast<scf::ForallOp>(blockArg.getOwner()->getParentOp());
392+
if (!forallOp) {
393+
return rewriter.notifyMatchFailure(
394+
extractOp, "block argument is not from an scf.forall");
395+
}
396+
397+
unsigned argNum = blockArg.getArgNumber();
398+
unsigned numIVs = forallOp.getInductionVars().size();
399+
if (argNum < numIVs) {
400+
return rewriter.notifyMatchFailure(
401+
extractOp, "block argument is an induction variable, not shared_out");
402+
}
403+
404+
unsigned outputIdx = argNum - numIVs;
405+
if (outputIdx >= forallOp.getOutputs().size()) {
406+
return rewriter.notifyMatchFailure(extractOp,
407+
"invalid output index for block arg");
408+
}
409+
410+
Value initValue = forallOp.getOutputs()[outputIdx];
411+
412+
auto fillOp = initValue.getDefiningOp<linalg::FillOp>();
413+
if (!fillOp) {
414+
return rewriter.notifyMatchFailure(
415+
extractOp, "init value is not a linalg.fill operation");
416+
}
417+
418+
Value fillValue = fillOp.getInputs()[0];
419+
Value fillDest = fillOp.getOutputs()[0];
420+
rewriter.modifyOpInPlace(forallOp, [&]() {
421+
forallOp.getOutputsMutable()[outputIdx].set(fillDest);
422+
});
423+
424+
rewriter.setInsertionPointAfter(extractOp);
425+
Location loc = extractOp.getLoc();
426+
auto newFillOp =
427+
linalg::FillOp::create(rewriter, loc, fillValue, extractOp.getResult());
428+
rewriter.replaceAllUsesExcept(extractOp.getResult(), newFillOp.getResult(0),
429+
newFillOp);
430+
return success();
431+
}
432+
};
433+
434+
} // namespace
435+
436+
void populateFoldExtractSliceOfFillThroughBlockArgPattern(
437+
RewritePatternSet &patterns) {
438+
patterns.add<FoldExtractSliceOfFillThroughBlockArg>(patterns.getContext());
439+
}
440+
357441
/// Note the following pattern is adapted from the upstream pattern
358442
/// `BubbleUpCollapseShapeThroughExtractSlice` by allowing some special cases.
359443
///

compiler/src/iree/compiler/Codegen/Common/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ void populateReplaceSlowMinMaxOpsPatterns(RewritePatternSet &patterns);
193193
/// `tensor.expand_shape(tensor.extract_slice)`.
194194
void populateSwapExtractWithExpandPattern(RewritePatternSet &patterns);
195195

196+
/// Populate pattern to fold `tensor.extract_slice` of a `linalg.fill` through
197+
/// a forall's block argument. Creates a smaller tensor.empty and linalg.fill
198+
/// inside the loop body.
199+
void populateFoldExtractSliceOfFillThroughBlockArgPattern(
200+
RewritePatternSet &patterns);
201+
196202
/// Populate pattern to convert `tensor.extract_slice(tensor.collapse_shape)` to
197203
/// `tensor.collapse_shape(tensor.extract_slice)`.
198204
void populateSwapExtractWithCollapsePattern(RewritePatternSet &patterns);

compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,3 +1364,64 @@ attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPU
13641364
// CHECK: scf.forall.in_parallel
13651365
// CHECK: tensor.parallel_insert_slice %[[RES]] into %[[OUT0]][%[[OFFSET0]], 0, %[[OFFSET1]]]
13661366
// CHECK: {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
1367+
1368+
// -----
1369+
1370+
// Test for FoldExtractSliceOfFillThroughBlockArgPattern:
1371+
// When a forall's shared_out init is a linalg.fill, and we extract a slice
1372+
// from the block argument, the pattern should:
1373+
// 1. Change the forall's init to use the fill's destination (empty tensor)
1374+
// 2. Create a new fill on the extracted slice inside the loop
1375+
1376+
#config_fill_fold = #iree_codegen.lowering_config<tile_sizes = [[1, 8]]>
1377+
1378+
func.func @fold_fill_through_block_arg(%arg0 : tensor<4x16x128xf16>) -> (tensor<4x16xf16>, tensor<4x16xi32>) {
1379+
%cst = arith.constant 0xFC00 : f16
1380+
%c0_i32 = arith.constant 0 : i32
1381+
%c0 = arith.constant 0 : index
1382+
%empty_f16 = tensor.empty() : tensor<4x16xf16>
1383+
%empty_i32 = tensor.empty() : tensor<4x16xi32>
1384+
%fill_f16 = linalg.fill {lowering_config = #config_fill_fold}
1385+
ins(%cst : f16) outs(%empty_f16 : tensor<4x16xf16>) -> tensor<4x16xf16>
1386+
%fill_i32 = linalg.fill {lowering_config = #config_fill_fold}
1387+
ins(%c0_i32 : i32) outs(%empty_i32 : tensor<4x16xi32>) -> tensor<4x16xi32>
1388+
%result:2 = scf.forall (%iv0, %iv1) = (0, 0) to (4, 16) step (1, 8)
1389+
shared_outs(%out_f16 = %fill_f16, %out_i32 = %fill_i32) -> (tensor<4x16xf16>, tensor<4x16xi32>) {
1390+
%in_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [1, 8, 128] [1, 1, 1]
1391+
: tensor<4x16x128xf16> to tensor<1x8x128xf16>
1392+
%slice_f16 = tensor.extract_slice %out_f16[%iv0, %iv1] [1, 8] [1, 1]
1393+
: tensor<4x16xf16> to tensor<1x8xf16>
1394+
%slice_i32 = tensor.extract_slice %out_i32[%iv0, %iv1] [1, 8] [1, 1]
1395+
: tensor<4x16xi32> to tensor<1x8xi32>
1396+
%compare:2 = iree_linalg_ext.arg_compare {lowering_config = #config_fill_fold}
1397+
dimension(2) ins(%in_slice : tensor<1x8x128xf16>)
1398+
outs(%slice_f16, %slice_i32 : tensor<1x8xf16>, tensor<1x8xi32>)
1399+
index_base(%c0 : index) {
1400+
^bb0(%lhs: f16, %rhs: f16):
1401+
%cmp = arith.cmpf ogt, %lhs, %rhs : f16
1402+
iree_linalg_ext.yield %cmp : i1
1403+
} -> tensor<1x8xf16>, tensor<1x8xi32>
1404+
scf.forall.in_parallel {
1405+
tensor.parallel_insert_slice %compare#0 into %out_f16[%iv0, %iv1] [1, 8] [1, 1]
1406+
: tensor<1x8xf16> into tensor<4x16xf16>
1407+
tensor.parallel_insert_slice %compare#1 into %out_i32[%iv0, %iv1] [1, 8] [1, 1]
1408+
: tensor<1x8xi32> into tensor<4x16xi32>
1409+
}
1410+
} {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
1411+
return %result#0, %result#1 : tensor<4x16xf16>, tensor<4x16xi32>
1412+
}
1413+
1414+
// CHECK-LABEL: func.func @fold_fill_through_block_arg
1415+
// CHECK-DAG: %[[CST_F16:.+]] = arith.constant 0xFC00 : f16
1416+
// CHECK-DAG: %[[CST_I32:.+]] = arith.constant 0 : i32
1417+
// CHECK-DAG: %[[EMPTY_F16:.+]] = tensor.empty() : tensor<4x16xf16>
1418+
// CHECK-DAG: %[[EMPTY_I32:.+]] = tensor.empty() : tensor<4x16xi32>
1419+
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) = (0, 0) to (4, 16) step (1, 8)
1420+
// CHECK-SAME: shared_outs(%[[OUT_F16:.+]] = %[[EMPTY_F16]], %[[OUT_I32:.+]] = %[[EMPTY_I32]])
1421+
// CHECK: %[[SLICE_F16:.+]] = tensor.extract_slice %[[OUT_F16]][%[[IV0]], %[[IV1]]] [1, 8] [1, 1]
1422+
// CHECK: %[[FILLED_F16:.+]] = linalg.fill ins(%[[CST_F16]] : f16) outs(%[[SLICE_F16]] : tensor<1x8xf16>)
1423+
// CHECK: %[[SLICE_I32:.+]] = tensor.extract_slice %[[OUT_I32]][%[[IV0]], %[[IV1]]] [1, 8] [1, 1]
1424+
// CHECK: %[[FILLED_I32:.+]] = linalg.fill ins(%[[CST_I32]] : i32) outs(%[[SLICE_I32]] : tensor<1x8xi32>)
1425+
// CHECK: scf.forall
1426+
// CHECK-SAME: shared_outs({{.*}} = %[[FILLED_F16]], {{.*}} = %[[FILLED_I32]])
1427+
// CHECK: iree_linalg_ext.arg_compare

0 commit comments

Comments
 (0)