Skip to content

Commit 59faad1

Browse files
Tai78641jorickert
authored andcommitted
[mlir][tosa] Change zero points of convolution ops to required inputs (llvm#127679)
This patch changes the input_zp and weight_zp for convolution operators to be required inputs in order to align with the TOSA Spec 1.0. Convolution operators affected are: CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D. Signed-off-by: Tai Ly <[email protected]>
1 parent 8e382b9 commit 59faad1

File tree

15 files changed

+528
-1441
lines changed

15 files changed

+528
-1441
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 0 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -161,112 +161,6 @@ namespace tosa {
161161
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
162162
Type srcElemType, int64_t zp = 0);
163163

164-
// Get zero point value from the attribute argument.
165-
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
166-
167-
// Verify if zero point falls into valid range.
168-
template <typename T>
169-
LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
170-
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
171-
!std::is_same_v<T, DepthwiseConv2DOp> &&
172-
!std::is_same_v<T, TransposeConv2DOp>) {
173-
return failure();
174-
}
175-
176-
if (!zpElemType.isIntOrFloat())
177-
return failure();
178-
179-
if (!zpElemType.isInteger(8) && zp != 0)
180-
return failure();
181-
182-
if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
183-
return failure();
184-
185-
if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
186-
return failure();
187-
188-
return success();
189-
}
190-
191-
// Helper type trait to determine if an operation is a tosa convolution.
192-
template <typename Op>
193-
struct IsTosaConv : std::false_type {};
194-
195-
template <>
196-
struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
197-
template <>
198-
struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
199-
template <>
200-
struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
201-
template <>
202-
struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};
203-
204-
template <typename Op>
205-
constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;
206-
207-
// Helper struct to hold the zero points of a TOSA convolution operation as
208-
// named 64-bit integer fields.
209-
struct ConvZpPair {
210-
ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
211-
: inputZp(inputZp), weightZp(weightZp) {}
212-
std::int64_t inputZp;
213-
std::int64_t weightZp;
214-
};
215-
216-
// Helper function which attempts to extract the zero points from a TOSA
217-
// convolution by matching them against defining ops which should be tosa.const
218-
// operations.
219-
//
220-
// There are three possible results:
221-
// 1. Failed to extract the zero-points i.e. they should exist and don't or they
222-
// do exist but are invalid.
223-
// 2. Succeeded in extracting zero-points.
224-
// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
225-
// convolution.
226-
using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
227-
template <typename TosaConvOp>
228-
std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
229-
extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
230-
// Strictly speaking the base TOSA spec requires that for non int8 types
231-
// zero points must be zero. However, in the dialect these operands are
232-
// optional and only required for int8. They have no semantic meaning for
233-
// non-quantized types and can therefore be safely ignored. This is case 3.
234-
if (auto opElementTY =
235-
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
236-
!opElementTY.isInteger(8))
237-
return FailOrMaybeZP(std::nullopt);
238-
239-
// Now we know we should have a zero point check it is valid.
240-
if (!op.getInputZp())
241-
return rewriter.notifyMatchFailure(op, "missing input zero point");
242-
243-
// Helper to extract the zero point by matching its definition against a
244-
// constant.
245-
auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
246-
ElementsAttr zpAttr;
247-
if (!matchPattern(zpValue, m_Constant(&zpAttr)))
248-
return std::nullopt;
249-
250-
int64_t zp;
251-
if (tosa::getZeroPoint(zpAttr, zp).failed())
252-
return std::nullopt;
253-
254-
return std::make_optional(zp);
255-
};
256-
257-
auto maybeInputZp = extractZeroPoint(op.getInputZp());
258-
if (!maybeInputZp)
259-
return rewriter.notifyMatchFailure(op, "unable to extract input zp");
260-
261-
if (!op.getWeightZp())
262-
return rewriter.notifyMatchFailure(op, "missing weight zero point");
263-
264-
auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
265-
if (!maybeWeightZp)
266-
return rewriter.notifyMatchFailure(op, "unable to extract weight zp");
267-
268-
return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
269-
}
270164
} // namespace tosa
271165
} // namespace mlir
272166

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
105105
Tosa_Tensor4D:$input,
106106
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
107107
Tosa_Tensor1D:$bias,
108-
Optional<Tosa_ScalarTensor>:$input_zp,
109-
Optional<Tosa_ScalarTensor>:$weight_zp,
108+
Tosa_ScalarTensor:$input_zp,
109+
Tosa_ScalarTensor:$weight_zp,
110+
110111
Tosa_IntArrayAttr4:$pad,
111112
Tosa_IntArrayAttr2:$stride,
112113
Tosa_IntArrayAttr2:$dilation,
@@ -118,6 +119,14 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
118119
Tosa_Tensor4D:$output
119120
);
120121

