@@ -4134,24 +4134,23 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4134
4134
bool changed = false ;
4135
4135
SmallVector<bool , 4 > newInBounds;
4136
4136
newInBounds.reserve (op.getTransferRank ());
4137
+ // Idxs of non-bcast dims - used when analysing bcast dims.
4137
4138
SmallVector<unsigned > nonBcastDims;
4139
+
4140
+ // 1. Process non-broadcast dims
4138
4141
for (unsigned i = 0 ; i < op.getTransferRank (); ++i) {
4139
- // 1. Already marked as in-bounds, nothing to see here.
4142
+ // 1.1. Already marked as in-bounds, nothing to see here.
4140
4143
if (op.isDimInBounds (i)) {
4141
4144
newInBounds.push_back (true );
4142
4145
continue ;
4143
4146
}
4144
- // 2. Currently out-of-bounds, check whether we can statically determine it
4145
- // is inBounds.
4147
+ // 1. 2. Currently out-of-bounds, check whether we can statically determine
4148
+ // it is inBounds.
4146
4149
bool inBounds = false ;
4147
4150
auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult (i));
4148
4151
if (dimExpr) {
4149
- // 2.a Non-broadcast dim
4150
4152
inBounds = isInBounds (op, /* resultIdx=*/ i,
4151
4153
/* indicesIdx=*/ dimExpr.getPosition ());
4152
- // 2.b Broadcast dims are handled after processing non-bcast dims
4153
- // FIXME: constant expr != 0 are not broadcasts - should such
4154
- // constants be allowed at all?
4155
4154
nonBcastDims.push_back (i);
4156
4155
}
4157
4156
@@ -4160,15 +4159,17 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4160
4159
changed |= inBounds;
4161
4160
}
4162
4161
4163
- // Handle broadcast dims: if all non-broadcast dims are "in
4164
- // bounds", then all bcast dims should be "in bounds" as well.
4162
+ // 2. Handle broadcast dims
4163
+ // If all non-broadcast dims are "in bounds", then all bcast dims should be
4164
+ // "in bounds" as well.
4165
4165
bool allNonBcastDimsInBounds = llvm::all_of (
4166
4166
nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4167
- if (allNonBcastDimsInBounds)
4168
- llvm::for_each ( permutationMap.getBroadcastDims (), [&]( unsigned idx ) {
4167
+ if (allNonBcastDimsInBounds) {
4168
+ for ( size_t idx : permutationMap.getBroadcastDims ()) {
4169
4169
changed |= !newInBounds[idx];
4170
4170
newInBounds[idx] = true ;
4171
- });
4171
+ }
4172
+ }
4172
4173
4173
4174
if (!changed)
4174
4175
return failure ();
0 commit comments