Skip to content

[mlir][vector] Canonicalize gathers/scatters with trivial offsets #117939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 24, 2025

Conversation

Hardcode84
Copy link
Contributor

Canonicalize gathers/scatters with contiguous (i.e. [0, 1, 2, ...]) offsets into vector masked load/store ops.

@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Canonicalize gathers/scatters with contiguous (i.e. [0, 1, 2, ...]) offsets into vector masked load/store ops.


Full diff: https://github.com/llvm/llvm-project/pull/117939.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+44-2)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+31)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0c0a7bc98d8b5e..21e62085be5a49 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5181,6 +5181,19 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+static LogicalResult isContiguousIndices(Value val) {
+  auto vecType = dyn_cast<VectorType>(val.getType());
+  if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
+    return failure();
+
+  DenseIntElementsAttr elements;
+  if (!matchPattern(val, m_Constant(&elements)))
+    return failure();
+
+  return success(
+      llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
+}
+
 namespace {
 class GatherFolder final : public OpRewritePattern<GatherOp> {
 public:
@@ -5199,11 +5212,26 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
     llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
   }
 };
+
+class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(isContiguousIndices(op.getIndexVec())))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
+                                              op.getIndices(), op.getMask(),
+                                              op.getPassThru());
+    return success();
+  }
+};
 } // namespace
 
 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<GatherFolder>(context);
+  results.add<GatherFolder, GatherTrivialIndices>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -5245,11 +5273,25 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
     llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
   }
 };
+
+class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(ScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(isContiguousIndices(op.getIndexVec())))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<MaskedStoreOp>(
+        op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
+    return success();
+  }
+};
 } // namespace
 
 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ScatterFolder>(context);
+  results.add<ScatterFolder, ScatterTrivialIndices>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5ae769090dac66..b4f9d98e729771 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2826,3 +2826,34 @@ func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_s
   %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
   return %1 : vector<1x1x2x1x1x1xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @contiguous_gather
+//  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+//       CHECK:   return %[[R]]
+func.func @contiguous_gather(%base: memref<?xf32>,
+                             %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+  %1 = vector.gather %base[%c0][%indices], %mask, %passthru :
+    memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %1 : vector<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_scatter
+//  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+func.func @contiguous_scatter(%base: memref<?xf32>,
+                              %mask: vector<16xi1>, %value: vector<16xf32>){
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+  vector.scatter %base[%c0][%indices], %mask, %value :
+    memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+  return
+}

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks! I took a quick look and left some comments for now.

@@ -5181,6 +5181,19 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

static LogicalResult isContiguousIndices(Value val) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add doc about the current supported cases and limitations.

nit: val -> indices, indexVec ... ?

return failure();

