@@ -48,37 +48,23 @@ static void debugLog(StringRef opName, ArrayRef<const LevelLattice*> operands,
4848 });
4949};
5050
51- LevelState transferForward (mgmt::ModReduceOp op,
52- ArrayRef<const LevelLattice*> operands) {
53- LevelState result = std::visit (
54- Overloaded{
55- [](MaxLevel) -> LevelState { return LevelState (Invalid{}); },
56- [](Uninit) -> LevelState { return LevelState (Invalid{}); },
57- [](Invalid) -> LevelState { return LevelState (Invalid{}); },
58- [](int val) -> LevelState { return LevelState (val + 1 ); },
59- },
60- operands[0 ]->getValue ().get ());
61- LLVM_DEBUG (debugLog (" mod_reduce" , operands, result));
62- return result;
63- }
64-
65- LevelState transferForward (mgmt::LevelReduceOp op,
51+ LevelState transferForward (ReducesLevelOpInterface op,
6652 ArrayRef<const LevelLattice*> operands) {
6753 LevelState result = std::visit (
6854 Overloaded{
6955 [](MaxLevel) -> LevelState { return LevelState (Invalid{}); },
7056 [](Uninit) -> LevelState { return LevelState (Invalid{}); },
7157 [](Invalid) -> LevelState { return LevelState (Invalid{}); },
7258 [&](int val) -> LevelState {
73- return LevelState (val + ( int ) op.getLevelToDrop ());
59+ return LevelState (val + op.getLevelsToDrop ());
7460 },
7561 },
7662 operands[0 ]->getValue ().get ());
77- LLVM_DEBUG (debugLog (" level_reduce " , operands, result));
63+ LLVM_DEBUG (debugLog (" ReduceLevelOpInterface " , operands, result));
7864 return result;
7965}
8066
81- LevelState transferForward (mgmt::LevelReduceMinOp op,
67+ LevelState transferForward (ReducesAllLevelsOpInterface op,
8268 ArrayRef<const LevelLattice*> operands) {
8369 LevelState result = std::visit (
8470 Overloaded{
@@ -90,11 +76,11 @@ LevelState transferForward(mgmt::LevelReduceMinOp op,
9076 [](int val) -> LevelState { return LevelState (MaxLevel{}); },
9177 },
9278 operands[0 ]->getValue ().get ());
93- LLVM_DEBUG (debugLog (" level_reduce_min " , operands, result));
79+ LLVM_DEBUG (debugLog (" ReduceAllLevelsOpInterface " , operands, result));
9480 return result;
9581}
9682
97- LevelState transferForward (mgmt::BootstrapOp op,
83+ LevelState transferForward (ResetsLevelOpInterface op,
9884 ArrayRef<const LevelLattice*> operands) {
9985 LevelState result = std::visit (
10086 Overloaded{
@@ -104,15 +90,18 @@ LevelState transferForward(mgmt::BootstrapOp op,
10490 [](int val) -> LevelState { return LevelState (0 ); },
10591 },
10692 operands[0 ]->getValue ().get ());
107- LLVM_DEBUG (debugLog (" bootstrap " , operands, result));
93+ LLVM_DEBUG (debugLog (" ResetsLevelOpInterface " , operands, result));
10894 return result;
10995}
11096
11197LevelState deriveResultLevel (Operation* op,
11298 ArrayRef<const LevelLattice*> operands) {
11399 return llvm::TypeSwitch<Operation&, LevelState>(*op)
114- .Case <mgmt::ModReduceOp, mgmt::LevelReduceOp, mgmt::BootstrapOp,
115- mgmt::LevelReduceMinOp>(
100+ .Case <ResetsLevelOpInterface>(
101+ [&](auto op) -> LevelState { return transferForward (op, operands); })
102+ .Case <ReducesAllLevelsOpInterface>(
103+ [&](auto op) -> LevelState { return transferForward (op, operands); })
104+ .Case <ReducesLevelOpInterface>(
116105 [&](auto op) -> LevelState { return transferForward (op, operands); })
117106 .Default ([&](auto & op) -> LevelState {
118107 LevelState result;
0 commit comments