diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index c36c1074f5780..751ae785bda6f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1322,6 +1322,25 @@ LogicalResult tosa::ConcatOp::verify() { << " on operands 0 and " << operandNum; } } + + // ERROR_IF(axis_sum != shape[axis]); + int64_t axisSum = 0; + for (const auto &input : inputList) { + const ShapeAdaptor inputShape(input.getType()); + if (inputShape.isDynamicDim(axis)) { + // make axisSum negative to indicate invalid value + axisSum = -1; + break; + } + axisSum += inputShape.getDimSize(axis); + } + const ShapeAdaptor outputShape(outType); + if (axisSum >= 0 && outputShape.hasRank() && + !outputShape.isDynamicDim(axis) && + axisSum != outputShape.getDimSize(axis)) + return emitOpError("requires sum of axis dimensions of input1 " + "equal to output axis dimension, got ") + << axisSum << " and " << outputShape.getDimSize(axis); } return success(); diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 269ed58fdc81c..b147c94fde9b0 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -272,37 +272,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens // ----- -func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor { - // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}} - %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor - return %0 : tensor -} - -// ----- - -func.func @test_concat_zero_inputs() { - // expected-error@+1 {{'tosa.concat' op expect at least one input}} - %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32> -} - -// ----- - -func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}} - %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -// ----- - -func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}} - %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -// ----- - func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> { %pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32> // expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}} diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index fb8726cba1853..262e6d4265ea6 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -319,3 +319,42 @@ func.func @test_conv3d_wholly_divisible_output_width(%arg0: tensor<1x4x8x21x19xf : (tensor<1x4x8x21x19xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x19x34xf32> return %0 : tensor<1x4x8x19x34xf32> } + +// ----- + +func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor { + // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}} + %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_concat_zero_inputs() { + // expected-error@+1 {{'tosa.concat' op expect at least one input}} + %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32> +} + +// ----- + +func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}} + %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}} + %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?xf32> { + // expected-error@+1 {{'tosa.concat' op requires sum of axis dimensions of input1 equal to output axis dimension, got 3 and 2}} + %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32> + return %0 : tensor<2x?xf32> +}