diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir index b4cb640108bae..9b23681dba6a8 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -1,12 +1,17 @@ // RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s -func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{ +//----------------------------------------------------------------------------- +// 1. vector.transfer_read +//----------------------------------------------------------------------------- + +func.func @contiguous_inner_most(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{ %c0 = arith.constant 0 : index %cst = arith.constant 0.0 : f32 %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32> return %0 : vector<1x8x1xf32> } -// CHECK: func @contiguous_inner_most_view(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> + +// CHECK: func @contiguous_inner_most(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> // CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]] // CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>> // CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]] @@ -14,15 +19,61 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8, // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]] // CHECK: return %[[RESULT]] +// Same as the top example within this split, but with the inner vector +// dim scalable. Note that this example only makes sense when "8 = [8]" (i.e. +// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute. + +func.func @contiguous_inner_most_scalable_inner_dim(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x[8]x1xf32>{ + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x[8]x1xf32> + return %0 : vector<1x[8]x1xf32> +} + +// CHECK: func @contiguous_inner_most_scalable_inner_dim(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> +// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]] +// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>> +// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]] +// CHECK-SAME: memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x[8]xf32> +// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]] +// CHECK: return %[[RESULT]] + +// Same as the top example within this split, but the trailing unit dim was +// replaced with a dyn dim - not supported + +func.func @non_unit_trailing_dim(%in: memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{ + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32> + return %0 : vector<1x8x1xf32> +} + +// CHECK-LABEL: func @non_unit_trailing_dim +// CHECK-NOT: memref.subview +// CHECK-NOT: vector.shape_cast + +// Same as the top example within this split, but with a scalable unit dim in +// the output vector - not supported + +func.func @negative_scalable_unit_dim(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x[1]xf32>{ + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x[1]xf32> + return %0 : vector<1x8x[1]xf32> +} +// CHECK-LABEL: func @negative_scalable_unit_dim +// CHECK-NOT: memref.subview +// CHECK-NOT: vector.shape_cast + // ----- -func.func @contiguous_outer_dyn_inner_most_view(%a: index, %b: index, %memref: memref) -> vector<8x1xf32> { +func.func @contiguous_outer_dyn_inner_most(%a: index, %b: index, %memref: memref) -> vector<8x1xf32> { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f32 %v = vector.transfer_read %memref[%a, %b, %c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<8x1xf32> return %v : vector<8x1xf32> } -// CHECK: func.func @contiguous_outer_dyn_inner_most_view( +// CHECK: func.func @contiguous_outer_dyn_inner_most( // CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]] // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] @@ -103,6 +154,10 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) -> // ----- +//----------------------------------------------------------------------------- +// 2. vector.transfer_write +//----------------------------------------------------------------------------- + func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) { %c0 = arith.constant 0 : index vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] @@ -177,21 +232,6 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o // ----- -func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>) -> vector<[4]x1xf32> { - %c0 = arith.constant 0 : index - %pad = arith.constant 0.0 : f32 - %0 = vector.transfer_read %dest[%c0, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32> - return %0 : vector<[4]x1xf32> -} -// CHECK: func.func @leading_scalable_dimension_transfer_read -// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>> -// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32> -// CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<[4]xf32> to vector<[4]x1xf32> -// CHECK: return %[[CAST]] - -// ----- - // Negative test: [1] (scalable 1) is _not_ a unit dimension. func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>) -> vector<4x[1]xf32> { %c0 = arith.constant 0 : index