diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ee9ab61b670c4..ddc80063fd340 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5881,14 +5881,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { } // shape_cast(constant) -> constant - if (auto splatAttr = - llvm::dyn_cast_if_present(adaptor.getSource())) - return splatAttr.reshape(getType()); + if (auto denseAttr = + dyn_cast_if_present(adaptor.getSource())) + return denseAttr.reshape(getType()); // shape_cast(poison) -> poison - if (llvm::dyn_cast_if_present(adaptor.getSource())) { + if (llvm::dyn_cast_if_present(adaptor.getSource())) return ub::PoisonAttr::get(getContext()); - } return {}; } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 65b73375831da..a06a98ee1b93b 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1219,11 +1219,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> { // ----- -// CHECK-LABEL: shape_cast_constant +// CHECK-LABEL: shape_cast_splat_constant // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32> // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32> // CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32> -func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { +func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { %cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32> %cst_1 = arith.constant dense<1> : vector<12x2xi32> %0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32> @@ -1233,6 +1233,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { // ----- +// Test of shape_cast's fold method: +// shape_cast(constant) -> constant. +// +// CHECK-LABEL: @shape_cast_dense_int_constant +// CHECK: %[[CST:.*]] = arith.constant +// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]> +// CHECK: return %[[CST]] : vector<2x3xi8> +func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> { + %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8> + %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8> + return %0 : vector<2x3xi8> +} + +// ----- + +// Test of shape_cast fold's method: +// (shape_cast(const_x), const_x) -> (const_x_folded, const_x) +// +// CHECK-LABEL: @shape_cast_dense_float_constant +// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32> +// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32> +// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32> +func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){ + %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32> + %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32> + return %0, %cst : vector<2xf32>, vector<1x2xf32> +} + +// ----- + // CHECK-LABEL: shape_cast_poison // CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32> // CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32> @@ -1549,7 +1579,7 @@ func.func @negative_store_to_load_tensor_memref( %arg0 : tensor, %arg1 : memref, %v0 : vector<4x2xf32> - ) -> vector<4x2xf32> + ) -> vector<4x2xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 @@ -1606,7 +1636,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor< // CHECK: vector.transfer_read func.func @negative_store_to_load_tensor_broadcast_masked( %arg0 : tensor, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>) - -> vector<4x2x6xf32> + -> vector<4x2x6xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32