Skip to content

[mlir][tosa] Change ClampOp's min/max attributes #125197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {

let arguments = (ins
Tosa_Tensor:$input,
I64Attr:$min_int,
I64Attr:$max_int,
Tosa_FloatAttr:$min_fp,
Tosa_FloatAttr:$max_fp,
Tosa_IntOrFloatAttr:$min_val,
Tosa_IntOrFloatAttr:$max_val,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);

Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,14 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
let returnType = [{ ::mlir::APFloat }];
}

def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
"arbitrary integer attribute"> {
let storageType = [{ ::mlir::IntegerAttr }];
let returnType = [{ ::llvm::APInt }];
}

def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;

//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::ClampOp
if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
bool losesInfo = false;
APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
Expand All @@ -405,9 +405,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
auto intTy = cast<IntegerType>(elementTy);
int64_t min =
cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
int64_t max =
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();

int64_t minRepresentable = std::numeric_limits<int64_t>::min();
int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
Expand Down
129 changes: 94 additions & 35 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,12 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {

if (isa<FloatType>(inputElementType)) {
// Unlike integer types, floating point types can represent infinity.
auto minClamp = op.getMinFp();
auto maxClamp = op.getMaxFp();
bool isMin = minClamp.isInfinity() && minClamp.isNegative();
bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
auto minClamp =
llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
auto maxClamp =
llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
bool isMin = minClamp.isNegInfinity();
bool isMax = maxClamp.isInfinity();

if (isMin && isMax) {
rewriter.replaceOp(op, input);
Expand All @@ -300,8 +302,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
}

if (inputElementType.isUnsignedInteger()) {
int64_t minClamp = op.getMinInt();
int64_t maxClamp = op.getMaxInt();
int64_t minClamp =
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
int64_t maxClamp =
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();

int64_t intMin =
APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
Expand All @@ -318,8 +322,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
}

if (llvm::isa<IntegerType>(inputElementType)) {
int64_t minClamp = op.getMinInt();
int64_t maxClamp = op.getMaxInt();
int64_t minClamp =
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
int64_t maxClamp =
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();

int64_t intMin =
APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
Expand Down Expand Up @@ -374,9 +380,10 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {

LogicalResult matchAndRewrite(tosa::ClampOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();

// Check the input to the CLAMP op is itself a CLAMP.
auto clampOp =
dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
if (!clampOp)
return failure();

Expand All @@ -386,34 +393,86 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
return failure();

// Check we have intersecting ranges.
const auto opMinInt = op.getMinInt();
const auto opMaxInt = op.getMaxInt();
const auto clampOpMinInt = clampOp.getMinInt();
const auto clampOpMaxInt = clampOp.getMaxInt();
ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
if (!opRangeIntRange.intersects(clampRangeIntRange))
return failure();
auto maxValAttr = op.getMaxValAttr();
auto minValAttr = op.getMinValAttr();
auto clampOpMaxValAttr = clampOp.getMaxValAttr();
auto clampOpMinValAttr = clampOp.getMinValAttr();

const auto opMinFloat = op.getMinFp();
const auto opMaxFloat = op.getMaxFp();
const auto clampOpMinFloat = clampOp.getMinFp();
const auto clampOpMaxFloat = clampOp.getMaxFp();
ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
if (!opRangeFloatRange.intersects(clampRangeFloatRange))
return failure();
auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
inputEType = quantType.getStorageType();
}

Attribute newMinValAttr, newMaxValAttr;
if (mlir::isa<FloatType>(inputEType)) {
auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);

// Check we have intersecting ranges.
const auto opMinFloat = floatMinValAttr.getValue();
const auto opMaxFloat = floatMaxValAttr.getValue();
const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
clampOpMaxFloat);
if (!opRangeFloatRange.intersects(clampRangeFloatRange))
return failure();

// Run the transformation.
auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
} else {
assert(mlir::isa<IntegerType>(inputEType));
auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);

if (inputEType.isUnsignedInteger()) {
// Check we have intersecting ranges.
const auto opMinInt = intMinValAttr.getUInt();
const auto opMaxInt = intMaxValAttr.getUInt();
const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
clampOpMaxInt);
if (!opRangeIntRange.intersects(clampRangeIntRange))
return failure();

// Run the transformation.
auto newMinVal = std::max(opMinInt, clampOpMinInt);
auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
} else {
// Check we have intersecting ranges.
const auto opMinInt = intMinValAttr.getInt();
const auto opMaxInt = intMaxValAttr.getInt();
const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
clampOpMaxInt);
if (!opRangeIntRange.intersects(clampRangeIntRange))
return failure();

// Run the transformation.
auto newMinVal = std::max(opMinInt, clampOpMinInt);
auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
}
}

