@@ -288,11 +288,21 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
288
288
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr ();
289
289
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr ();
290
290
291
- auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
292
- if (llvm::failed (failureOrMaybeZps))
293
- return failure ();
291
+ // Get and verify zero points.
292
+ int64_t inputZpVal;
293
+ int64_t weightZpVal;
294
+
295
+ if (op.getInputZeroPoint (inputZpVal).failed () ||
296
+ op.getWeightZeroPoint (weightZpVal).failed ())
297
+ return rewriter.notifyMatchFailure (
298
+ op, " bail out if zero points cannot statically be determined" );
299
+
300
+ if (op.verifyInputZeroPoint (inputZpVal).failed () ||
301
+ op.verifyWeightZeroPoint (weightZpVal).failed ())
302
+ return rewriter.notifyMatchFailure (
303
+ op, " zero point must be zero for non-int8 integer types" );
294
304
295
- auto maybeZps = failureOrMaybeZps. value ( );
305
+ bool hasZp = (inputZpVal != 0 ) || (weightZpVal != 0 );
296
306
297
307
if (!weightTy.hasStaticShape () || !biasTy.hasStaticShape ())
298
308
return rewriter.notifyMatchFailure (
@@ -318,19 +328,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
318
328
319
329
// Apply padding as necessary.
320
330
TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
321
- if (maybeZps ) {
331
+ if (hasZp ) {
322
332
int64_t intMin =
323
333
APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
324
334
.getSExtValue ();
325
335
int64_t intMax =
326
336
APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
327
337
.getSExtValue ();
328
338
329
- if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
339
+ if (inputZpVal < intMin || inputZpVal > intMax)
330
340
return rewriter.notifyMatchFailure (
331
341
op, " tosa.conv op quantization has zp outside of input range" );
332
342
333
- zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
343
+ zeroAttr = rewriter.getIntegerAttr (inputETy, inputZpVal );
334
344
}
335
345
336
346
llvm::SmallVector<int64_t > pad;
@@ -343,8 +353,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
343
353
// For 2D convolutions, we need to check if the target convolution op
344
354
// wants a HWCF kernel layout.
345
355
bool wantHwcf =
346
- maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
347
- : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
356
+ hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
357
+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
348
358
if (wantHwcf) {
349
359
// Transpose the kernel to match dimension ordering of the linalg
350
360
// convolution operation.
@@ -405,9 +415,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
405
415
Value broadcastBias =
406
416
linalgBroadcastAndMaybeExtSI (rewriter, loc, bias, biasEmptyTensor);
407
417
408
- if (maybeZps ) {
409
- auto iZp = rewriter.getI32IntegerAttr (maybeZps-> inputZp );
410
- auto kZp = rewriter.getI32IntegerAttr (maybeZps-> weightZp );
418
+ if (hasZp ) {
419
+ auto iZp = rewriter.getI32IntegerAttr (inputZpVal );
420
+ auto kZp = rewriter.getI32IntegerAttr (weightZpVal );
411
421
412
422
auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
413
423
auto kZpVal = rewriter.create <arith::ConstantOp>(loc, kZp );
@@ -470,31 +480,40 @@ class DepthwiseConvConverter
470
480
/* inputSizeDims=*/ {1 , 2 },
471
481
/* kernelSizeDims=*/ {0 , 1 }, rewriter);
472
482
473
- auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
474
- if (llvm::failed (failureOrMaybeZps))
475
- return failure ();
483
+ // Get and verify zero points.
484
+ int64_t inputZpVal;
485
+ int64_t weightZpVal;
486
+
487
+ if (op.getInputZeroPoint (inputZpVal).failed () ||
488
+ op.getWeightZeroPoint (weightZpVal).failed ())
489
+ return rewriter.notifyMatchFailure (
490
+ op, " bail out if zero points cannot statically be determined" );
476
491
477
- auto maybeZps = failureOrMaybeZps.value ();
492
+ if (op.verifyInputZeroPoint (inputZpVal).failed () ||
493
+ op.verifyWeightZeroPoint (weightZpVal).failed ())
494
+ return rewriter.notifyMatchFailure (
495
+ op, " zero point must be zero for non-int8 integer types" );
478
496
497
+ bool hasZp = (inputZpVal != 0 ) || (weightZpVal != 0 );
479
498
auto weightShape = weightTy.getShape ();
480
499
auto resultShape = resultTy.getShape ();
481
500
482
501
// Apply padding as necessary.
483
502
TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
484
- if (maybeZps ) {
503
+ if (hasZp ) {
485
504
int64_t intMin =
486
505
APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
487
506
.getSExtValue ();
488
507
int64_t intMax =
489
508
APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
490
509
.getSExtValue ();
491
510
492
- if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
511
+ if (inputZpVal < intMin || inputZpVal > intMax)
493
512
return rewriter.notifyMatchFailure (
494
513
op, " tosa.depthwise_conv op quantization has zp outside of input "
495
514
" range" );
496
515
497
- zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
516
+ zeroAttr = rewriter.getIntegerAttr (inputETy, inputZpVal );
498
517
}
499
518
500
519
llvm::SmallVector<int64_t > pad;
@@ -534,7 +553,7 @@ class DepthwiseConvConverter
534
553
indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
535
554
indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
536
555
537
- if (!maybeZps ) {
556
+ if (!hasZp ) {
538
557
Value conv = rewriter
539
558
.create <linalg::DepthwiseConv2DNhwcHwcmOp>(
540
559
loc, linalgConvTy, ValueRange{input, weight},
@@ -561,8 +580,8 @@ class DepthwiseConvConverter
561
580
.getResult (0 );
562
581
rewriter.replaceOp (op, result);
563
582
} else {
564
- IntegerAttr iZp = rewriter.getI32IntegerAttr (maybeZps-> inputZp );
565
- IntegerAttr wZp = rewriter.getI32IntegerAttr (maybeZps-> weightZp );
583
+ IntegerAttr iZp = rewriter.getI32IntegerAttr (inputZpVal );
584
+ IntegerAttr wZp = rewriter.getI32IntegerAttr (weightZpVal );
566
585
auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
567
586
auto kZpVal = rewriter.create <arith::ConstantOp>(loc, wZp);
568
587
Value conv =
0 commit comments