Skip to content

Commit 3b251cd

Browse files
[MLIR] Legalize certain vector.transfer_read ops of scalable vectors (#143146)
This patch adds a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single trailing scalable dimension.
1 parent 4f9adb6 commit 3b251cd

File tree

3 files changed

+482
-6
lines changed

3 files changed

+482
-6
lines changed

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp

Lines changed: 146 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,16 +298,156 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
298298
}
299299
};
300300

301+
/// Transforms a `transfer_read` operation so it reads vector of a type that
302+
/// can be mapped to an LLVM type ("LLVM-legal" type). This is done by
303+
/// collapsing trailing dimensions so we obtain a vector type with a single
304+
/// scalable dimension in the rightmost position.
305+
///
306+
/// Example:
307+
/// ```
308+
/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
309+
/// {in_bounds = [false, true, true, true]}
310+
/// : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
311+
/// ```
312+
/// is rewritten to
313+
/// ```
314+
/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
315+
/// : memref<?x?x2x8xi8> into memref<?x?xi8>
316+
/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
317+
/// {in_bounds = [false, true]}
318+
/// : memref<?x?xi8>, vector<2x[64]xi8>
319+
/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
320+
/// ```
321+
struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
322+
using OpRewritePattern::OpRewritePattern;
323+
324+
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
325+
PatternRewriter &rewriter) const override {
326+
327+
// Do not try to transform masked reads. For example, if we have a transfer
328+
// to a `vector<[4]x4xi8>` we could have a mask like
329+
// 1 1 1 0
330+
// 1 1 1 0
331+
// 1 1 1 0
332+
// 0 0 0 0
333+
// Flattening this mask would look like
334+
// 1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
335+
// and we have not yet figured out an efficient way to build such a mask,
336+
// neither from the mask operand, nor from the original `vector.create_mask`
337+
// operation (if visible at all).
338+
if (readOp.isMasked() || readOp.getMask())
339+
return rewriter.notifyMatchFailure(readOp,
340+
"masked transfers not-supported");
341+
342+
// General permutation maps are not supported. The issue is with transpose,
343+
// broadcast, and other forms of non-identify mapping in the minor
344+
// dimensions which is impossible to represent after collapsing (at least
345+
// because the resulting "collapsed" maps would have smaller number of
346+
// dimension indices).
347+
// TODO: We have not had yet the need for it, but some forms of permutation
348+
// maps with identity in the minor dimensions voukld be supported, for
349+
// example `(i, j, k, p) -> (j, i, k, p)` where we need to collapse only `k`
350+
// and `p`.
351+
if (!readOp.getPermutationMap().isMinorIdentity())
352+
return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
353+
354+
// We handle transfers of vectors with rank >= 2 and a single scalable
355+
// dimension. This transformation aims to transform an LLVM-illegal type
356+
// into an LLVM-legal type and one dimensional vectors are already
357+
// LLVM-legal, even if scalable. A value of a vector type with more than one
358+
// scalable dimension is impossible to represent using a vector type with no
359+
// scalable dimensions or a single one. For example a `vector<[4]x[4]xi8>`
360+
// would have `4 * 4 * vscale * vscale` elements and this quantity is
361+
// impossible to represent as `N` or `N * vscale` (where `N` is a constant).
362+
VectorType origVT = readOp.getVectorType();
363+
ArrayRef<bool> origScalableDims = origVT.getScalableDims();
364+
const int64_t origVRank = origVT.getRank();
365+
if (origVRank < 2 || origVT.getNumScalableDims() != 1)
366+
return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
367+
368+
// Number of trailing dimensions to collapse, including the scalable
369+
// dimension. Nothing to do if the single scalable dimension is already the
370+
// last one.
371+
const int64_t numCollapseDims = std::distance(
372+
llvm::find(origScalableDims, true), origScalableDims.end());
373+
if (numCollapseDims < 2)
374+
return rewriter.notifyMatchFailure(readOp,
375+
"scalable dimension is trailing");
376+
377+
// We want a simple memref (not a tensor) with contiguous elements for at
378+
// least all the trailing dimensions up to and including the scalable one.
379+
auto memTy = dyn_cast<MemRefType>(readOp.getBase().getType());
380+
if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
381+
return rewriter.notifyMatchFailure(
382+
readOp, "non-contiguous memref dimensions to collapse");
383+
384+
// The dimensions to collapse (excluding the scalable one) of the vector and
385+
// the memref must match. A dynamic memref dimension is considered
386+
// non-matching. The transfers from the dimensions to collapse must be
387+
// in-bounds (it follows the corresponding indices would be zero). This
388+
// guarantees that the operation transfers a contiguous block
389+
// and no padding is necessary.
390+
if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
391+
origVT.getShape().take_back(numCollapseDims - 1)))
392+
return rewriter.notifyMatchFailure(
393+
readOp, "memref and vector dimensions do not match");
394+
395+
SmallVector<bool> origInBounds = readOp.getInBoundsValues();
396+
if (!llvm::all_of(
397+
ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
398+
[](bool v) { return v; }))
399+
return rewriter.notifyMatchFailure(
400+
readOp, "out-of-bounds transfer from a dimension to collapse");
401+
402+
// Collapse the trailing dimensions of the memref.
403+
SmallVector<ReassociationIndices> reassoc;
404+
for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
405+
reassoc.push_back({i});
406+
for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank();
407+
++i)
408+
reassoc.back().push_back(i);
409+
if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
410+
return failure();
411+
Value collapsedMem = rewriter.create<memref::CollapseShapeOp>(
412+
readOp.getLoc(), readOp.getBase(), reassoc);
413+
414+
// Get a vector type with collapsed trailing dimensions.
415+
SmallVector<int64_t> shape(origVT.getShape());
416+
for (int64_t i = origVRank - numCollapseDims + 1; i < origVRank; ++i)
417+
shape[origVRank - numCollapseDims] *= shape[i];
418+
shape.pop_back_n(numCollapseDims - 1);
419+
auto collapsedVT =
420+
VectorType::get(shape, origVT.getElementType(),
421+
origScalableDims.drop_back(numCollapseDims - 1));
422+
423+
// Drop the extra (zero) indices.
424+
auto indices = readOp.getIndices().drop_back(numCollapseDims - 1);
425+
426+
// Create the new `transfer_read`.
427+
auto newReadOp = rewriter.create<vector::TransferReadOp>(
428+
readOp.getLoc(), collapsedVT, collapsedMem, indices,
429+
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
430+
431+
// Cast back to the original vector type.
432+
auto toOrigShape = rewriter.create<vector::ShapeCastOp>(readOp.getLoc(),
433+
origVT, newReadOp);
434+
435+
rewriter.replaceOp(readOp, toOrigShape);
436+
return success();
437+
}
438+
};
439+
301440
} // namespace
302441

303442
void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
304443
RewritePatternSet &patterns) {
305-
patterns.add<RelaxScalableVectorAllocaAlignment,
306-
LegalizeSVEMaskAllocation<memref::AllocaOp>,
307-
LegalizeSVEMaskAllocation<memref::AllocOp>,
308-
LegalizeSVEMaskTypeCastConversion,
309-
LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
310-
patterns.getContext());
444+
patterns
445+
.add<RelaxScalableVectorAllocaAlignment,
446+
LegalizeSVEMaskAllocation<memref::AllocaOp>,
447+
LegalizeSVEMaskAllocation<memref::AllocOp>,
448+
LegalizeSVEMaskTypeCastConversion, LegalizeSVEMaskStoreConversion,
449+
LegalizeSVEMaskLoadConversion, LegalizeTransferRead>(
450+
patterns.getContext());
311451
}
312452

313453
namespace {

0 commit comments

Comments
 (0)