Skip to content

Commit 93b41a4

Browse files
committed
Merge remote-tracking branch 'origin/feature/fused-ops' into bump_to_9387fd96
2 parents 6928d4a + 42131ee commit 93b41a4

File tree

14 files changed

+376
-192
lines changed

14 files changed

+376
-192
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#define MLIR_CONVERSION_PASSES
1111

1212
include "mlir/Pass/PassBase.td"
13-
13+
include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
1414

1515
//===----------------------------------------------------------------------===//
1616
// ToLLVM
@@ -1436,10 +1436,32 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14361436
"bool", /*default=*/"false",
14371437
"Enables the use of X86Vector dialect while lowering the vector "
14381438
"dialect.">,
1439-
Option<"vectorTransformsOptions", "vector-transform-options",
1440-
"vector::VectorTransformsOptions",
1441-
/*default=*/"vector::VectorTransformsOptions()",
1442-
"Options to lower some operations like contractions and transposes.">,
1439+
Option<"vectorContractLowering", "vector-contract-lowering",
1440+
"vector::VectorContractLowering",
1441+
/*default=*/"vector::VectorContractLowering::Dot",
1442+
VectorContractLoweringAttr.summary, [{::llvm::cl::values(
1443+
clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
1444+
"Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
1445+
clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
1446+
"Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
1447+
clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
1448+
"Lower to `vector.outerproduct`."),
1449+
clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
1450+
"Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.")
1451+
)}]>,
1452+
Option<"vectorTransposeLowering", "vector-transpose-lowering",
1453+
"vector::VectorTransposeLowering",
1454+
/*default=*/"vector::VectorTransposeLowering::EltWise",
1455+
VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
1456+
clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
1457+
"Lower transpose into element-wise extract and inserts (default)"),
1458+
clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
1459+
"Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
1460+
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
1461+
"Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
1462+
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
1463+
"Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
1464+
)}]>,
14431465
];
14441466
}
14451467

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
1010
#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
1111

12+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1213
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1314

1415
namespace mlir {
@@ -47,7 +48,8 @@ namespace vector {
4748
/// Progressively lower a `vector.contract` with row-major matmul semantics to
4849
/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
4950
void populateVectorContractLoweringPatterns(
50-
RewritePatternSet &patterns, VectorTransformsOptions options,
51+
RewritePatternSet &patterns,
52+
VectorContractLowering vectorContractLoweringOption,
5153
PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
5254

5355
/// Populate the pattern set with the following patterns:
@@ -142,9 +144,10 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
142144
///
143145
/// [TransposeOp2DToShuffleLowering]
144146
///
145-
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
146-
VectorTransformsOptions options,
147-
PatternBenefit benefit = 1);
147+
void populateVectorTransposeLoweringPatterns(
148+
RewritePatternSet &patterns,
149+
VectorTransposeLowering vectorTransposeLowering,
150+
PatternBenefit benefit = 1);
148151

149152
/// Populate the pattern set with the following patterns:
150153
///

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
6969
populateVectorToVectorCanonicalizationPatterns(patterns);
7070
populateVectorBitCastLoweringPatterns(patterns);
7171
populateVectorBroadcastLoweringPatterns(patterns);
72-
populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions);
72+
populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
7373
populateVectorMaskOpLoweringPatterns(patterns);
7474
populateVectorShapeCastLoweringPatterns(patterns);
7575
populateVectorInterleaveLoweringPatterns(patterns);
76-
populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions);
76+
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
7777
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
7878
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
7979
populateVectorMaskMaterializationPatterns(patterns,

mlir/lib/Dialect/PDL/IR/Builtins.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,15 +274,11 @@ LogicalResult static binaryOp(PatternRewriter &rewriter, PDLResultList &results,
274274
"encounter an unsupported binary operator.");
275275
}
276276

