diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 93e8cac6b84e9..893cedefc1ebd 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -524,18 +524,8 @@ bool TosaValidation::isValidElementType(Type type) { if (!isEnabledProfile(TosaProfileEnum::MainInference)) return false; return type.isF32() || type.isF16() || type.isBF16(); - } - if (auto intTy = dyn_cast(type)) { - if (intTy.isUnsigned()) { - switch (intTy.getWidth()) { - case 8: - case 16: - return true; - default: - return false; - } - } else { - // Signless - treated as signed. + } else if (auto intTy = dyn_cast(type)) { + if (intTy.isSignless()) { switch (intTy.getWidth()) { case 1: case 4: @@ -544,13 +534,10 @@ bool TosaValidation::isValidElementType(Type type) { case 32: case 48: return true; - default: - return false; } } - return false; } - return true; + return false; } void TosaValidation::runOnOperation() { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index e851019362958..529a16ca48c7e 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -143,6 +143,14 @@ func.func @test_const_f64(%arg0 : tensor<1xf64>) { // ----- +func.func @test_const_ui8(%arg0 : tensor<1xui8>) { + // expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'ui8' is not legal}} + %0 = "tosa.const"() {value = dense<0> : tensor<1xui8>} : () -> tensor<1xui8> + return +} + +// ----- + func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}} %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} :