diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index edf81bd7a8f39..b55eda69f99ec 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -229,8 +229,8 @@ struct ConstantCompositeOpPattern final if (!srcType || srcType.getNumElements() == 1) return failure(); - // arith.constant should only have vector or tenor types. - assert((isa(srcType))); + assert((isa(srcType) && + "arith.constant should only have vector or tensor types")); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) @@ -250,8 +250,9 @@ struct ConstantCompositeOpPattern final srcType.getElementType()); dstElementsAttr = dstElementsAttr.reshape(dstAttrType); } else { - // TODO: add support for large vectors. - return failure(); + dstAttrType = + VectorType::get(srcType.getNumElements(), srcType.getElementType()); + dstElementsAttr = dstElementsAttr.reshape(dstAttrType); } } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 2b79c8022b8e8..1ce7dff8ff0e4 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -330,6 +330,10 @@ convertVectorType(const spirv::TargetEnv &targetEnv, if (type.getRank() <= 1 && type.getNumElements() == 1) return convertScalarType(targetEnv, options, scalarType, storageClass); + // Linearize ND vectors + if (type.getRank() > 1) + type = VectorType::get(type.getNumElements(), scalarType); + if (!spirv::CompositeType::isValid(type)) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: not a valid composite type\n"); diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir index 0d92a8e676d85..551b036ba85e5 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir @@ -26,9 +26,9 @@ module attributes { #spirv.vce, #spirv.resource_limits<>> } { -func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) { +func.func @unsupported_2x2elem_vector(%arg0: vector<3x5xi32>) { // expected-error@+1 {{failed to legalize operation 'arith.muli'}} - %2 = arith.muli %arg0, %arg0: vector<2x2xi32> + %2 = arith.muli %arg0, %arg0: vector<3x5xi32> return } diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index ae47ae36ca51c..4a2ef1f0275c6 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -556,6 +556,8 @@ func.func @constant() { %9 = arith.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> // CHECK: spirv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spirv.array<6 x i32> %10 = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> + // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> + %11 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> return }