// Run the transformation.
const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
const auto minInt = std::max(opMinInt, clampOpMinInt);
const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), clampOp.getInput(),
rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
: opNanMode));
return success();
Expand Down
36 changes: 25 additions & 11 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,26 +464,40 @@ LogicalResult tosa::ClampOp::verify() {
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
inputETy = quantType.getStorageType();
}
mlir::Type maxFpType = getMaxFpAttr().getType();
mlir::Type minFpType = getMinFpAttr().getType();
mlir::Type outputETy =
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
outputETy = quantType.getStorageType();
}
unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();

if (inputETy != outputETy)
return emitOpError("input/output element types are incompatible.");

// If input datatype is float, check that the two min/max_fp attributes
// share the same type and that their type is either the same of the input's
// datatype, or a float type whose bitwidth > input datatype bitwidth.
if (!inputETy.isInteger(dataTypeBitWidth)) {
if (((maxFpType != minFpType) ||
(maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
inputETy.getIntOrFloatBitWidth())))
auto maxValAttr = getMaxValAttr();
auto minValAttr = getMinValAttr();

unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();

if (inputETy.isInteger(dataTypeBitWidth)) {
// if input datatype is integer, check that the min_val/max_val attributes
// are integer attributes, and that their type is the same as the input's
// datatype
auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
if (!intMaxValAttr || !intMinValAttr ||
(intMaxValAttr.getType() != intMinValAttr.getType()) ||
(intMaxValAttr.getType() != inputETy))
return emitOpError("min/max attributes types are incompatible with "
"input/output element types.");
} else {
// otherwise, input datatype is float, check that the min_val/max_val
// attributes share the same type and that their type is the same as the
// input's datatype
auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
if (!floatMaxValAttr || !floatMinValAttr ||
(floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
(floatMaxValAttr.getType() != inputETy))
return emitOpError("min/max attributes types are incompatible with "
"input/output element types.");
}
Expand Down
41 changes: 6 additions & 35 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: linalg.generic
// CHECK: arith.minimumf
// CHECK: arith.maximumf
%18 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
%18 = tosa.clamp %0 {min_val = 1.0 : f32, max_val = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>

// CHECK: linalg.generic
// CHECK: arith.negf
Expand Down Expand Up @@ -729,35 +729,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
// CHECK: linalg.generic
// CHECK-DAG: arith.maxsi
// CHECK-DAG: arith.minsi
%19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
%19 = tosa.clamp %0 {min_val = 1 : i32, max_val = 5 : i32} : (tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32
// CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32
// CHECK-DAG: arith.maxui %[[LB]],
// CHECK-DAG: arith.minui %[[UB]],
%u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>

// CHECK: linalg.generic
// CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32
// CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32
// CHECK-DAG: arith.maxui %[[LB]],
// CHECK-DAG: arith.minui %[[UB]],
%u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>

// CHECK: linalg.generic
// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32
// CHECK-DAG: arith.maxui %[[LB]],
// CHECK-DAG: arith.minui %[[UB]],
%u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>

// CHECK: linalg.generic
// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64
// CHECK-DAG: arith.maxui %[[LB]],
// CHECK-DAG: arith.minui %[[UB]],
%u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
%u0 = tosa.clamp %unsigned {min_val = 4 : ui32, max_val = 32 : ui32} : (tensor<1xui32>) -> tensor<1xui32>

// CHECK: linalg.generic
// CHECK: arith.trunci
Expand Down Expand Up @@ -807,15 +786,7 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// CHECK-DAG: %[[C126:.+]] = arith.constant 126
// CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
// CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
%0 = tosa.clamp %arg0 {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>

// CHECK: linalg.generic
// CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK-DAG: %[[C128:.+]] = arith.constant -128
// CHECK-DAG: %[[C127:.+]] = arith.constant 127
// CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C128]], %[[ARG1]]
// CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C127]], %[[LOWER]]
%1 = tosa.clamp %arg0 {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
%0 = tosa.clamp %arg0 {min_val = -127 : i8, max_val = 126 : i8} : (tensor<1xi8>) -> tensor<1xi8>

return
}
Expand All @@ -830,7 +801,7 @@ func.func @test_i64(%arg0: tensor<1xi64>) -> () {
// CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
// CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
// CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
%0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
%0 = tosa.clamp %arg0 {min_val = -9223372036854775808 : i64, max_val = 9223372036854775807 : i64} : (tensor<1xi64>) -> tensor<1xi64>

return
}
Expand All @@ -845,7 +816,7 @@ func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
// CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
// CHECK-DAG: %[[MIN:.+]] = arith.minimumf %[[ARG1]], %[[C6]]
// CHECK-DAG: %[[MAX:.+]] = arith.maximumf %[[MIN]], %[[C0]]
%0 = tosa.clamp %arg0 {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16>
%0 = tosa.clamp %arg0 {min_val = 0.0 : f16, max_val = 6.0 : f16} : (tensor<1xf16>) -> tensor<1xf16>

return
}
Expand Down
Loading