@@ -298,16 +298,156 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
298
298
}
299
299
};
300
300
301
+ // / Transforms a `transfer_read` operation so it reads vector of a type that
302
+ // / can be mapped to an LLVM type ("LLVM-legal" type). This is done by
303
+ // / collapsing trailing dimensions so we obtain a vector type with a single
304
+ // / scalable dimension in the rightmost position.
305
+ // /
306
+ // / Example:
307
+ // / ```
308
+ // / %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
309
+ // / {in_bounds = [false, true, true, true]}
310
+ // / : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
311
+ // / ```
312
+ // / is rewritten to
313
+ // / ```
314
+ // / %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
315
+ // / : memref<?x?x2x8xi8> into memref<?x?xi8>
316
+ // / %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
317
+ // / {in_bounds = [false, true]}
318
+ // / : memref<?x?xi8>, vector<2x[64]xi8>
319
+ // / %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
320
+ // / ```
321
+ struct LegalizeTransferRead : public OpRewritePattern <vector::TransferReadOp> {
322
+ using OpRewritePattern::OpRewritePattern;
323
+
324
+ LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
325
+ PatternRewriter &rewriter) const override {
326
+
327
+ // Do not try to transform masked reads. For example, if we have a transfer
328
+ // to a `vector<[4]x4xi8>` we could have a mask like
329
+ // 1 1 1 0
330
+ // 1 1 1 0
331
+ // 1 1 1 0
332
+ // 0 0 0 0
333
+ // Flattening this mask would look like
334
+ // 1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
335
+ // and we have not yet figured out an efficient way to build such a mask,
336
+ // neither from the mask operand, nor from the original `vector.create_mask`
337
+ // operation (if visible at all).
338
+ if (readOp.isMasked () || readOp.getMask ())
339
+ return rewriter.notifyMatchFailure (readOp,
340
+ " masked transfers not-supported" );
341
+
342
+ // General permutation maps are not supported. The issue is with transpose,
343
+ // broadcast, and other forms of non-identify mapping in the minor
344
+ // dimensions which is impossible to represent after collapsing (at least
345
+ // because the resulting "collapsed" maps would have smaller number of
346
+ // dimension indices).
347
+ // TODO: We have not had yet the need for it, but some forms of permutation
348
+ // maps with identity in the minor dimensions voukld be supported, for
349
+ // example `(i, j, k, p) -> (j, i, k, p)` where we need to collapse only `k`
350
+ // and `p`.
351
+ if (!readOp.getPermutationMap ().isMinorIdentity ())
352
+ return rewriter.notifyMatchFailure (readOp, " non-identity permutation" );
353
+
354
+ // We handle transfers of vectors with rank >= 2 and a single scalable
355
+ // dimension. This transformation aims to transform an LLVM-illegal type
356
+ // into an LLVM-legal type and one dimensional vectors are already
357
+ // LLVM-legal, even if scalable. A value of a vector type with more than one
358
+ // scalable dimension is impossible to represent using a vector type with no
359
+ // scalable dimensions or a single one. For example a `vector<[4]x[4]xi8>`
360
+ // would have `4 * 4 * vscale * vscale` elements and this quantity is
361
+ // impossible to represent as `N` or `N * vscale` (where `N` is a constant).
362
+ VectorType origVT = readOp.getVectorType ();
363
+ ArrayRef<bool > origScalableDims = origVT.getScalableDims ();
364
+ const int64_t origVRank = origVT.getRank ();
365
+ if (origVRank < 2 || origVT.getNumScalableDims () != 1 )
366
+ return rewriter.notifyMatchFailure (readOp, " wrong dimensions" );
367
+
368
+ // Number of trailing dimensions to collapse, including the scalable
369
+ // dimension. Nothing to do if the single scalable dimension is already the
370
+ // last one.
371
+ const int64_t numCollapseDims = std::distance (
372
+ llvm::find (origScalableDims, true ), origScalableDims.end ());
373
+ if (numCollapseDims < 2 )
374
+ return rewriter.notifyMatchFailure (readOp,
375
+ " scalable dimension is trailing" );
376
+
377
+ // We want a simple memref (not a tensor) with contiguous elements for at
378
+ // least all the trailing dimensions up to and including the scalable one.
379
+ auto memTy = dyn_cast<MemRefType>(readOp.getBase ().getType ());
380
+ if (!(memTy && memTy.areTrailingDimsContiguous (numCollapseDims)))
381
+ return rewriter.notifyMatchFailure (
382
+ readOp, " non-contiguous memref dimensions to collapse" );
383
+
384
+ // The dimensions to collapse (excluding the scalable one) of the vector and
385
+ // the memref must match. A dynamic memref dimension is considered
386
+ // non-matching. The transfers from the dimensions to collapse must be
387
+ // in-bounds (it follows the corresponding indices would be zero). This
388
+ // guarantees that the operation transfers a contiguous block
389
+ // and no padding is necessary.
390
+ if (!llvm::equal (memTy.getShape ().take_back (numCollapseDims - 1 ),
391
+ origVT.getShape ().take_back (numCollapseDims - 1 )))
392
+ return rewriter.notifyMatchFailure (
393
+ readOp, " memref and vector dimensions do not match" );
394
+
395
+ SmallVector<bool > origInBounds = readOp.getInBoundsValues ();
396
+ if (!llvm::all_of (
397
+ ArrayRef<bool >(origInBounds).take_back (numCollapseDims - 1 ),
398
+ [](bool v) { return v; }))
399
+ return rewriter.notifyMatchFailure (
400
+ readOp, " out-of-bounds transfer from a dimension to collapse" );
401
+
402
+ // Collapse the trailing dimensions of the memref.
403
+ SmallVector<ReassociationIndices> reassoc;
404
+ for (int64_t i = 0 ; i < memTy.getRank () - numCollapseDims + 1 ; ++i)
405
+ reassoc.push_back ({i});
406
+ for (int64_t i = memTy.getRank () - numCollapseDims + 1 ; i < memTy.getRank ();
407
+ ++i)
408
+ reassoc.back ().push_back (i);
409
+ if (!memref::CollapseShapeOp::isGuaranteedCollapsible (memTy, reassoc))
410
+ return failure ();
411
+ Value collapsedMem = rewriter.create <memref::CollapseShapeOp>(
412
+ readOp.getLoc (), readOp.getBase (), reassoc);
413
+
414
+ // Get a vector type with collapsed trailing dimensions.
415
+ SmallVector<int64_t > shape (origVT.getShape ());
416
+ for (int64_t i = origVRank - numCollapseDims + 1 ; i < origVRank; ++i)
417
+ shape[origVRank - numCollapseDims] *= shape[i];
418
+ shape.pop_back_n (numCollapseDims - 1 );
419
+ auto collapsedVT =
420
+ VectorType::get (shape, origVT.getElementType (),
421
+ origScalableDims.drop_back (numCollapseDims - 1 ));
422
+
423
+ // Drop the extra (zero) indices.
424
+ auto indices = readOp.getIndices ().drop_back (numCollapseDims - 1 );
425
+
426
+ // Create the new `transfer_read`.
427
+ auto newReadOp = rewriter.create <vector::TransferReadOp>(
428
+ readOp.getLoc (), collapsedVT, collapsedMem, indices,
429
+ ArrayRef<bool >(origInBounds).drop_back (numCollapseDims - 1 ));
430
+
431
+ // Cast back to the original vector type.
432
+ auto toOrigShape = rewriter.create <vector::ShapeCastOp>(readOp.getLoc (),
433
+ origVT, newReadOp);
434
+
435
+ rewriter.replaceOp (readOp, toOrigShape);
436
+ return success ();
437
+ }
438
+ };
439
+
301
440
} // namespace
302
441
303
442
void mlir::arm_sve::populateLegalizeVectorStoragePatterns (
304
443
RewritePatternSet &patterns) {
305
- patterns.add <RelaxScalableVectorAllocaAlignment,
306
- LegalizeSVEMaskAllocation<memref::AllocaOp>,
307
- LegalizeSVEMaskAllocation<memref::AllocOp>,
308
- LegalizeSVEMaskTypeCastConversion,
309
- LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
310
- patterns.getContext ());
444
+ patterns
445
+ .add <RelaxScalableVectorAllocaAlignment,
446
+ LegalizeSVEMaskAllocation<memref::AllocaOp>,
447
+ LegalizeSVEMaskAllocation<memref::AllocOp>,
448
+ LegalizeSVEMaskTypeCastConversion, LegalizeSVEMaskStoreConversion,
449
+ LegalizeSVEMaskLoadConversion, LegalizeTransferRead>(
450
+ patterns.getContext ());
311
451
}
312
452
313
453
namespace {
0 commit comments