diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index cc5623068ab10..189bf7f619888 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1208,11 +1208,22 @@ static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, if (xferOp.getMaskType().getRank() > 1) { // Unpack one dimension of the mask. OpBuilder::InsertionGuard guard(b); + Location loc = xferOp.getLoc(); b.setInsertionPoint(newXferOp); // Insert load before newXfer. + auto expr = dyn_cast( + compressUnusedDims(xferOp.getPermutationMap()).getResult(0)); + assert(expr && "cannot extract from dimension"); + // Transpose dim to be the outer most dimension, so we can use + // vector.extract on it. + TypedValue mask = xferOp.getMask(); + SmallVector perm = + llvm::to_vector(llvm::seq(mask.getType().getRank())); + std::swap(perm[0], perm[expr.getPosition()]); + mask = b.create(loc, mask, perm); + // Extract from the transposed mask. llvm::SmallVector indices({i}); - Location loc = xferOp.getLoc(); - auto newMask = b.create(loc, xferOp.getMask(), indices); + auto newMask = b.create(loc, mask, indices); newXferOp.getMaskMutable().assign(newMask); } diff --git a/mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir b/mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir index 7d97829c06599..8aa72086e4e0e 100644 --- a/mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir +++ b/mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir @@ -84,3 +84,33 @@ func.func @transfer_read_mask(%A : memref, %mask : vector<2x3x4xi1>) %vec = vector.transfer_read %A[%c0, %c0, %c0], %f0, %mask {in_bounds = [true, true, true]}: memref, vector<2x3x4xf32> return %vec : vector<2x3x4xf32> } + +// ----- + +func.func @transfer_read_perm_mask(%A : memref, %mask : vector<3x2x4xi1>) -> (vector<2x3x4xf32>) { + %f0 = arith.constant 0.0: f32 + %c0 = arith.constant 0: index + + // CHECK: vector.extract %{{.*}}[0, 0] : vector<4xi1> from vector<3x2x4xi1> + // CHECK-NEXT: vector.transfer_read {{.*}} : memref, vector<4xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<4xf32> into vector<2x3x4xf32> + // CHECK-NEXT: vector.extract %{{.*}}[1, 0] : vector<4xi1> from vector<3x2x4xi1> + // CHECK-NEXT: vector.transfer_read {{.*}} : memref, vector<4xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<4xf32> into vector<2x3x4xf32> + // CHECK-NEXT: vector.extract %{{.*}}[2, 0] : vector<4xi1> from vector<3x2x4xi1> + // CHECK-NEXT: vector.transfer_read {{.*}} : memref, vector<4xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<4xf32> into vector<2x3x4xf32> + // CHECK-NEXT: vector.extract %{{.*}}[0, 1] : vector<4xi1> from vector<3x2x4xi1> + // CHECK-NEXT: vector.transfer_read {{.*}} : memref, vector<4xf32> + // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<4xf32> into vector<2x3x4xf32> + // CHECK-NEXT: vector.extract %{{.*}}[1, 1] : vector<4xi1> from vector<3x2x4xi1> + // CHECK-NEXT: vector.transfer_read {{.*}} : memref, vector<4xf32> + // CHECK-NEXT: vector.insert {{.*}} [1, 1] : vector<4xf32> into vector<2x3x4xf32> + // CHECK-NEXT: vector.extract %{{.*}}[2, 1] : vector<4xi1> from vector<3x2x4xi1> + // CHECK-NEXT: vector.transfer_read {{.*}} : memref, vector<4xf32> + // CHECK-NEXT: vector.insert {{.*}} [1, 2] : vector<4xf32> into vector<2x3x4xf32> + // CHECK-NOT: scf.if + // CHECK-NOT: scf.for + %vec = vector.transfer_read %A[%c0, %c0, %c0, %c0], %f0, %mask {permutation_map = affine_map<(d0, d1, d2, d4) -> (d2, d0, d4)>, in_bounds = [true, true, true]}: memref, vector<2x3x4xf32> + return %vec : vector<2x3x4xf32> +}