-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][vector] Canonicalize gathers/scatters with trivial offsets #117939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesCanonicalize gathers/scatters with contiguous (i.e. [0, 1, 2, ...]) offsets into vector masked load/store ops. Full diff: https://github.com/llvm/llvm-project/pull/117939.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0c0a7bc98d8b5e..21e62085be5a49 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5181,6 +5181,19 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
+static LogicalResult isContiguousIndices(Value val) {
+ auto vecType = dyn_cast<VectorType>(val.getType());
+ if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
+ return failure();
+
+ DenseIntElementsAttr elements;
+ if (!matchPattern(val, m_Constant(&elements)))
+ return failure();
+
+ return success(
+ llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
+}
+
namespace {
class GatherFolder final : public OpRewritePattern<GatherOp> {
public:
@@ -5199,11 +5212,26 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
}
};
+
+class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(GatherOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(isContiguousIndices(op.getIndexVec())))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
+ op.getIndices(), op.getMask(),
+ op.getPassThru());
+ return success();
+ }
+};
} // namespace
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<GatherFolder>(context);
+ results.add<GatherFolder, GatherTrivialIndices>(context);
}
//===----------------------------------------------------------------------===//
@@ -5245,11 +5273,25 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
}
};
+
+class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ScatterOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(isContiguousIndices(op.getIndexVec())))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<MaskedStoreOp>(
+ op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
+ return success();
+ }
+};
} // namespace
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ScatterFolder>(context);
+ results.add<ScatterFolder, ScatterTrivialIndices>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5ae769090dac66..b4f9d98e729771 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2826,3 +2826,34 @@ func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_s
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
return %1 : vector<1x1x2x1x1x1xi32>
}
+
+// -----
+
+// CHECK-LABEL: @contiguous_gather
+// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK: return %[[R]]
+func.func @contiguous_gather(%base: memref<?xf32>,
+ %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = arith.constant 0 : index
+ %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+ %1 = vector.gather %base[%c0][%indices], %mask, %passthru :
+ memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %1 : vector<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_scatter
+// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+func.func @contiguous_scatter(%base: memref<?xf32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>){
+ %c0 = arith.constant 0 : index
+ %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+ vector.scatter %base[%c0][%indices], %mask, %value :
+ memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks! I took a quick look and left some comments for now.
@@ -5181,6 +5181,19 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() { | |||
return llvm::to_vector<4>(getVectorType().getShape()); | |||
} | |||
|
|||
static LogicalResult isContiguousIndices(Value val) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add doc about the current supported cases and limitations.
nit: val
-> indices
, indexVec
... ?
return failure(); | ||
|
||
DenseIntElementsAttr elements; | ||
if (!matchPattern(val, m_Constant(&elements))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to do something for ConstantMaskOp
return failure(); | ||
|
||
return success( | ||
llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements()))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about contiguous indices with a different start number?
|
||
return success( | ||
llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements()))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@banach-space, is there a common utility that we can use here and for the extract op in the Linalg vectorizer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet - the vectorizer looks at the scalar indices before vectorization. However, this patch make me think that we could do better 🤔 Let me look into this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Hardcode84 Would you have some linalg examples that vectorize into these contiguous gathers? That would be helpful, but no worries if Vector
is your actual starting point here.
func.func @contiguous_gather(%base: memref<?xf32>, | ||
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { | ||
%c0 = arith.constant 0 : index | ||
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add tests for:
- Start index != 0
- ConstantMaskOp
- constant indices that describe a broadcast (e.g.,
[3, 3, 3, 3, 3, 3... 3]
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- We don't need any special handling for constant mask as it already handled in existing masked -> non-masked canonicalizations, added a couple of tests.
- I can add support for non-zero start, but broadcast is more involved
- For scatters duplicated indices are undefined per current spec
- For gather we need
reduce(mask)
+ 1-elementvector.maskedload
+extract
+splat
and I would rather not do this as part of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broadcast is more involved
Could you add a test and a TODO? The test would be "negative" (i.e. the folder would leave the code unchanged). For "scatter" load we'd only need to make sure that invalid.mlir contains relevant test. Could you check that?
I can add support for non-zero start
Could you add a negative test to exercise this case? And a TODO to extend the pattern :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid scatter indices (and invalid dynamic indices in general) should not fail validation (see https://mlir.llvm.org/getting_started/DeveloperGuide/#ir-verifier), so nothing to add to invalid.mlir
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thank you for working on this!
Two high level requests:
- add support for
vector.step
- either add support for multi-dim gathers/scatters, or document the limitations.
Maybe, instead of adding special case for UPD: Found the other thread #113655, but either way situation is not ideal as we now have semantically identical code represented in 2 different forms, exact situation which folders/canonicalizers was intended to avoid. (IMO, we we should always represent this as constants and |
I would say the other way around. We may want |
06e4f95
to
8b1f69c
Compare
Added
If we really want |
Actually, I think, better approach may be to have a special |
8b1f69c
to
08fc937
Compare
func.func @contiguous_gather(%base: memref<?xf32>, | ||
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { | ||
%c0 = arith.constant 0 : index | ||
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broadcast is more involved
Could you add a test and a TODO? The test would be "negative" (i.e. the folder would leave the code unchanged). For "scatter" load we'd only need to make sure that invalid.mlir contains relevant test. Could you check that?
I can add support for non-zero start
Could you add a negative test to exercise this case? And a TODO to extend the pattern :)
|
||
return success( | ||
llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements()))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Hardcode84 Would you have some linalg examples that vectorize into these contiguous gathers? That would be helpful, but no worries if Vector
is your actual starting point here.
Supporting |
In my specific case (https://github.com/iree-org/iree-turbine/blob/main/iree/turbine/kernel/wave/codegen.py#L808), I generate gathers directly, without going through linalg. Also as side note, indices are constructed from user-provided sympy exprs, so we don't know beforehand if it's a step or not. |
08fc937
to
8ef0b0e
Compare
updated |
Cononicalize gathers/scatters with contiguous (i.e. [0, 1, 2, ...]) offsets into vector masked load/store ops.
8ef0b0e
to
c079482
Compare
@Hardcode84 Is this ready for another round of reviews? If yes, could you follow LLVM's code-review guidelines and ping reviewers?
Note, reviewers get a notification every time a PR is updated (and there's a lot of PRs). "ping" is the usual LLVM way let people know that this is ready for another round :) And, specifically, things like:
tend to be interpreted as noise (it's not clear to me whether "updated" means "I've addressed all PR comments" or just casual "I've made some changes, but might do some more sometime soon."). Thanks :) |
@banach-space yes, PTAL. I understand everyone is busy but this PR was intended as trivial improvement and I never planned for it to take multiple months or to support broadcasts/step/scalable vectors/non-1D vectors (all of which have zero benefit for my specific use case). At this point I would prefer to either merge it in current form or just drop it and move on. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
@dcaballe , looks like tests in canonicalize.mlir cover all TODOs (apart from scalable vectors, but I can handle that). WDYT?
Thanks, sorry for the rant |
Canonicalize gathers/scatters with contiguous (i.e. [0, 1, 2, ...]) offsets into vector masked load/store ops.