DenseIntElementsAttr elements;
if (!matchPattern(val, m_Constant(&elements)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to do something for ConstantMaskOp

return failure();

return success(
llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about contiguous indices with a different start number?


return success(
llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space, is there a common utility that we can use here and for the extract op in the Linalg vectorizer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet - the vectorizer looks at the scalar indices before vectorization. However, this patch make me think that we could do better 🤔 Let me look into this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hardcode84 Would you have some linalg examples that vectorize into these contiguous gathers? That would be helpful, but no worries if Vector is your actual starting point here.

func.func @contiguous_gather(%base: memref<?xf32>,
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add tests for:

  • Start index != 0
  • ConstantMaskOp
  • constant indices that describe a broadcast (e.g., [3, 3, 3, 3, 3, 3... 3])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • We don't need any special handling for constant mask as it already handled in existing masked -> non-masked canonicalizations, added a couple of tests.
  • I can add support for non-zero start, but broadcast is more involved
    • For scatters duplicated indices are undefined per current spec
    • For gather we need reduce(mask) + 1-element vector.maskedload + extract + splat and I would rather not do this as part of this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast is more involved

Could you add a test and a TODO? The test would be "negative" (i.e. the folder would leave the code unchanged). For "scatter" load we'd only need to make sure that invalid.mlir contains relevant test. Could you check that?

I can add support for non-zero start

Could you add a negative test to exercise this case? And a TODO to extend the pattern :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invalid scatter indices (and invalid dynamic indices in general) should not fail validation (see https://mlir.llvm.org/getting_started/DeveloperGuide/#ir-verifier), so nothing to add to invalid.mlir

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thank you for working on this!

Two high level requests:

  • add support for vector.step
  • either add support for multi-dim gathers/scatters, or document the limitations.

@Hardcode84
Copy link
Contributor Author

Hardcode84 commented Nov 29, 2024

add support for vector.step

Maybe, instead of adding special case for vector.step just fold all non-scalable vector.steps into arith.constants? It will probably help other canonicalizations as well.

UPD: Found the other thread #113655, but either way situation is not ideal as we now have semantically identical code represented in 2 different forms, exact situation which folders/canonicalizers was intended to avoid. (IMO, we we should always represent this as constants and constant-> step optimization should happen much later (probably on llvm/spirv level), but I'm not willing to die on this hill)

@dcaballe
Copy link
Contributor

Maybe, instead of adding special case for vector.step just fold all non-scalable vector.steps into arith.constants? It will probably help other canonicalizations as well.

I would say the other way around. We may want vector.step to be the canonical form for vector constants that fall into that category as vector.step holds valuable information that has to be inferred by inspecting the values of a constant op.

@Hardcode84
Copy link
Contributor Author

Hardcode84 commented Dec 27, 2024

Added vector.step support and rebased but I don't want to do non-zero start offset and broadcast as part of this PR.

I would say the other way around. We may want vector.step to be the canonical form for vector constants that fall into that category as vector.step holds valuable information that has to be inferred by inspecting the values of a constant op.

If we really want vector.step to be a canonical representation, we will need to always canonicalize [0,1,2...] constants to it (and it to be part of canonicalize pass and not just some random set of patterns), otherwise relevant patterns like this one will need to always check both forms. I'm still not convinced vector.step (for non-scalable) is useful as checking for relevant constant pattern is trivial and actually gives you more freedom (e.g. you can check for non-zero start offsets mentioned earlier) but I will leave this fight for someone else.

@Hardcode84
Copy link
Contributor Author

Actually, I think, better approach may be to have a special StepElementsAttr in the same way we already have SplatElementsAttr, which is "canonicalized" on DenseElementsAttr construction like SplatElementsAttr. This way we don't need any special ops and users which doesn't know or doesn't care about StepElementsAttr can still work with it as normal dense constant.

func.func @contiguous_gather(%base: memref<?xf32>,
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast is more involved

Could you add a test and a TODO? The test would be "negative" (i.e. the folder would leave the code unchanged). For "scatter" load we'd only need to make sure that invalid.mlir contains relevant test. Could you check that?

I can add support for non-zero start

Could you add a negative test to exercise this case? And a TODO to extend the pattern :)


return success(
llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hardcode84 Would you have some linalg examples that vectorize into these contiguous gathers? That would be helpful, but no worries if Vector is your actual starting point here.

@banach-space
Copy link
Contributor

Actually, I think, better approach may be to have a special StepElementsAttr in the same way we already have SplatElementsAttr, which is "canonicalized" on DenseElementsAttr construction like SplatElementsAttr. This way we don't need any special ops and users which doesn't know or doesn't care about StepElementsAttr can still work with it as normal dense constant.

Supporting vector.step is important - that's the only option for scalable vectors. In fact, since you added support for vector.step, your changes should also work for scalable vectors.

@Hardcode84
Copy link
Contributor Author

@Hardcode84 Would you have some linalg examples that vectorize into these contiguous gathers? That would be helpful, but no worries if Vector is your actual starting point here.

In my specific case (https://github.com/iree-org/iree-turbine/blob/main/iree/turbine/kernel/wave/codegen.py#L808), I generate gathers directly, without going through linalg. Also as side note, indices are constructed from user-provided sympy exprs, so we don't know beforehand if it's a step or not.

@Hardcode84
Copy link
Contributor Author

updated

@banach-space
Copy link
Contributor

banach-space commented Jan 21, 2025

@Hardcode84 Is this ready for another round of reviews? If yes, could you follow LLVM's code-review guidelines and ping reviewers?

Ping the patch. If it is urgent, provide reasons why it is important to you to get this patch landed and ping it every couple of days. If it is not urgent, the common courtesy ping rate is one week. Remember that you’re asking for valuable time from other professional developers.

Note, reviewers get a notification every time a PR is updated (and there's a lot of PRs). "ping" is the usual LLVM way let people know that this is ready for another round :) And, specifically, things like:

updated

tend to be interpreted as noise (it's not clear to me whether "updated" means "I've addressed all PR comments" or just casual "I've made some changes, but might do some more sometime soon.").

Thanks :)
-Andrzej

@Hardcode84
Copy link
Contributor Author

@banach-space yes, PTAL.

I understand everyone is busy but this PR was intended as trivial improvement and I never planned for it to take multiple months or to support broadcasts/step/scalable vectors/non-1D vectors (all of which have zero benefit for my specific use case). At this point I would prefer to either merge it in current form or just drop it and move on.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@dcaballe , looks like tests in canonicalize.mlir cover all TODOs (apart from scalable vectors, but I can handle that). WDYT?

@Hardcode84
Copy link
Contributor Author

Thanks, sorry for the rant

@Hardcode84 Hardcode84 merged commit 88136f9 into llvm:main Jan 24, 2025
8 checks passed
@Hardcode84 Hardcode84 deleted the gather-canon branch January 24, 2025 11:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants