Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 210 additions & 10 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include <utility>

Expand Down Expand Up @@ -52,6 +55,21 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
return map;
}

static int getDistributedDim(VectorType origType, VectorType distributedType) {
assert(origType.getRank() == distributedType.getRank() &&
"sequential and distributed vector types must have the same rank");
int64_t distributedDim = -1;
for (int64_t i = 0; i < origType.getRank(); ++i) {
if (distributedType.getDimSize(i) != origType.getDimSize(i)) {
// Keep this assert here in case WarpExecuteOnLane0Op gets extended to
// support distributing multiple dimensions in the future.
assert(distributedDim == -1 && "found multiple distributed dims");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about return a failure if there is more than one dim mismatch? it could avoid the crash of the pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code was already there. I moved it to a function to reuse.

I think the motivation of the assert is that the pass strictly assumes only 1 dim is distributed. assert is there to add more support later. lets keep it for now so that crash is isolated to this pass.

distributedDim = i;
}
}
return distributedDim;
}

namespace {

/// Helper struct to create the load / store operations that permit transit
Expand Down Expand Up @@ -1076,6 +1094,195 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
}
};

/// Sink out insert_strided_slice op feeding into a warp op yield.
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
/// ...
/// %src = ... : vector<4x16xf32>
/// %dest = ... : vector<8x16xf32>
/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
/// strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
/// gpu.yield %insert : vector<8x16xf32>
/// }
/// ```
/// To
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
/// vector<8x1xf32>) {
/// ...
/// %src = ... : vector<4x16xf32>
/// %dest = ... : vector<8x16xf32>
/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
/// }
/// %insert = vector.insert_strided_slice %0#0, %0#1,
/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
/// ```
/// NOTE: Current support assume that both src and dest vectors are distributed
/// to lanes and sinking the insert op does not require any cross lane
/// communication.
struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto insertOp =
operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
auto distributedType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
// Distributed type must be 2D or higher.
// TODO: Support 1D distributed types.
if (distributedType.getRank() < 2)
return rewriter.notifyMatchFailure(
insertOp, "result vector type must be 2D or higher");
// Find the distributed dimension of the dest vector. There should be
// exactly one.
auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t destDistributedDim =
getDistributedDim(yieldedType, distributedType);
assert(destDistributedDim != -1 && "could not find distributed dimension");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about return failure or notifyMatchFailure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing patterns assume always 1 dimension is distributed (check warpOpExtract). lets keep the assert for now due to this assumption.

(void)destDistributedDim;
VectorType srcType = insertOp.getSourceVectorType();
VectorType destType = insertOp.getDestVectorType();
// Currently we require that both source (kD) and dest (nD) vectors are
// distributed. This requires that distributedDim (d) is contained in the
// last k dims of the dest vector (d >= n - k).
// TODO: Add support for case where source vector is not distributed.
int64_t sourceDistributedDim =
destDistributedDim - (destType.getRank() - srcType.getRank());
if (sourceDistributedDim < 0)
return rewriter.notifyMatchFailure(
insertOp, "distributed dimension must be in the last k dims");
// Distributed dimension must be fully inserted.
if (srcType.getDimSize(sourceDistributedDim) !=
Copy link
Contributor

@Jianhui-Li Jianhui-Li Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason we disallow distributing the following case? I think the distribution should work as long as offsets are multiple of subgroup size.
/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 32],
/// strides = [1, 1] : vector<8x32xf32> into vector<8x64xf32>

=> suppose subgroup size = 32
/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 1],
/// strides = [1, 1] : vector<8x1xf32> into vector<8x2xf32>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, this will be added in separate PR after some investigation into other upstream patterns. Current support make no assumption about what data is owned by what lane.

destType.getDimSize(destDistributedDim))
return rewriter.notifyMatchFailure(
insertOp, "distributed dimension must be fully inserted");
SmallVector<int64_t> newSourceDistShape(
insertOp.getSourceVectorType().getShape()),
newDestDistShape(insertOp.getDestVectorType().getShape());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is newDestDistShape equivalent to the shape of distributedType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. I removed it. thanks.

newSourceDistShape[sourceDistributedDim] =
distributedType.getDimSize(destDistributedDim);
newDestDistShape[destDistributedDim] =
distributedType.getDimSize(destDistributedDim);
auto newSourceTy =
VectorType::get(newSourceDistShape, distributedType.getElementType());
auto newDestTy =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is newDestTy the same as the distributedType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

VectorType::get(newDestDistShape, distributedType.getElementType());
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
{newSourceTy, newDestTy}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
auto distributedSource = newWarpOp->getResult(newRetIndices[0]);
auto distributedDest = newWarpOp->getResult(newRetIndices[1]);
// Create a new insert strided slice op that inserts distributed source into
// distributed dest.
Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
insertOp.getLoc(), distributedDest.getType(), distributedSource,
distributedDest, insertOp.getOffsets(), insertOp.getStrides());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
return success();
}
};

/// Sink out extract_strided_slice op feeding into a warp op yield.
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
/// ...
/// %src = ... : vector<32x16xf32>
/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
/// strides = [1] : vector<32x16xf32> to vector<16x16xf32>
/// gpu.yield %extract : vector<16x16xf32>
/// }
/// ```
/// To
/// ````
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
/// ...
/// %src = ... : vector<32x16xf32>
/// gpu.yield %src : vector<32x16xf32>
/// }
/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
/// strides = [1] : vector<32x1xf32> to vector<16x1xf32>
/// ```
/// NOTE: Current support assumes that the extraction happens only on non
/// distributed dimensions (does not require cross lane communication).
struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp =
operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
auto distributedType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
// Distributed type must be 2D or higher.
// TODO: Support 1D distributed types.
if (distributedType.getRank() < 2)
return rewriter.notifyMatchFailure(
extractOp, "result vector type must be 2D or higher");

// Find the distributed dimension. There should be exactly one.
auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
assert(distributedDim != -1 && "could not find distributed dimension");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about return failure or notifyMatchFailure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed above

(void)distributedDim;

// Distributed dimension must be fully extracted.
// TODO: Partial extraction from distributed dimension require cross lane
// communication.
if (distributedDim < static_cast<int64_t>(extractOp.getSizes().size())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider giving a proper name for this expression to improve readability "static_cast<int64_t>(extractOp.getSizes().size())". Something like extractedVecRank

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to extractedDimsRank

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt like numOfExtractedDims is a more appropriate name. so changed it again.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about "else" case here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. Else case here means that distributed dimension is already fully extracted. So we are good to go anyway. We need a check if the distributed dim is included in the extracted dims. in vector.extract_strided op only the first k dims of an n-D vector can be partially extracted. remaining last n-k dims are fully extracted by default. here n >= k.

int64_t distributedDimOffset =
llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
.getInt();
int64_t distributedDimSize =
llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
.getInt();
if (distributedDimOffset != 0 ||
distributedDimSize != yieldedType.getDimSize(distributedDim))
return rewriter.notifyMatchFailure(
extractOp, "distributed dimension must be fully extracted");
}
SmallVector<int64_t> newDistributedShape(
extractOp.getSourceVectorType().getShape());
newDistributedShape[distributedDim] =
distributedType.getDimSize(distributedDim);
auto newDistributedType =
VectorType::get(newDistributedShape, distributedType.getElementType());
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
extractOp.getSizes(), [](Attribute attr) { return attr; });
// Update the distributed sizes to match the distributed type.
if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
distributedType.getDimSize(distributedDim));

// Create a new extract strided slice op that extracts from the
// distributed vector.
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
extractOp.getLoc(), distributedType, distributedVec,
extractOp.getOffsets(),
ArrayAttr::get(rewriter.getContext(), distributedSizes),
extractOp.getStrides());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
}
};

/// Pattern to move out vector.extract of single element vector. Those don't
/// need to be distributed and can just be propagated outside of the region.
struct WarpOpExtract : public WarpDistributionPattern {
Expand Down Expand Up @@ -1122,15 +1329,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
auto distributedType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t distributedDim = -1;
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
// Keep this assert here in case WarpExecuteOnLane0Op gets extended to
// support distributing multiple dimensions in the future.
assert(distributedDim == -1 && "found multiple distributed dims");
distributedDim = i;
}
}
int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
assert(distributedDim != -1 && "could not find distributed dimension");
(void)distributedDim;

Expand Down Expand Up @@ -1764,7 +1963,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
Expand Down
80 changes: 80 additions & 0 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,86 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
return %r : vector<4x96xf32>
}

