Skip to content

Commit a5b2ab5

Browse files
committed
[Codegen] add FoldExtractSliceOfFillThroughBlockArg pattern to TileAndDistributeToWorkgroups
Signed-off-by: Bangtian Liu <[email protected]>
1 parent 2fbd13b commit a5b2ab5

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
308308
// TODO(Max191): Replace populateSwapExtractWithExpandPattern with upstream
309309
// MLIR version once it is available (llvm-project/pull/126898).
310310
populateSwapExtractWithExpandPattern(cleanupPatterns);
311+
populateFoldExtractSliceOfFillThroughBlockArgPattern(cleanupPatterns);
311312
// When fusing pads we do not want to generate zeroSliceGuards when doing
312313
// workgroup tiling. In `GPUApplyTilingLevelPass` we do have an option called
313314
// `allowZeroSlices` that can control this but we do not want these
@@ -412,6 +413,7 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
412413
{
413414
RewritePatternSet patterns(context);
414415
populateSwapExtractWithCollapsePattern(patterns);
416+
populateFoldExtractSliceOfFillThroughBlockArgPattern(patterns);
415417
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
416418
tensor::populateFoldTensorEmptyPatterns(patterns);
417419
context->getOrLoadDialect<tensor::TensorDialect>()

compiler/src/iree/compiler/Codegen/Common/Transforms.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "llvm/ADT/ScopeExit.h"
1111
#include "mlir/Analysis/SliceAnalysis.h"
1212
#include "mlir/Dialect/Affine/IR/AffineOps.h"
13+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1314
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1415

1516
#define DEBUG_TYPE "iree-codegen-common-transforms"
@@ -354,6 +355,90 @@ void populateSwapExtractWithExpandPattern(RewritePatternSet &patterns) {
354355
patterns.add<SwapExpandShapeWithSlicePattern>(patterns.getContext());
355356
}
356357

358+
namespace {
359+
/// Pattern to fold extract_slice of a fill through a forall's block argument.
360+
/// When extracting a slice from a block argument where the init value is a
361+
/// linalg.fill, we update the forall's shared_outs to use the fill's
362+
/// destination (the empty tensor), and then create a fill on the extracted
363+
/// slice inside the loop body.
364+
///
365+
/// Example:
366+
/// %empty = tensor.empty() : tensor<4x1xf16>
367+
/// %fill = linalg.fill ins(%cst) outs(%empty) -> tensor<4x1xf16>
368+
/// scf.forall ... shared_outs(%arg = %fill) {
369+
/// %slice = tensor.extract_slice %arg[%i, 0] [1, 1] -> tensor<1x1xf16>
370+
/// ...
371+
/// }
372+
/// ->
373+
/// %empty = tensor.empty() : tensor<4x1xf16>
374+
/// scf.forall ... shared_outs(%arg = %empty) { // Updated to use %empty
375+
/// %extracted = tensor.extract_slice %arg[%i, 0] [1, 1] -> tensor<1x1xf16>
376+
/// %slice = linalg.fill ins(%cst) outs(%extracted) -> tensor<1x1xf16>
377+
/// ...
378+
/// }
379+
struct FoldExtractSliceOfFillThroughBlockArg final
380+
: OpRewritePattern<tensor::ExtractSliceOp> {
381+
using OpRewritePattern::OpRewritePattern;
382+
383+
LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
384+
PatternRewriter &rewriter) const override {
385+
auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
386+
if (!blockArg) {
387+
return rewriter.notifyMatchFailure(extractOp,
388+
"source is not a block argument");
389+
}
390+
391+
auto forallOp = dyn_cast<scf::ForallOp>(blockArg.getOwner()->getParentOp());
392+
if (!forallOp) {
393+
return rewriter.notifyMatchFailure(
394+
extractOp, "block argument is not from an scf.forall");
395+
}
396+
397+
unsigned argNum = blockArg.getArgNumber();
398+
unsigned numIVs = forallOp.getInductionVars().size();
399+
if (argNum < numIVs) {
400+
return rewriter.notifyMatchFailure(
401+
extractOp, "block argument is an induction variable, not shared_out");
402+
}
403+
404+
unsigned outputIdx = argNum - numIVs;
405+
if (outputIdx >= forallOp.getOutputs().size()) {
406+
return rewriter.notifyMatchFailure(extractOp,
407+
"invalid output index for block arg");
408+
}
409+
410+
Value initValue = forallOp.getOutputs()[outputIdx];
411+
412+
auto fillOp = initValue.getDefiningOp<linalg::FillOp>();
413+
if (!fillOp) {
414+
return rewriter.notifyMatchFailure(
415+
extractOp, "init value is not a linalg.fill operation");
416+
}
417+
418+
Value fillValue = fillOp.getInputs()[0];
419+
Value fillDest = fillOp.getOutputs()[0];
420+
rewriter.modifyOpInPlace(forallOp, [&]() {
421+
forallOp.getOutputsMutable()[outputIdx].set(fillDest);
422+
});
423+
424+
rewriter.setInsertionPointAfter(extractOp);
425+
Location loc = extractOp.getLoc();
426+
auto newFillOp =
427+
linalg::FillOp::create(rewriter, loc, fillValue, extractOp.getResult());
428+
newFillOp->setDiscardableAttrs(fillOp->getDiscardableAttrDictionary());
429+
rewriter.replaceAllUsesExcept(extractOp.getResult(), newFillOp.getResult(0),
430+
newFillOp);
431+
return success();
432+
}
433+
};
434+
435+
} // namespace
436+
437+
void populateFoldExtractSliceOfFillThroughBlockArgPattern(
438+
RewritePatternSet &patterns) {
439+
patterns.add<FoldExtractSliceOfFillThroughBlockArg>(patterns.getContext());
440+
}
441+
357442
/// Note the following pattern is adapted from the upstream pattern
358443
/// `BubbleUpCollapseShapeThroughExtractSlice` by allowing some special cases.
359444
///

