Skip to content

[mlir][linalg][nfc] Split hoisting tests into dedicated test functions #145234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 25, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 203 additions & 63 deletions mlir/test/Dialect/Linalg/hoisting.mlir
Original file line number Diff line number Diff line change
@@ -1,76 +1,210 @@
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s

// CHECK-LABEL: func @hoist_vector_transfer_pairs(
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
func.func @hoist_vector_transfer_pairs(
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
%memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
///----------------------------------------------------------------------------------------
/// Tests for vector.transfer_read + vector.transfer_write pairs
///
/// * Nested in double loops
// * Indices depend on induction variables
///----------------------------------------------------------------------------------------

// CHECK-LABEL: func @mem_use_outside
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
Comment on lines +11 to +14
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you don't really care the types because they are shown right in the next line; the purpose of using regex is to avoid the capture of types IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This REGEX helps when you have adjacent index variables. With a more generic REGEX, we'd need this (all index vars on one line):

// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index, %[[UB:[a-zA-Z0-9]+]]: index, %[[STEP:[a-zA-Z0-9]+]]: index)

I think that you are right that we could skip types without loosing anything, but we seem to always include them 🤷🏻‍♂️ (at least in this file). I find them helpful TBH, but am also open to new trends :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to send a patch to help clean it up, but I'm surprised that you find them helpful. To me, it is case by case. It is helpful when type converter or dialect conversion is involved, but it is not that helpful in other cases. It is not a hurt for types. However, It could be annoying if people update complicated types, e.g., memref types, and I think that happened before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the MLIR name is descriptive (e.g. %mem) and the LIT name matches the MLIR name 1:1 (e.g. %mem vs MEM) then the type doesn't really help, does it? But with "enigmatic" names (e.g. %arg0) it helps to identify the right variable.

In this file the names are quite self-descriptive. If you want to remove types from this file then I will happily review + approve.

func.func @mem_use_outside(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
%pad = arith.constant 0.0 : f32

// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
// CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: scf.yield %[[USE]] : vector<1xf32>
// CHECK: }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For most cases, we don't need to check }, because the below core checks use captured variables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, though in this case it helps to highlight where the loop ends (and that the xfer Ops have indeed been hoisted).

I don't feel strongly about this, just sharing my rationale. If that's OK, I'll keep it (also for consistency with the existing tests).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is okay, I just wanna point out that it is not necessary, and it depends on if you want to have fewer checks.

// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
// CHECK: }
scf.for %i = %lb to %ub step %step {
scf.for %j = %lb to %ub step %step {
%read = vector.transfer_read %mem[%i, %i], %pad: memref<?x?xf32>, vector<1xf32>
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref<?x?xf32>
}
}
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func @mem_use_inside_outer_loop
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
func.func @mem_use_inside_outer_loop(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
%pad = arith.constant 0.0 : f32

// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
// CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: scf.yield %[[USE]] : vector<1xf32>
// CHECK: }
// CHECK: vector.transfer_write %[[SCF]], %[[MEM]]{{\[}}%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
// CHECK: }
scf.for %i = %lb to %ub step %step {
scf.for %j = %lb to %ub step %step {
%read = vector.transfer_read %mem[%i, %i], %pad: memref<?x?xf32>, vector<1xf32>
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref<?x?xf32>
}
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

///----------------------------------------------------------------------------------------
/// Tests for vector.transfer_read + vector.transfer_write pairs
///
/// * Nested in double loops
// * Indices are constant
///----------------------------------------------------------------------------------------

// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_write
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
func.func @negative_mem_use_inside_inner_loop_before_write(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
%pad = arith.constant 0.0 : f32

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: }
// CHECK: }
scf.for %i = %lb to %ub step %step {
scf.for %j = %lb to %ub step %step {
%read = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
}
return
}

// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
// CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32>
// CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
// CHECK: "some_use"(%[[MEMREF2]], %{{.*}}) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
// CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
// CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32>
// CHECK: vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32>
// CHECK: vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32>
// CHECK: "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> ()
// CHECK: scf.yield {{.*}} : vector<1xf32>, vector<2xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_after_write
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
func.func @negative_mem_use_inside_inner_loop_after_write(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
// CHECK: }
// CHECK: }
scf.for %i = %lb to %ub step %step {
scf.for %j = %lb to %ub step %step {
%r3 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
%u3 = "val_use"(%r3) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %u3, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
}
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_read
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
func.func @negative_mem_use_inside_inner_loop_before_read(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32

// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
// CHECK: "val_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
// CHECK: }
// CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32>
// CHECK: "unrelated_use"(%[[MEMREF0]]) : (memref<?x?xf32>) -> ()
// CHECK: scf.yield {{.*}} : vector<1xf32>
// CHECK: }
// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref<?x?xf32>) -> ()
scf.for %i = %lb to %ub step %step {
scf.for %j = %lb to %ub step %step {
%r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
%r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
%r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32>
%r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32>
"some_crippling_use"(%memref4) : (memref<?x?xf32>) -> ()
%r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32>
%r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref<?x?xf32>, vector<6xf32>
"some_crippling_use"(%memref5) : (memref<?x?xf32>) -> ()
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
%u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
%u2 = "some_use"(%memref2, %r2) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
%u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
%u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
%u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32>
vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32>
vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref<?x?xf32>
"some_crippling_use"(%memref3) : (memref<?x?xf32>) -> ()
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
%read = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
"unrelated_use"(%memref0) : (memref<?x?xf32>) -> ()
}
"unrelated_use"(%memref1) : (memref<?x?xf32>) -> ()
return
}

Expand All @@ -86,6 +220,12 @@ module attributes {transform.with_named_sequence} {

// -----

///----------------------------------------------------------------------------------------
/// Other tests
///
/// TODO: Document
///----------------------------------------------------------------------------------------

// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint(
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
Expand Down
Loading