-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg][nfc] Split hoisting tests into dedicated test functions #145234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e430966
7606ec6
6dee0a6
2872b64
53dd43e
0107f4a
fd3df28
4ece1c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,76 +1,210 @@ | ||
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s | ||
|
||
// CHECK-LABEL: func @hoist_vector_transfer_pairs( | ||
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index, | ||
// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index, | ||
// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index, | ||
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index, | ||
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1 | ||
func.func @hoist_vector_transfer_pairs( | ||
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>, | ||
%memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>, | ||
%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) { | ||
///---------------------------------------------------------------------------------------- | ||
/// Tests for vector.transfer_read + vector.transfer_write pairs | ||
/// | ||
/// * Nested in double loops | ||
// * Indices depend on induction variables | ||
///---------------------------------------------------------------------------------------- | ||
|
||
// CHECK-LABEL: func @mem_use_outside | ||
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, | ||
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, | ||
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) | ||
func.func @mem_use_outside(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) { | ||
%pad = arith.constant 0.0 : f32 | ||
|
||
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 | ||
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { | ||
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32> | ||
// CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) { | ||
// CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32> | ||
// CHECK: scf.yield %[[USE]] : vector<1xf32> | ||
// CHECK: } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For most cases, we don't need to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, though in this case it helps to highlight where the loop ends (and that the xfer Ops have indeed been hoisted). I don't feel strongly about this, just sharing my rationale. If that's OK, I'll keep it (also for consistency with the existing tests). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is okay, I just wanna point out that it is not necessary, and it depends on if you want to have fewer checks. |
||
// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32> | ||
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> () | ||
// CHECK: } | ||
scf.for %i = %lb to %ub step %step { | ||
scf.for %j = %lb to %ub step %step { | ||
%read = vector.transfer_read %mem[%i, %i], %pad: memref<?x?xf32>, vector<1xf32> | ||
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32> | ||
vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref<?x?xf32> | ||
} | ||
} | ||
"mem_use"(%mem) : (memref<?x?xf32>) -> () | ||
return | ||
} | ||
|
||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||
%0 = transform.structured.match ops{["func.func"]} in %arg1 | ||
: (!transform.any_op) -> !transform.any_op | ||
transform.structured.hoist_redundant_vector_transfers %0 | ||
: (!transform.any_op) -> !transform.any_op | ||
transform.yield | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @mem_use_inside_outer_loop | ||
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, | ||
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, | ||
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) | ||
func.func @mem_use_inside_outer_loop(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) { | ||
%pad = arith.constant 0.0 : f32 | ||
|
||
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 | ||
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { | ||
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32> | ||
// CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) { | ||
// CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32> | ||
// CHECK: scf.yield %[[USE]] : vector<1xf32> | ||
// CHECK: } | ||
// CHECK: vector.transfer_write %[[SCF]], %[[MEM]]{{\[}}%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32> | ||
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> () | ||
// CHECK: } | ||
scf.for %i = %lb to %ub step %step { | ||
scf.for %j = %lb to %ub step %step { | ||
%read = vector.transfer_read %mem[%i, %i], %pad: memref<?x?xf32>, vector<1xf32> | ||
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32> | ||
vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref<?x?xf32> | ||
} | ||
"mem_use"(%mem) : (memref<?x?xf32>) -> () | ||
} | ||
return | ||
} | ||
|
||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||
%0 = transform.structured.match ops{["func.func"]} in %arg1 | ||
: (!transform.any_op) -> !transform.any_op | ||
transform.structured.hoist_redundant_vector_transfers %0 | ||
: (!transform.any_op) -> !transform.any_op | ||
transform.yield | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
///---------------------------------------------------------------------------------------- | ||
/// Tests for vector.transfer_read + vector.transfer_write pairs | ||
/// | ||
/// * Nested in double loops | ||
// * Indices are constant | ||
///---------------------------------------------------------------------------------------- | ||
|
||
// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_write | ||
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, | ||
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, | ||
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) | ||
func.func @negative_mem_use_inside_inner_loop_before_write(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant 0.0 : f32 | ||
%pad = arith.constant 0.0 : f32 | ||
|
||
// CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 | ||
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { | ||
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { | ||
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32> | ||
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32> | ||
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> () | ||
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32> | ||
// CHECK: } | ||
// CHECK: } | ||
scf.for %i = %lb to %ub step %step { | ||
scf.for %j = %lb to %ub step %step { | ||
%read = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32> | ||
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32> | ||
"mem_use"(%mem) : (memref<?x?xf32>) -> () | ||
vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32> | ||
} | ||
} | ||
return | ||
} | ||
|
||
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32> | ||
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) { | ||
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32> | ||
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) { | ||
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32> | ||
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32> | ||
// CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> () | ||
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32> | ||
// CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> | ||
// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> | ||
// CHECK: "some_use"(%[[MEMREF2]], %{{.*}}) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32> | ||
// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> | ||
// CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32> | ||
// CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32> | ||
// CHECK: vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32> | ||
// CHECK: vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32> | ||
// CHECK: "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> () | ||
// CHECK: scf.yield {{.*}} : vector<1xf32>, vector<2xf32> | ||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||
%0 = transform.structured.match ops{["func.func"]} in %arg1 | ||
: (!transform.any_op) -> !transform.any_op | ||
transform.structured.hoist_redundant_vector_transfers %0 | ||
: (!transform.any_op) -> !transform.any_op | ||
transform.yield | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_after_write | ||
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, | ||
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, | ||
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) | ||
func.func @negative_mem_use_inside_inner_loop_after_write(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) { | ||
%c0 = arith.constant 0 : index | ||
%pad = arith.constant 0.0 : f32 | ||
|
||
// CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 | ||
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { | ||
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { | ||
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32> | ||
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32> | ||
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32> | ||
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> () | ||
// CHECK: } | ||
// CHECK: } | ||
scf.for %i = %lb to %ub step %step { | ||
scf.for %j = %lb to %ub step %step { | ||
%r3 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32> | ||
%u3 = "val_use"(%r3) : (vector<1xf32>) -> vector<1xf32> | ||
vector.transfer_write %u3, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32> | ||
"mem_use"(%mem) : (memref<?x?xf32>) -> () | ||
} | ||
} | ||
return | ||
} | ||
|
||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||
%0 = transform.structured.match ops{["func.func"]} in %arg1 | ||
: (!transform.any_op) -> !transform.any_op | ||
transform.structured.hoist_redundant_vector_transfers %0 | ||
: (!transform.any_op) -> !transform.any_op | ||
transform.yield | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_read | ||
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, | ||
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, | ||
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) | ||
func.func @negative_mem_use_inside_inner_loop_before_read(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) { | ||
%c0 = arith.constant 0 : index | ||
%pad = arith.constant 0.0 : f32 | ||
|
||
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { | ||
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { | ||
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> () | ||
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32> | ||
// CHECK: "val_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> | ||
// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32> | ||
// CHECK: } | ||
// CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32> | ||
// CHECK: "unrelated_use"(%[[MEMREF0]]) : (memref<?x?xf32>) -> () | ||
// CHECK: scf.yield {{.*}} : vector<1xf32> | ||
// CHECK: } | ||
// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32> | ||
// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref<?x?xf32>) -> () | ||
scf.for %i = %lb to %ub step %step { | ||
scf.for %j = %lb to %ub step %step { | ||
%r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32> | ||
%r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32> | ||
%r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32> | ||
%r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32> | ||
"some_crippling_use"(%memref4) : (memref<?x?xf32>) -> () | ||
%r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32> | ||
%r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref<?x?xf32>, vector<6xf32> | ||
"some_crippling_use"(%memref5) : (memref<?x?xf32>) -> () | ||
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> | ||
%u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> | ||
%u2 = "some_use"(%memref2, %r2) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32> | ||
%u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> | ||
%u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> | ||
%u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32> | ||
vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32> | ||
vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32> | ||
vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32> | ||
vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32> | ||
vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32> | ||
vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref<?x?xf32> | ||
"some_crippling_use"(%memref3) : (memref<?x?xf32>) -> () | ||
"mem_use"(%mem) : (memref<?x?xf32>) -> () | ||
%read = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32> | ||
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32> | ||
vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32> | ||
} | ||
"unrelated_use"(%memref0) : (memref<?x?xf32>) -> () | ||
} | ||
"unrelated_use"(%memref1) : (memref<?x?xf32>) -> () | ||
return | ||
} | ||
|
||
|
@@ -86,6 +220,12 @@ module attributes {transform.with_named_sequence} { | |
|
||
// ----- | ||
|
||
///---------------------------------------------------------------------------------------- | ||
/// Other tests | ||
/// | ||
/// TODO: Document | ||
///---------------------------------------------------------------------------------------- | ||
|
||
// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint( | ||
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>, | ||
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you don't really care the types because they are shown right in the next line; the purpose of using regex is to avoid the capture of types IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This REGEX helps when you have adjacent
index
variables. With a more generic REGEX, we'd need this (allindex
vars on one line):I think that you are right that we could skip types without loosing anything, but we seem to always include them 🤷🏻♂️ (at least in this file). I find them helpful TBH, but am also open to new trends :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going to send a patch to help clean it up, but I'm surprised that you find them helpful. To me, it is case by case. It is helpful when type converter or dialect conversion is involved, but it is not that helpful in other cases. It is not a hurt for types. However, It could be annoying if people update complicated types, e.g., memref types, and I think that happened before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the MLIR name is descriptive (e.g.
%mem
) and the LIT name matches the MLIR name 1:1 (e.g.%mem
vsMEM
) then the type doesn't really help, does it? But with "enigmatic" names (e.g.%arg0
) it helps to identify the right variable.In this file the names are quite self-descriptive. If you want to remove types from this file then I will happily review + approve.