compiler/src/iree/compiler/Codegen/Common/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ void populateReplaceSlowMinMaxOpsPatterns(RewritePatternSet &patterns);
193193
/// `tensor.expand_shape(tensor.extract_slice)`.
194194
void populateSwapExtractWithExpandPattern(RewritePatternSet &patterns);
195195

196+
/// Populate pattern to fold `tensor.extract_slice` of a `linalg.fill` through
197+
/// a forall's block argument. Creates a smaller tensor.empty and linalg.fill
198+
/// inside the loop body.
199+
void populateFoldExtractSliceOfFillThroughBlockArgPattern(
200+
RewritePatternSet &patterns);
201+
196202
/// Populate pattern to convert `tensor.extract_slice(tensor.collapse_shape)` to
197203
/// `tensor.collapse_shape(tensor.extract_slice)`.
198204
void populateSwapExtractWithCollapsePattern(RewritePatternSet &patterns);

compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,3 +1364,57 @@ attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPU
13641364
// CHECK: scf.forall.in_parallel
13651365
// CHECK: tensor.parallel_insert_slice %[[RES]] into %[[OUT0]][%[[OFFSET0]], 0, %[[OFFSET1]]]
13661366
// CHECK: {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
1367+
1368+
// -----
1369+
1370+
#config_fill_fold = #iree_codegen.lowering_config<tile_sizes = [[1, 8]]>
1371+
func.func @fold_fill_through_block_arg(%arg0 : tensor<4x16x128xf16>) -> (tensor<4x16xf16>, tensor<4x16xi32>) {
1372+
%cst = arith.constant 0xFC00 : f16
1373+
%c0_i32 = arith.constant 0 : i32
1374+
%c0 = arith.constant 0 : index
1375+
%empty_f16 = tensor.empty() : tensor<4x16xf16>
1376+
%empty_i32 = tensor.empty() : tensor<4x16xi32>
1377+
%fill_f16 = linalg.fill {lowering_config = #config_fill_fold}
1378+
ins(%cst : f16) outs(%empty_f16 : tensor<4x16xf16>) -> tensor<4x16xf16>
1379+
%fill_i32 = linalg.fill {lowering_config = #config_fill_fold}
1380+
ins(%c0_i32 : i32) outs(%empty_i32 : tensor<4x16xi32>) -> tensor<4x16xi32>
1381+
%result:2 = scf.forall (%iv0, %iv1) = (0, 0) to (4, 16) step (1, 8)
1382+
shared_outs(%out_f16 = %fill_f16, %out_i32 = %fill_i32) -> (tensor<4x16xf16>, tensor<4x16xi32>) {
1383+
%in_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [1, 8, 128] [1, 1, 1]
1384+
: tensor<4x16x128xf16> to tensor<1x8x128xf16>
1385+
%slice_f16 = tensor.extract_slice %out_f16[%iv0, %iv1] [1, 8] [1, 1]
1386+
: tensor<4x16xf16> to tensor<1x8xf16>
1387+
%slice_i32 = tensor.extract_slice %out_i32[%iv0, %iv1] [1, 8] [1, 1]
1388+
: tensor<4x16xi32> to tensor<1x8xi32>
1389+
%compare:2 = iree_linalg_ext.arg_compare {lowering_config = #config_fill_fold}
1390+
dimension(2) ins(%in_slice : tensor<1x8x128xf16>)
1391+
outs(%slice_f16, %slice_i32 : tensor<1x8xf16>, tensor<1x8xi32>)
1392+
index_base(%c0 : index) {
1393+
^bb0(%lhs: f16, %rhs: f16):
1394+
%cmp = arith.cmpf ogt, %lhs, %rhs : f16
1395+
iree_linalg_ext.yield %cmp : i1
1396+
} -> tensor<1x8xf16>, tensor<1x8xi32>
1397+
scf.forall.in_parallel {
1398+
tensor.parallel_insert_slice %compare#0 into %out_f16[%iv0, %iv1] [1, 8] [1, 1]
1399+
: tensor<1x8xf16> into tensor<4x16xf16>
1400+
tensor.parallel_insert_slice %compare#1 into %out_i32[%iv0, %iv1] [1, 8] [1, 1]
1401+
: tensor<1x8xi32> into tensor<4x16xi32>
1402+
}
1403+
} {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
1404+
return %result#0, %result#1 : tensor<4x16xf16>, tensor<4x16xi32>
1405+
}
1406+
1407+
// CHECK-LABEL: func.func @fold_fill_through_block_arg
1408+
// CHECK-DAG: %[[CST_F16:.+]] = arith.constant 0xFC00 : f16
1409+
// CHECK-DAG: %[[CST_I32:.+]] = arith.constant 0 : i32
1410+
// CHECK-DAG: %[[EMPTY_F16:.+]] = tensor.empty() : tensor<4x16xf16>
1411+
// CHECK-DAG: %[[EMPTY_I32:.+]] = tensor.empty() : tensor<4x16xi32>
1412+
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) = (0, 0) to (4, 16) step (1, 8)
1413+
// CHECK-SAME: shared_outs(%[[OUT_F16:.+]] = %[[EMPTY_F16]], %[[OUT_I32:.+]] = %[[EMPTY_I32]])
1414+
// CHECK: %[[SLICE_F16:.+]] = tensor.extract_slice %[[OUT_F16]][%[[IV0]], %[[IV1]]] [1, 8] [1, 1]
1415+
// CHECK: %[[FILLED_F16:.+]] = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = {{\[\[}}1, 8]]>} ins(%[[CST_F16]] : f16) outs(%[[SLICE_F16]] : tensor<1x8xf16>)
1416+
// CHECK: %[[SLICE_I32:.+]] = tensor.extract_slice %[[OUT_I32]][%[[IV0]], %[[IV1]]] [1, 8] [1, 1]
1417+
// CHECK: %[[FILLED_I32:.+]] = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = {{\[\[}}1, 8]]>} ins(%[[CST_I32]] : i32) outs(%[[SLICE_I32]] : tensor<1x8xi32>)
1418+
// CHECK: scf.forall
1419+
// CHECK-SAME: shared_outs({{.*}} = %[[FILLED_F16]], {{.*}} = %[[FILLED_I32]])
1420+
// CHECK: iree_linalg_ext.arg_compare

0 commit comments

Comments
 (0)