277-
if (operationStatus != APFloat::opOK) {
278-
if (operationStatus != APFloat::opInexact)
279-
return failure();
280-
281-
emitWarning(rewriter.getUnknownLoc())
282-
<< "Binary arithmetic operation between " << lhsVal.convertToFloat()
283-
<< " and " << rhsVal.convertToFloat()
284-
<< " produced an inexact result";
277+
if (operationStatus != APFloat::opOK &&
278+
operationStatus != APFloat::opInexact) {
279+
return failure();
285280
}
281+
286282
results.push_back(rewriter.getFloatAttr(floatType, resultVal));
287283
return success();
288284
}

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,9 +1374,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
13741374
// further transformations to canonicalize/cancel.
13751375
{
13761376
RewritePatternSet patterns(context);
1377-
auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
1378-
vector::VectorTransposeLowering::EltWise);
1379-
vector::populateVectorTransposeLoweringPatterns(patterns, options);
1377+
vector::populateVectorTransposeLoweringPatterns(
1378+
patterns, vector::VectorTransposeLowering::EltWise);
13801379
vector::populateVectorShapeCastLoweringPatterns(patterns);
13811380
if (failed(applyPatternsGreedily(op, std::move(patterns))))
13821381
return failure();

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,50 +1079,8 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
10791079
MLIRContext *context, ::std::optional<Location> location,
10801080
SliceOp::Adaptor adaptor,
10811081
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1082-
1083-
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1084-
SmallVector<int64_t> start;
1085-
SmallVector<int64_t> size;
1086-
1087-
if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
1088-
!tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
1089-
auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1090-
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1091-
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1092-
return success();
1093-
}
1094-
1095-
// if size[i] is -1, all remaining elements in dimension i are included
1096-
// in the slice, similar to TF.
1097-
ShapeAdaptor inputShape(adaptor.getInput1().getType());
1098-
// initialize outputShape to all unknown
1099-
SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
1100-
if (inputShape.hasRank()) {
1101-
for (size_t i = 0; i < size.size(); i++) {
1102-
if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1103-
(ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1104-
start[i] < inputShape.getDimSize(i))) {
1105-
// size[i] is not 0 and not < -1, and start[i] is in valid range
1106-
if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1107-
// input shape has unknown dim[i] - only valid if size[i] > 0
1108-
if (size[i] > 0) {
1109-
outputShape[i] = size[i];
1110-
}
1111-
} else {
1112-
// input shape has known dim[i]
1113-
if (size[i] == -1) {
1114-
outputShape[i] = inputShape.getDimSize(i) - start[i];
1115-
} else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1116-
// start[i] + size[i] is within bound of input shape's dim[i]
1117-
outputShape[i] = size[i];
1118-
}
1119-
}
1120-
}
1121-
}
1122-
} else {
1123-
outputShape = convertToMlirShape(size);
1124-
}
1125-
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1082+
inferredReturnShapes.push_back(
1083+
ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
11261084
return success();
11271085
}
11281086

@@ -1131,7 +1089,7 @@ LogicalResult tosa::SliceOp::verify() {
11311089
if (!inputType)
11321090
return success();
11331091

1134-
auto startShapeRank =
1092+
ShapedTypeComponents(convertToMlirShape(size)));
11351093
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
11361094
if (inputType.getRank() != startShapeRank)
11371095
return emitOpError(

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,6 +2120,120 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
21202120
const bool aggressiveReduceConstant;
21212121
};
21222122

