@@ -5176,13 +5176,14 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5176
5176
return llvm::to_vector<4 >(getVectorType ().getShape ());
5177
5177
}
5178
5178
5179
- static LogicalResult isContiguousIndices (Value val) {
5180
- auto vecType = dyn_cast<VectorType>(val.getType ());
5179
+ // / Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5180
+ static LogicalResult isContiguousIndices (Value indexVec) {
5181
+ auto vecType = dyn_cast<VectorType>(indexVec.getType ());
5181
5182
if (!vecType || vecType.getRank () != 1 || vecType.isScalable ())
5182
5183
return failure ();
5183
5184
5184
5185
DenseIntElementsAttr elements;
5185
- if (!matchPattern (val , m_Constant (&elements)))
5186
+ if (!matchPattern (indexVec , m_Constant (&elements)))
5186
5187
return failure ();
5187
5188
5188
5189
return success (
@@ -5208,6 +5209,8 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
5208
5209
}
5209
5210
};
5210
5211
5212
+ // / Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5213
+ // / maskedload. Only 1D non-scalable vectors are supported for now.
5211
5214
class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
5212
5215
public:
5213
5216
using OpRewritePattern::OpRewritePattern;
@@ -5269,6 +5272,8 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5269
5272
}
5270
5273
};
5271
5274
5275
+ // / Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5276
+ // / maskedstore. Only 1D non-scalable vectors are supported for now.
5272
5277
class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
5273
5278
public:
5274
5279
using OpRewritePattern::OpRewritePattern;
0 commit comments