@@ -173,6 +173,10 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
173
173
// Helper functions to create constants.
174
174
// ----------------------------------------------------------------------------//
175
175
176
+ static Value boolCst (ImplicitLocOpBuilder &builder, bool value) {
177
+ return builder.create <arith::ConstantOp>(builder.getBoolAttr (value));
178
+ }
179
+
176
180
static Value floatCst (ImplicitLocOpBuilder &builder, float value,
177
181
Type elementType) {
178
182
assert ((elementType.isF16 () || elementType.isF32 ()) &&
@@ -1118,12 +1122,102 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
1118
1122
return success ();
1119
1123
}
1120
1124
1125
+ // Approximates erfc(x) with
1126
+ LogicalResult
1127
+ ErfcPolynomialApproximation::matchAndRewrite (math::ErfcOp op,
1128
+ PatternRewriter &rewriter) const {
1129
+ Value x = op.getOperand ();
1130
+ Type et = getElementTypeOrSelf (x);
1131
+
1132
+ if (!et.isF32 ())
1133
+ return rewriter.notifyMatchFailure (op, " only f32 type is supported." );
1134
+ std::optional<VectorShape> shape = vectorShape (x);
1135
+
1136
+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1137
+ auto bcast = [&](Value value) -> Value {
1138
+ return broadcast (builder, value, shape);
1139
+ };
1140
+
1141
+ Value trueValue = bcast (boolCst (builder, true ));
1142
+ Value zero = bcast (floatCst (builder, 0 .0f , et));
1143
+ Value one = bcast (floatCst (builder, 1 .0f , et));
1144
+ Value onehalf = bcast (floatCst (builder, 0 .5f , et));
1145
+ Value neg4 = bcast (floatCst (builder, -4 .0f , et));
1146
+ Value neg2 = bcast (floatCst (builder, -2 .0f , et));
1147
+ Value pos2 = bcast (floatCst (builder, 2 .0f , et));
1148
+ Value posInf = bcast (f32FromBits (builder, 0x7f800000u ));
1149
+ Value clampVal = bcast (floatCst (builder, 10 .0546875f , et));
1150
+
1151
+ // Get abs(x)
1152
+ Value isNegativeArg =
1153
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
1154
+ Value negArg = builder.create <arith::NegFOp>(x);
1155
+ Value a = builder.create <arith::SelectOp>(isNegativeArg, negArg, x);
1156
+ Value p = builder.create <arith::AddFOp>(a, pos2);
1157
+ Value r = builder.create <arith::DivFOp>(one, p);
1158
+ Value q = builder.create <math::FmaOp>(neg4, r, one);
1159
+ Value t = builder.create <math::FmaOp>(builder.create <arith::AddFOp>(q, one),
1160
+ neg2, a);
1161
+ Value e = builder.create <math::FmaOp>(builder.create <arith::NegFOp>(a), q, t);
1162
+ q = builder.create <math::FmaOp>(r, e, q);
1163
+
1164
+ p = bcast (floatCst (builder, -0x1 .a4a000p -12f , et)); // -4.01139259e-4
1165
+ Value c1 = bcast (floatCst (builder, -0x1 .42a260p-10f , et)); // -1.23075210e-3
1166
+ p = builder.create <math::FmaOp>(p, q, c1);
1167
+ Value c2 = bcast (floatCst (builder, 0x1 .585714p-10f , et)); // 1.31355342e-3
1168
+ p = builder.create <math::FmaOp>(p, q, c2);
1169
+ Value c3 = bcast (floatCst (builder, 0x1 .1adcc4p-07f , et)); // 8.63227434e-3
1170
+ p = builder.create <math::FmaOp>(p, q, c3);
1171
+ Value c4 = bcast (floatCst (builder, -0x1 .081b82p-07f , et)); // -8.05991981e-3
1172
+ p = builder.create <math::FmaOp>(p, q, c4);
1173
+ Value c5 = bcast (floatCst (builder, -0x1 .bc0b6ap -05f , et)); // -5.42046614e-2
1174
+ p = builder.create <math::FmaOp>(p, q, c5);
1175
+ Value c6 = bcast (floatCst (builder, 0x1 .4ffc46p-03f , et)); // 1.64055392e-1
1176
+ p = builder.create <math::FmaOp>(p, q, c6);
1177
+ Value c7 = bcast (floatCst (builder, -0x1 .540840p-03f , et)); // -1.66031361e-1
1178
+ p = builder.create <math::FmaOp>(p, q, c7);
1179
+ Value c8 = bcast (floatCst (builder, -0x1 .7bf616p-04f , et)); // -9.27639827e-2
1180
+ p = builder.create <math::FmaOp>(p, q, c8);
1181
+ Value c9 = bcast (floatCst (builder, 0x1 .1ba03ap-02f , et)); // 2.76978403e-1
1182
+ p = builder.create <math::FmaOp>(p, q, c9);
1183
+
1184
+ Value d = builder.create <math::FmaOp>(pos2, a, one);
1185
+ r = builder.create <arith::DivFOp>(one, d);
1186
+ q = builder.create <math::FmaOp>(p, r, r);
1187
+ e = builder.create <math::FmaOp>(
1188
+ builder.create <math::FmaOp>(q, builder.create <arith::NegFOp>(a), onehalf),
1189
+ pos2, builder.create <arith::SubFOp>(p, q));
1190
+ r = builder.create <math::FmaOp>(e, r, q);
1191
+
1192
+ Value s = builder.create <arith::MulFOp>(a, a);
1193
+ e = builder.create <math::ExpOp>(builder.create <arith::NegFOp>(s));
1194
+
1195
+ t = builder.create <math::FmaOp>(builder.create <arith::NegFOp>(a), a, s);
1196
+ r = builder.create <math::FmaOp>(
1197
+ r, e,
1198
+ builder.create <arith::MulFOp>(builder.create <arith::MulFOp>(r, e), t));
1199
+
1200
+ Value isNotLessThanInf = builder.create <arith::XOrIOp>(
1201
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
1202
+ trueValue);
1203
+ r = builder.create <arith::SelectOp>(isNotLessThanInf,
1204
+ builder.create <arith::AddFOp>(x, x), r);
1205
+ Value isGreaterThanClamp =
1206
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
1207
+ r = builder.create <arith::SelectOp>(isGreaterThanClamp, zero, r);
1208
+
1209
+ Value isNegative =
1210
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
1211
+ r = builder.create <arith::SelectOp>(
1212
+ isNegative, builder.create <arith::SubFOp>(pos2, r), r);
1213
+
1214
+ rewriter.replaceOp (op, r);
1215
+ return success ();
1216
+ }
1121
1217
// ----------------------------------------------------------------------------//
1122
1218
// Exp approximation.
1123
1219
// ----------------------------------------------------------------------------//
1124
-
1125
1220
namespace {
1126
-
1127
1221
Value clampWithNormals (ImplicitLocOpBuilder &builder,
1128
1222
const std::optional<VectorShape> shape, Value value,
1129
1223
float lowerBound, float upperBound) {
@@ -1667,6 +1761,11 @@ void mlir::populatePolynomialApproximateErfPattern(
1667
1761
patterns.add <ErfPolynomialApproximation>(patterns.getContext ());
1668
1762
}
1669
1763
1764
+ void mlir::populatePolynomialApproximateErfcPattern (
1765
+ RewritePatternSet &patterns) {
1766
+ patterns.add <ErfcPolynomialApproximation>(patterns.getContext ());
1767
+ }
1768
+
1670
1769
void mlir::populateMathPolynomialApproximationPatterns (
1671
1770
RewritePatternSet &patterns,
1672
1771
const MathPolynomialApproximationOptions &options) {
@@ -1680,13 +1779,14 @@ void mlir::populateMathPolynomialApproximationPatterns(
1680
1779
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1681
1780
patterns.getContext ());
1682
1781
1683
- patterns
1684
- .add <AtanApproximation, Atan2Approximation, TanhApproximation,
1685
- LogApproximation, Log2Approximation, Log1pApproximation,
1686
- ErfPolynomialApproximation, AsinPolynomialApproximation,
1687
- AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1688
- CbrtApproximation, SinAndCosApproximation<true , math::SinOp>,
1689
- SinAndCosApproximation<false , math::CosOp>>(patterns.getContext ());
1782
+ patterns.add <AtanApproximation, Atan2Approximation, TanhApproximation,
1783
+ LogApproximation, Log2Approximation, Log1pApproximation,
1784
+ ErfPolynomialApproximation, ErfcPolynomialApproximation,
1785
+ AsinPolynomialApproximation, AcosPolynomialApproximation,
1786
+ ExpApproximation, ExpM1Approximation, CbrtApproximation,
1787
+ SinAndCosApproximation<true , math::SinOp>,
1788
+ SinAndCosApproximation<false , math::CosOp>>(
1789
+ patterns.getContext ());
1690
1790
if (options.enableAvx2 ) {
1691
1791
patterns.add <RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1692
1792
patterns.getContext ());
0 commit comments