2123+
template <typename ElementStorageType>
2124+
DenseElementsAttr
2125+
concatenateAttrs(const ShapedType outputType, ArrayRef<ElementsAttr> inputAttrs,
2126+
const uint32_t concatAxis, PatternRewriter &rewriter,
2127+
const Type elementType) {
2128+
2129+
static_assert(std::is_same<ElementStorageType, APInt>::value ||
2130+
std::is_same<ElementStorageType, APFloat>::value,
2131+
"ElementStorageType must be either APInt or APFloat");
2132+
2133+
SmallVector<ElementStorageType> resultValues;
2134+
if constexpr (std::is_same<ElementStorageType, APInt>::value) {
2135+
resultValues.resize_for_overwrite(outputType.getNumElements());
2136+
} else {
2137+
resultValues.resize(
2138+
outputType.getNumElements(),
2139+
APFloat::getZero(cast<FloatType>(elementType).getFloatSemantics()));
2140+
}
2141+
const auto outputShape = outputType.getShape();
2142+
2143+
int64_t concatDimOffset = 0;
2144+
for (const auto &inputAttr : inputAttrs) {
2145+
const auto inputShape = cast<ShapedType>(inputAttr.getType()).getShape();
2146+
const auto inputValues = inputAttr.getValues<ElementStorageType>();
2147+
2148+
for (const auto &[inputLinearIdx, val] : llvm::enumerate(inputValues)) {
2149+
// TODO: Could be optimized to work on slices instead of single value
2150+
SmallVector<int64_t> multiDimIndex =
2151+
offsetToIndex(inputShape, inputLinearIdx);
2152+
multiDimIndex[concatAxis] += concatDimOffset;
2153+
2154+
const int64_t outputLinearIndex =
2155+
indexToOffset(outputShape, multiDimIndex);
2156+
resultValues[outputLinearIndex] = val;
2157+
}
2158+
concatDimOffset += inputShape[concatAxis];
2159+
}
2160+
return DenseElementsAttr::get(outputType, resultValues);
2161+
}
2162+
2163+
struct TosaFoldConstantConcat : public TosaFoldConstantBase<tosa::ConcatOp> {
2164+
using TosaFoldConstantBase::TosaFoldConstantBase;
2165+
2166+
LogicalResult matchAndRewrite(tosa::ConcatOp op,
2167+
PatternRewriter &rewriter) const override {
2168+
auto inputs = op->getOperands();
2169+
const uint32_t concatAxis = op.getAxis();
2170+
const auto outputType = cast<ShapedType>(op.getType());
2171+
if (!outputType.hasStaticShape()) {
2172+
return rewriter.notifyMatchFailure(
2173+
op, "Output type must have static shape for concat folding.");
2174+
}
2175+
if (llvm::any_of(inputs, [](Value v) {
2176+
return !cast<ShapedType>(v.getType()).hasStaticShape();
2177+
})) {
2178+
return rewriter.notifyMatchFailure(
2179+
op, "All inputs to ConcatOp must have static shape for folding.");
2180+
}
2181+
2182+
const Type elementType = outputType.getElementType();
2183+
if (!elementType.isIntOrIndexOrFloat()) {
2184+
// Sanity check, this should always be the case
2185+
return rewriter.notifyMatchFailure(
2186+
op, "Output element type must be int, index, or float for folding.");
2187+
}
2188+
2189+
SmallVector<ElementsAttr> inputAttrs;
2190+
inputAttrs.reserve(inputs.size());
2191+
2192+
for (Value inputVal : inputs) {
2193+
ElementsAttr inputAsAttr;
2194+
if (!matchPattern(inputVal, m_Constant(&inputAsAttr))) {
2195+
// TODO: This could be extended to handle partial non-const inputs
2196+
return rewriter.notifyMatchFailure(
2197+
op, "All inputs to ConcatOp must be constant for folding.");
2198+
}
2199+
2200+
if (inputAsAttr.isSplat()) {
2201+
const ShapedType inputType = cast<ShapedType>(inputAsAttr.getType());
2202+
if (isa<IntegerType>(elementType)) {
2203+
inputAsAttr = DenseElementsAttr::get(
2204+
inputType, inputAsAttr.getSplatValue<APInt>());
2205+
} else {
2206+
inputAsAttr = DenseElementsAttr::get(
2207+
inputType, inputAsAttr.getSplatValue<APFloat>());
2208+
}
2209+
}
2210+
if (foldSplatOrSingleUseOnly && !inputVal.hasOneUse() &&
2211+
!inputAsAttr.isSplat()) {
2212+
return rewriter.notifyMatchFailure(
2213+
op, "Concat folding heuristic: non-splat constant inputs must have "
2214+
"only a single use.");
2215+
}
2216+
inputAttrs.push_back(inputAsAttr);
2217+
}
2218+
2219+
DenseElementsAttr resultAttr;
2220+
if (auto intType = dyn_cast<IntegerType>(elementType)) {
2221+
// TODO: This could be optimized to not go to APInt if the int size
2222+
// matches c++ native types
2223+
resultAttr = concatenateAttrs<APInt>(outputType, inputAttrs, concatAxis,
2224+
rewriter, elementType);
2225+
} else {
2226+
resultAttr = concatenateAttrs<APFloat>(outputType, inputAttrs, concatAxis,
2227+
rewriter, elementType);
2228+
}
2229+
2230+
assert(resultAttr && "Result attribute should not be null.");
2231+
2232+
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
2233+
return success();
2234+
}
2235+
};
2236+
21232237
} // namespace
21242238

