Skip to content

[mlir][vector] Clarify the semantics of BroadcastOp #101928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ enum class BroadcastableToResult {
DimensionMismatch = 2,
SourceTypeNotAVector = 3
};
struct VectorDim {
int64_t dim;
bool scalableFlag;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was this MR from @MacDue , implementing similar features. I dont know why it got closed though.
#96236

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is unrelated to that discussion. I'm only adding this here to avoid adding new set of params to isBroadcastableTo.

I believe that before we commit to any new wider API, we should discuss the internal representation of VectorType and how scalable dimensions are represented. I am working on a proposal, but that's not yet ready to share 😅 I'm hoping to have something in the coming weeks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouuh, exciting.

BroadcastableToResult
isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims = nullptr);
std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr);

/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def Vector_BroadcastOp :
s_1 x .. x s_j x .. x s_k
<duplication> <potential stretch>
```
* a scalable unit dimeension, `[1]`, must match exactly.

The source operand is duplicated over all the missing leading dimensions
and stretched over the trailing dimensions where the source has a non-equal
dimension of 1. These rules imply that any scalar broadcast (k=0) to any
Expand Down
45 changes: 35 additions & 10 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2371,9 +2371,9 @@ Value BroadcastOp::createOrFoldBroadcastOp(
return res;
}

BroadcastableToResult
mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims) {
BroadcastableToResult mlir::vector::isBroadcastableTo(
Type srcType, VectorType dstVectorType,
std::pair<VectorDim, VectorDim> *mismatchingDims) {
// Broadcast scalar to vector of the same element type.
if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
Expand All @@ -2391,12 +2391,28 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
// (all leading dimensions are simply duplicated).
int64_t lead = dstRank - srcRank;
for (int64_t r = 0; r < srcRank; ++r) {
bool mismatch = false;

// Check fixed-width dims
int64_t srcDim = srcVectorType.getDimSize(r);
int64_t dstDim = dstVectorType.getDimSize(lead + r);
if (srcDim != 1 && srcDim != dstDim) {
if ((srcDim != 1 && srcDim != dstDim))
mismatch = true;

// Check scalable flags
bool srcDimScalableFlag = srcVectorType.getScalableDims()[r];
bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + r];
if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
(srcDimScalableFlag && !dstDimScalableFlag))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It got me thinking, what would be the expected behaviour of something like:

 %0 = vector.broadcast %arg0 : vector<nxf32> to vector<[n]xf32>

IMO it should not be supported as physically equivalent to a usecase

%1 = vector.broadcast %arg0 : vector<nxf32> to vector<vscale*nxf32>

Which is not invalid for fixed dimensions. Do you think this handles the cases ?

Suggested change
(srcDimScalableFlag && !dstDimScalableFlag))
(srcDimScalableFlag != dstDimScalableFlag))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have e.g. [2] and [4] (i.e. vscale * 2 and vscale * 4), then that's already "rejected" as "mismatching dims":

Is that the case you had in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case I pointed out was more src = 2 and dest = [2]. srcDim == dstDim, so no mismatch on line 2399. and we have !srcDimScalableFlag so no mismatch on line 2406. Whereas I think this is wrong.

 %0 = vector.broadcast %arg0 : vector<2xf32> to vector<[2]xf32>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, nice, great catch! In my head I had one case that wouldn't work with !=, but now I am failing to recall that 😂

Let me send an update - thanks very much for pointing this out 🙏🏻

mismatch = true;

if (mismatch) {
if (mismatchingDims) {
mismatchingDims->first = srcDim;
mismatchingDims->second = dstDim;
mismatchingDims->first.dim = srcDim;
mismatchingDims->first.scalableFlag = srcDimScalableFlag;

mismatchingDims->second.dim = dstDim;
mismatchingDims->second.scalableFlag = dstDimScalableFlag;
}
return BroadcastableToResult::DimensionMismatch;
}
Expand All @@ -2406,16 +2422,25 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
}

LogicalResult BroadcastOp::verify() {
std::pair<int, int> mismatchingDims;
std::pair<VectorDim, VectorDim> mismatchingDims;
BroadcastableToResult res = isBroadcastableTo(
getSourceType(), getResultVectorType(), &mismatchingDims);
if (res == BroadcastableToResult::Success)
return success();
if (res == BroadcastableToResult::SourceRankHigher)
return emitOpError("source rank higher than destination rank");
if (res == BroadcastableToResult::DimensionMismatch)
return emitOpError("dimension mismatch (")
<< mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
if (res == BroadcastableToResult::DimensionMismatch) {
std::string msg =
(Twine("dimension mismatch (") +
(mismatchingDims.first.scalableFlag ? "[" : "") +
std::to_string(mismatchingDims.first.dim) +
(mismatchingDims.first.scalableFlag ? "]" : "") + " vs. " +
(mismatchingDims.second.scalableFlag ? "[" : "") +
std::to_string(mismatchingDims.second.dim) +
(mismatchingDims.second.scalableFlag ? "]" : "") + ")")
.str();
return emitOpError(msg);
}
if (res == BroadcastableToResult::SourceTypeNotAVector)
return emitOpError("source type is not a vector");
llvm_unreachable("unexpected vector.broadcast op error");
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ func.func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {

// -----

func.func @broadcast_scalable_unit_dim(%arg0: vector<[1]xf32>) {
// expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. [4])}}
%0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<[4]xf32>
}

// -----

func.func @broadcast_scalable_to_fixed(%arg0: vector<[1]xf32>) {
// expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. 1)}}
%0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x1xf32>
}

// -----

func.func @broadcast_unknown(%arg0: memref<4x8xf32>) {
// expected-error@+1 {{'vector.broadcast' op source type is not a vector}}
%1 = vector.broadcast %arg0 : memref<4x8xf32> to vector<1x8xf32>
Expand Down
Loading