// -----
// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_outer(
// CHECK-RPOP-SAME: %[[LANEID:.*]]: index
// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<64x1xf32>) {
// CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<64x32xf32>
// CHECK-PROP: gpu.yield %[[VEC]] : vector<64x32xf32>
// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
// CHECK-PROP-SAME: {offsets = [8], sizes = [24], strides = [1]} : vector<64x1xf32> to vector<24x1xf32>
// CHECK-PROP: return %[[EXTRACT]] : vector<24x1xf32>
func.func @vector_extract_strided_slice_2d_distr_outer(%laneid: index) -> (vector<24x1xf32>) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<24x1xf32>) {
%0 = "some_def"() : () -> (vector<64x32xf32>)
%1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [24], strides = [1]}
: vector<64x32xf32> to vector<24x32xf32>
gpu.yield %1 : vector<24x32xf32>
}
return %r : vector<24x1xf32>
}

// -----
// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_inner(
// CHECK-PROP-SAME: %[[LANEID:.*]]: index
// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x64xf32>) {
// CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<32x64xf32>
// CHECK-PROP: gpu.yield %[[VEC]] : vector<32x64xf32>
// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
// CHECK-PROP-SAME: {offsets = [0, 12], sizes = [1, 8], strides = [1, 1]} : vector<1x64xf32> to vector<1x8xf32>
// CHECK-PROP: return %[[EXTRACT]] : vector<1x8xf32>
func.func @vector_extract_strided_slice_2d_distr_inner(%laneid: index) -> (vector<1x8xf32>) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x8xf32>) {
%0 = "some_def"() : () -> (vector<32x64xf32>)
%1 = vector.extract_strided_slice %0 { offsets = [0, 12], sizes = [32, 8], strides = [1, 1]}
: vector<32x64xf32> to vector<32x8xf32>
gpu.yield %1 : vector<32x8xf32>
}
return %r : vector<1x8xf32>
}

// -----
// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_1d_to_2d(
// CHECK-PROP-SAME: %[[LANEID:.*]]: index)
// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}} -> (vector<1xf32>, vector<64x1xf32>) {
// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<32xf32>
// CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
// CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<32xf32>, vector<64x32xf32>
// CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1
// CHECK-PROP-SAME: {offsets = [18, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32>
// CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
func.func @vector_insert_strided_slice_1d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
%0 = "some_def"() : () -> (vector<32xf32>)
%1 = "some_def"() : () -> (vector<64x32xf32>)
%2 = vector.insert_strided_slice %0, %1 { offsets = [18, 0], strides = [1]}
: vector<32xf32> into vector<64x32xf32>
gpu.yield %2 : vector<64x32xf32>
}
return %r : vector<64x1xf32>
}

// -----
// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_2d_to_2d(
// CHECK-PROP-SAME: %[[LANEID:.*]]: index)
// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<16x1xf32>, vector<64x1xf32>) {
// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<16x32xf32>
// CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
// CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<16x32xf32>, vector<64x32xf32>
// CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1 {offsets = [36, 0], strides = [1, 1]} :
// CHECK-PROP-SAME: vector<16x1xf32> into vector<64x1xf32>
// CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
%0 = "some_def"() : () -> (vector<16x32xf32>)
%1 = "some_def"() : () -> (vector<64x32xf32>)
%2 = vector.insert_strided_slice %0, %1 { offsets = [36, 0], strides = [1, 1]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should restrict the offset along the distribution dim to be multiple of subgroup size. For example, offsets = [36, 1] should be rejected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this version, distributed dimension is fully inserted (offset is always 0). I will add support for other cases in separate PRs.
Example:

func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x2xf32>) {
  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x2xf32>) {
    %0 = "some_def"() : () -> (vector<16x32xf32>)
    %1 = "some_def"() : () -> (vector<64x64xf32>)
    %2 = vector.insert_strided_slice %0, %1 { offsets = [36, 1],  strides = [1, 1]}
      : vector<16x32xf32> into vector<64x64xf32>
    gpu.yield %2 : vector<64x64xf32>
  }
  return %r : vector<64x2xf32>
}

Lowering filters out this case by checking,

    // Distributed dimension must be fully inserted.
    if (srcType.getDimSize(sourceDistributedDim) !=
        destType.getDimSize(destDistributedDim))
      return rewriter.notifyMatchFailure(
          insertOp, "distributed dimension must be fully inserted");

: vector<16x32xf32> into vector<64x32xf32>
gpu.yield %2 : vector<64x32xf32>
}
return %r : vector<64x1xf32>
}

// -----

// Make sure that all operands of the transfer_read op are properly propagated.
Expand Down
Loading