21252239
void mlir::tosa::populateTosaFoldConstantPatterns(
@@ -2167,6 +2281,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
21672281
patterns.add<TosaFoldConstantPad>(ctx, options.foldSplatOrSingleUseOnly);
21682282
patterns.add<TosaFoldConstantSlice>(ctx, options.foldSplatOrSingleUseOnly);
21692283
patterns.add<TosaFoldConstantMatMul>(ctx, options.foldSplatOrSingleUseOnly);
2284+
patterns.add<TosaFoldConstantConcat>(ctx, options.foldSplatOrSingleUseOnly);
21702285
if (options.enableTileFolding)
21712286
patterns.add<TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly);
21722287
}

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,11 @@ SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
197197

198198
// AMD: Picked from torch-mlir 12250739bfe85b702f9503cad45c2e535ea8eb18
199199
// Get accumulator type for TOSA convolution ops
200-
LogicalResult mlir::tosa ::getConvOpsAccType(PatternRewriter &rewriter,
201-
RankedTensorType inputTy,
202-
RankedTensorType weightTy,
203-
RankedTensorType outputTy,
204-
TypeAttr &accType) {
200+
LogicalResult mlir::tosa::getConvOpsAccType(PatternRewriter &rewriter,
201+
RankedTensorType inputTy,
202+
RankedTensorType weightTy,
203+
RankedTensorType outputTy,
204+
TypeAttr &accType) {
205205
auto inputElemTy = inputTy.getElementType();
206206
auto weightElemTy = weightTy.getElementType();
207207
auto outputElemTy = outputTy.getElementType();
@@ -231,8 +231,8 @@ LogicalResult mlir::tosa ::getConvOpsAccType(PatternRewriter &rewriter,
231231
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
232232
outputElemTy.isInteger(48)) {
233233
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
234-
} else if ((isa<Float8E4M3FNType>(inputElemTy) &&
235-
isa<Float8E4M3FNType>(weightElemTy) && outputElemTy.isF16()) ||
234+
} else if ((isa<Float8E4M3Type>(inputElemTy) &&
235+
isa<Float8E4M3Type>(weightElemTy) && outputElemTy.isF16()) ||
236236
(isa<Float8E5M2Type>(inputElemTy) &&
237237
isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16())) {
238238
accType = mlir::TypeAttr::get(rewriter.getF16Type());

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
102102

103103
void transform::ApplyLowerContractionPatternsOp::populatePatterns(
104104
RewritePatternSet &patterns) {
105-
vector::VectorTransformsOptions vectorTransformOptions;
106-
vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
107-
populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
105+
populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
108106
/*benefit=*/1,
109107
/*disableOuterProductLowering=*/true);
110108
}
@@ -161,9 +159,8 @@ void transform::ApplyLowerTransferPatternsOp::populatePatterns(
161159

162160
void transform::ApplyLowerTransposePatternsOp::populatePatterns(
163161
RewritePatternSet &patterns) {
164-
vector::populateVectorTransposeLoweringPatterns(
165-
patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
166-
getLoweringStrategy()));
162+
vector::populateVectorTransposeLoweringPatterns(patterns,
163+
getLoweringStrategy());
167164
if (getAvx2LoweringStrategy()) {
168165
auto avx2LoweringOptions =
169166
x86vector::avx2::LoweringOptions().setTransposeOptions(

0 commit comments

Comments
 (0)