122+
123+
let extraClassDeclaration = [{
124+
LogicalResult getInputZeroPoint(int64_t &zp);
125+
LogicalResult getWeightZeroPoint(int64_t &zp);
126+
LogicalResult verifyInputZeroPoint(int64_t zp);
127+
LogicalResult verifyWeightZeroPoint(int64_t zp);
128+
}];
129+
121130
let builders = [Tosa_ConvOpQuantInfoBuilder];
122131
let hasVerifier = 1;
123132
}
@@ -136,8 +145,9 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
136145
Tosa_Tensor5D:$input,
137146
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
138147
Tosa_Tensor1D:$bias,
139-
Optional<Tosa_ScalarTensor>:$input_zp,
140-
Optional<Tosa_ScalarTensor>:$weight_zp,
148+
Tosa_ScalarTensor:$input_zp,
149+
Tosa_ScalarTensor:$weight_zp,
150+
141151
Tosa_IntArrayAttr6:$pad,
142152
Tosa_IntArrayAttr3:$stride,
143153
Tosa_IntArrayAttr3:$dilation,
@@ -149,6 +159,14 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
149159
Tosa_Tensor5D:$output
150160
);
151161

162+
163+
let extraClassDeclaration = [{
164+
LogicalResult getInputZeroPoint(int64_t &zp);
165+
LogicalResult getWeightZeroPoint(int64_t &zp);
166+
LogicalResult verifyInputZeroPoint(int64_t zp);
167+
LogicalResult verifyWeightZeroPoint(int64_t zp);
168+
}];
169+
152170
let builders = [Tosa_ConvOpQuantInfoBuilder];
153171
let hasVerifier = 1;
154172
}
@@ -168,8 +186,9 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
168186
Tosa_Tensor4D:$input,
169187
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
170188
Tosa_Tensor1D:$bias,
171-
Optional<Tosa_ScalarTensor>:$input_zp,
172-
Optional<Tosa_ScalarTensor>:$weight_zp,
189+
Tosa_ScalarTensor:$input_zp,
190+
Tosa_ScalarTensor:$weight_zp,
191+
173192
Tosa_IntArrayAttr4:$pad,
174193
Tosa_IntArrayAttr2:$stride,
175194
Tosa_IntArrayAttr2:$dilation,
@@ -181,6 +200,14 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
181200
Tosa_Tensor4D:$output
182201
);
183202

203+
204+
let extraClassDeclaration = [{
205+
LogicalResult getInputZeroPoint(int64_t &zp);
206+
LogicalResult getWeightZeroPoint(int64_t &zp);
207+
LogicalResult verifyInputZeroPoint(int64_t zp);
208+
LogicalResult verifyWeightZeroPoint(int64_t zp);
209+
}];
210+
184211
let builders = [Tosa_ConvOpQuantInfoBuilder];
185212
let hasVerifier = 1;
186213
}
@@ -356,8 +383,9 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
356383
Tosa_Tensor4D:$input,
357384
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
358385
Tosa_Tensor1D:$bias,
359-
Optional<Tosa_ScalarTensor>:$input_zp,
360-
Optional<Tosa_ScalarTensor>:$weight_zp,
386+
Tosa_ScalarTensor:$input_zp,
387+
Tosa_ScalarTensor:$weight_zp,
388+
361389
Tosa_IntArrayAttr4:$out_pad,
362390
Tosa_IntArrayAttr2:$stride,
363391
Tosa_IntArrayAttr4:$out_shape,
@@ -369,6 +397,14 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
369397
Tosa_Tensor4D:$output
370398
);
371399

400+
401+
let extraClassDeclaration = [{
402+
LogicalResult getInputZeroPoint(int64_t &zp);
403+
LogicalResult getWeightZeroPoint(int64_t &zp);
404+
LogicalResult verifyInputZeroPoint(int64_t zp);
405+
LogicalResult verifyWeightZeroPoint(int64_t zp);
406+
}];
407+
372408
let builders = [Tosa_TransConvOpQuantInfoBuilder];
373409
let hasVerifier = 1;
374410
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,21 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
288288
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
289289
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
290290

