Skip to content

Commit cd0f43b

Browse files
committed
fixup! Refine how bcast dims are handled
Address PR suggestions, update comments, added new tests
1 parent b23b6a9 commit cd0f43b

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4134,24 +4134,23 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
41344134
bool changed = false;
41354135
SmallVector<bool, 4> newInBounds;
41364136
newInBounds.reserve(op.getTransferRank());
4137+
// Idxs of non-bcast dims - used when analysing bcast dims.
41374138
SmallVector<unsigned> nonBcastDims;
4139+
4140+
// 1. Process non-broadcast dims
41384141
for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4139-
// 1. Already marked as in-bounds, nothing to see here.
4142+
// 1.1. Already marked as in-bounds, nothing to see here.
41404143
if (op.isDimInBounds(i)) {
41414144
newInBounds.push_back(true);
41424145
continue;
41434146
}
4144-
// 2. Currently out-of-bounds, check whether we can statically determine it
4145-
// is inBounds.
4147+
// 1.2. Currently out-of-bounds, check whether we can statically determine
4148+
// it is inBounds.
41464149
bool inBounds = false;
41474150
auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
41484151
if (dimExpr) {
4149-
// 2.a Non-broadcast dim
41504152
inBounds = isInBounds(op, /*resultIdx=*/i,
41514153
/*indicesIdx=*/dimExpr.getPosition());
4152-
// 2.b Broadcast dims are handled after processing non-bcast dims
4153-
// FIXME: constant expr != 0 are not broadcasts - should such
4154-
// constants be allowed at all?
41554154
nonBcastDims.push_back(i);
41564155
}
41574156

@@ -4160,15 +4159,17 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
41604159
changed |= inBounds;
41614160
}
41624161

4163-
// Handle broadcast dims: if all non-broadcast dims are "in
4164-
// bounds", then all bcast dims should be "in bounds" as well.
4162+
// 2. Handle broadcast dims
4163+
// If all non-broadcast dims are "in bounds", then all bcast dims should be
4164+
// "in bounds" as well.
41654165
bool allNonBcastDimsInBounds = llvm::all_of(
41664166
nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4167-
if (allNonBcastDimsInBounds)
4168-
llvm::for_each(permutationMap.getBroadcastDims(), [&](unsigned idx) {
4167+
if (allNonBcastDimsInBounds) {
4168+
for (size_t idx : permutationMap.getBroadcastDims()) {
41694169
changed |= !newInBounds[idx];
41704170
newInBounds[idx] = true;
4171-
});
4171+
}
4172+
}
41724173

41734174
if (!changed)
41744175
return failure();

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,15 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
454454

455455
// -----
456456

457+
func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
458+
%c3 = arith.constant 3 : index
459+
%cst = arith.constant 3.0 : f32
460+
// expected-error@+1 {{requires a projected permutation_map (at most one dim or the zero constant can appear in each result)}}
461+
%0 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(1)>} : memref<?x?xf32>, vector<128xf32>
462+
}
463+
464+
// -----
465+
457466
func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
458467
%c3 = arith.constant 3 : index
459468
%cst = arith.constant 3.0 : f32
@@ -608,6 +617,15 @@ func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
608617

609618
// -----
610619

620+
func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
621+
%c3 = arith.constant 3 : index
622+
%cst = arith.constant dense<3.0> : vector<128 x f32>
623+
// expected-error@+1 {{requires a projected permutation_map (at most one dim or the zero constant can appear in each result)}}
624+
vector.transfer_write %cst, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(1)>} : vector<128xf32>, memref<?x?xf32>
625+
}
626+
627+
// -----
628+
611629
func.func @test_vector.transfer_write(%arg0: memref<?x?x?xf32>) {
612630
%c3 = arith.constant 3 : index
613631
%cst = arith.constant dense<3.0> : vector<3 x 7 x f32>

0 commit comments

Comments
 (0)