291-
auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
292-
if (llvm::failed(failureOrMaybeZps))
293-
return failure();
291+
// Get and verify zero points.
292+
int64_t inputZpVal;
293+
int64_t weightZpVal;
294+
295+
if (op.getInputZeroPoint(inputZpVal).failed() ||
296+
op.getWeightZeroPoint(weightZpVal).failed())
297+
return rewriter.notifyMatchFailure(
298+
op, "bail out if zero points cannot statically be determined");
299+
300+
if (op.verifyInputZeroPoint(inputZpVal).failed() ||
301+
op.verifyWeightZeroPoint(weightZpVal).failed())
302+
return rewriter.notifyMatchFailure(
303+
op, "zero point must be zero for non-int8 integer types");
294304

295-
auto maybeZps = failureOrMaybeZps.value();
305+
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
296306

297307
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
298308
return rewriter.notifyMatchFailure(
@@ -318,19 +328,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
318328

319329
// Apply padding as necessary.
320330
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
321-
if (maybeZps) {
331+
if (hasZp) {
322332
int64_t intMin =
323333
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
324334
.getSExtValue();
325335
int64_t intMax =
326336
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
327337
.getSExtValue();
328338

329-
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
339+
if (inputZpVal < intMin || inputZpVal > intMax)
330340
return rewriter.notifyMatchFailure(
331341
op, "tosa.conv op quantization has zp outside of input range");
332342

333-
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
343+
zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
334344
}
335345

336346
llvm::SmallVector<int64_t> pad;
@@ -343,8 +353,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
343353
// For 2D convolutions, we need to check if the target convolution op
344354
// wants a HWCF kernel layout.
345355
bool wantHwcf =
346-
maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
347-
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
356+
hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
357+
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
348358
if (wantHwcf) {
349359
// Transpose the kernel to match dimension ordering of the linalg
350360
// convolution operation.
@@ -405,9 +415,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
405415
Value broadcastBias =
406416
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
407417

408-
if (maybeZps) {
409-
auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
410-
auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
418+
if (hasZp) {
419+
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
420+
auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
411421

412422
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
413423
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -470,31 +480,40 @@ class DepthwiseConvConverter
470480
/*inputSizeDims=*/{1, 2},
471481
/*kernelSizeDims=*/{0, 1}, rewriter);
472482

473-
auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
474-
if (llvm::failed(failureOrMaybeZps))
475-
return failure();
483+
// Get and verify zero points.
484+
int64_t inputZpVal;
485+
int64_t weightZpVal;
486+
487+
if (op.getInputZeroPoint(inputZpVal).failed() ||
488+
op.getWeightZeroPoint(weightZpVal).failed())
489+
return rewriter.notifyMatchFailure(
490+
op, "bail out if zero points cannot statically be determined");
476491

477-
auto maybeZps = failureOrMaybeZps.value();
492+
if (op.verifyInputZeroPoint(inputZpVal).failed() ||
493+
op.verifyWeightZeroPoint(weightZpVal).failed())
494+
return rewriter.notifyMatchFailure(
495+
op, "zero point must be zero for non-int8 integer types");
478496

497+
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
479498
auto weightShape = weightTy.getShape();
480499
auto resultShape = resultTy.getShape();
481500

482501
// Apply padding as necessary.
483502
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
484-
if (maybeZps) {
503+
if (hasZp) {
485504
int64_t intMin =
486505
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
487506
.getSExtValue();
488507
int64_t intMax =
489508
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
490509
.getSExtValue();
491510

492-
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
511+
if (inputZpVal < intMin || inputZpVal > intMax)
493512
return rewriter.notifyMatchFailure(
494513
op, "tosa.depthwise_conv op quantization has zp outside of input "
495514
"range");
496515

497-
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
516+
zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
498517
}
499518

500519
llvm::SmallVector<int64_t> pad;
@@ -534,7 +553,7 @@ class DepthwiseConvConverter
534553
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
535554
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
536555

537-
if (!maybeZps) {
556+
if (!hasZp) {
538557
Value conv = rewriter
539558
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
540559
loc, linalgConvTy, ValueRange{input, weight},
@@ -561,8 +580,8 @@ class DepthwiseConvConverter
561580
.getResult(0);
562581
rewriter.replaceOp(op, result);
563582
} else {
564-
IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
565-
IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
583+
IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
584+
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
566585
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
567586
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
568587
Value conv =

0 commit comments

Comments
 (0)