Skip to content

Commit c044695

Browse files
asraacopybara-github
authored andcommitted
fix: fixes lattigo in place transform by ensuring that storage values keep level state invariant
regression test for in-place issue #2635 PiperOrigin-RevId: 868309124
1 parent c180782 commit c044695

File tree

17 files changed

+363
-62
lines changed

17 files changed

+363
-62
lines changed

lib/Analysis/LevelAnalysis/LevelAnalysis.cpp

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,37 +48,23 @@ static void debugLog(StringRef opName, ArrayRef<const LevelLattice*> operands,
4848
});
4949
};
5050

51-
LevelState transferForward(mgmt::ModReduceOp op,
52-
ArrayRef<const LevelLattice*> operands) {
53-
LevelState result = std::visit(
54-
Overloaded{
55-
[](MaxLevel) -> LevelState { return LevelState(Invalid{}); },
56-
[](Uninit) -> LevelState { return LevelState(Invalid{}); },
57-
[](Invalid) -> LevelState { return LevelState(Invalid{}); },
58-
[](int val) -> LevelState { return LevelState(val + 1); },
59-
},
60-
operands[0]->getValue().get());
61-
LLVM_DEBUG(debugLog("mod_reduce", operands, result));
62-
return result;
63-
}
64-
65-
LevelState transferForward(mgmt::LevelReduceOp op,
51+
LevelState transferForward(ReducesLevelOpInterface op,
6652
ArrayRef<const LevelLattice*> operands) {
6753
LevelState result = std::visit(
6854
Overloaded{
6955
[](MaxLevel) -> LevelState { return LevelState(Invalid{}); },
7056
[](Uninit) -> LevelState { return LevelState(Invalid{}); },
7157
[](Invalid) -> LevelState { return LevelState(Invalid{}); },
7258
[&](int val) -> LevelState {
73-
return LevelState(val + (int)op.getLevelToDrop());
59+
return LevelState(val + op.getLevelsToDrop());
7460
},
7561
},
7662
operands[0]->getValue().get());
77-
LLVM_DEBUG(debugLog("level_reduce", operands, result));
63+
LLVM_DEBUG(debugLog("ReduceLevelOpInterface", operands, result));
7864
return result;
7965
}
8066

81-
LevelState transferForward(mgmt::LevelReduceMinOp op,
67+
LevelState transferForward(ReducesAllLevelsOpInterface op,
8268
ArrayRef<const LevelLattice*> operands) {
8369
LevelState result = std::visit(
8470
Overloaded{
@@ -90,11 +76,11 @@ LevelState transferForward(mgmt::LevelReduceMinOp op,
9076
[](int val) -> LevelState { return LevelState(MaxLevel{}); },
9177
},
9278
operands[0]->getValue().get());
93-
LLVM_DEBUG(debugLog("level_reduce_min", operands, result));
79+
LLVM_DEBUG(debugLog("ReduceAllLevelsOpInterface", operands, result));
9480
return result;
9581
}
9682

97-
LevelState transferForward(mgmt::BootstrapOp op,
83+
LevelState transferForward(ResetsLevelOpInterface op,
9884
ArrayRef<const LevelLattice*> operands) {
9985
LevelState result = std::visit(
10086
Overloaded{
@@ -104,15 +90,18 @@ LevelState transferForward(mgmt::BootstrapOp op,
10490
[](int val) -> LevelState { return LevelState(0); },
10591
},
10692
operands[0]->getValue().get());
107-
LLVM_DEBUG(debugLog("bootstrap", operands, result));
93+
LLVM_DEBUG(debugLog("ResetsLevelOpInterface", operands, result));
10894
return result;
10995
}
11096

11197
LevelState deriveResultLevel(Operation* op,
11298
ArrayRef<const LevelLattice*> operands) {
11399
return llvm::TypeSwitch<Operation&, LevelState>(*op)
114-
.Case<mgmt::ModReduceOp, mgmt::LevelReduceOp, mgmt::BootstrapOp,
115-
mgmt::LevelReduceMinOp>(
100+
.Case<ResetsLevelOpInterface>(
101+
[&](auto op) -> LevelState { return transferForward(op, operands); })
102+
.Case<ReducesAllLevelsOpInterface>(
103+
[&](auto op) -> LevelState { return transferForward(op, operands); })
104+
.Case<ReducesLevelOpInterface>(
116105
[&](auto op) -> LevelState { return transferForward(op, operands); })
117106
.Default([&](auto& op) -> LevelState {
118107
LevelState result;

lib/Dialect/HEIRInterfaces.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,41 @@ def ResetsMulDepthOpInterface : OpInterface<"ResetsMulDepthOpInterface"> {
3434
}];
3535
}
3636

37+
def ResetsLevelOpInterface : OpInterface<"ResetsLevelOpInterface"> {
38+
let cppNamespace = "::mlir::heir";
39+
let description = [{
40+
An interface that signals when an operation resets level
41+
among its results, such as a `mgmt.bootstrap`.
42+
}];
43+
}
44+
45+
def ReducesLevelOpInterface : OpInterface<"ReducesLevelOpInterface"> {
46+
let cppNamespace = "::mlir::heir";
47+
let description = [{
48+
An interface that signals when an operation reduces level
49+
among its results, such as a `mgmt.mod_reduce` or `ckks.rescale`.
50+
}];
51+
52+
let methods = [
53+
InterfaceMethod<
54+
/*desc=*/"Return the number of levels to reduce by.",
55+
/*retTy=*/"int",
56+
/*methodName=*/"getLevelsToDrop",
57+
/*args=*/(ins ),
58+
/*body=*/[{}],
59+
/*defaultBody=*/[{ return 1; }]
60+
>,
61+
];
62+
}
63+
64+
def ReducesAllLevelsOpInterface : OpInterface<"ReducesAllLevelsOpInterface"> {
65+
let cppNamespace = "::mlir::heir";
66+
let description = [{
67+
An interface that signals when an operation reduces all level
68+
among its results, such as a `mgmt.level_reduce_min`.
69+
}];
70+
}
71+
3772
def LUTOpInterface : OpInterface<"LUTOpInterface"> {
3873
let cppNamespace = "::mlir::heir";
3974
let description = [{

lib/Dialect/Lattigo/IR/LattigoBGVOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def Lattigo_BGVMulOp : Lattigo_BGVBinaryInPlaceOp<"mul", [IncreasesMulDepthOpInt
182182
}];
183183
}
184184

185-
class Lattigo_BGVUnaryOp<string mnemonic> :
186-
Lattigo_BGVOp<mnemonic> {
185+
class Lattigo_BGVUnaryOp<string mnemonic, list<Trait> traits = []> :
186+
Lattigo_BGVOp<mnemonic, traits> {
187187
let arguments = (ins
188188
Lattigo_BGVEvaluator:$evaluator,
189189
Lattigo_RLWECiphertext:$input
@@ -198,7 +198,7 @@ def Lattigo_BGVRelinearizeNewOp : Lattigo_BGVUnaryOp<"relinearize_new"> {
198198
}];
199199
}
200200

201-
def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new"> {
201+
def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new", [ReducesLevelOpInterface]> {
202202
let summary = "Rescale a ciphertext in the Lattigo BGV dialect";
203203
let description = [{
204204
This operation rescales a ciphertext value in the Lattigo BGV dialect.
@@ -231,8 +231,8 @@ def Lattigo_BGVRotateRowsNewOp : Lattigo_BGVUnaryOp<"rotate_rows_new"> {
231231
}];
232232
}
233233

234-
class Lattigo_BGVUnaryInPlaceOp<string mnemonic> :
235-
Lattigo_BGVOp<mnemonic, [InPlaceOpInterface]> {
234+
class Lattigo_BGVUnaryInPlaceOp<string mnemonic, list<Trait> traits = []> :
235+
Lattigo_BGVOp<mnemonic, traits # [InPlaceOpInterface]> {
236236
let arguments = (ins
237237
Lattigo_BGVEvaluator:$evaluator,
238238
Lattigo_RLWECiphertext:$input,
@@ -254,7 +254,7 @@ def Lattigo_BGVRelinearizeOp : Lattigo_BGVUnaryInPlaceOp<"relinearize"> {
254254
}];
255255
}
256256

257-
def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale"> {
257+
def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale", [ReducesLevelOpInterface]> {
258258
let summary = "Rescale a ciphertext in the Lattigo BGV dialect";
259259
let description = [{
260260
This operation rescales a ciphertext value in the Lattigo BGV dialect.

lib/Dialect/Lattigo/IR/LattigoCKKSOps.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ def Lattigo_CKKSMulOp : Lattigo_CKKSBinaryInPlaceOp<"mul", [IncreasesMulDepthOpI
215215
}];
216216
}
217217

218-
class Lattigo_CKKSUnaryOp<string mnemonic> :
219-
Lattigo_CKKSOp<mnemonic> {
218+
class Lattigo_CKKSUnaryOp<string mnemonic, list<Trait> traits = []> :
219+
Lattigo_CKKSOp<mnemonic, traits> {
220220
let arguments = (ins
221221
Lattigo_CKKSEvaluator:$evaluator,
222222
Lattigo_RLWECiphertext:$input
@@ -231,7 +231,7 @@ def Lattigo_CKKSRelinearizeNewOp : Lattigo_CKKSUnaryOp<"relinearize_new"> {
231231
}];
232232
}
233233

234-
def Lattigo_CKKSRescaleNewOp : Lattigo_CKKSUnaryOp<"rescale_new"> {
234+
def Lattigo_CKKSRescaleNewOp : Lattigo_CKKSUnaryOp<"rescale_new", [ReducesLevelOpInterface]> {
235235
let summary = "Rescale a ciphertext in the Lattigo CKKS dialect";
236236
let description = [{
237237
This operation rescales a ciphertext value in the Lattigo CKKS dialect.
@@ -257,8 +257,8 @@ def Lattigo_CKKSRotateNewOp : Lattigo_CKKSOp<"rotate_new"> {
257257
let results = (outs Lattigo_RLWECiphertext:$output);
258258
}
259259

260-
class Lattigo_CKKSUnaryInPlaceOp<string mnemonic> :
261-
Lattigo_CKKSOp<mnemonic, [InPlaceOpInterface]> {
260+
class Lattigo_CKKSUnaryInPlaceOp<string mnemonic, list<Trait> traits = []> :
261+
Lattigo_CKKSOp<mnemonic, traits # [InPlaceOpInterface]> {
262262
let arguments = (ins
263263
Lattigo_CKKSEvaluator:$evaluator,
264264
Lattigo_RLWECiphertext:$input,
@@ -280,7 +280,7 @@ def Lattigo_CKKSRelinearizeOp : Lattigo_CKKSUnaryInPlaceOp<"relinearize"> {
280280
}];
281281
}
282282

283-
def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale"> {
283+
def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale", [ReducesLevelOpInterface]> {
284284
let summary = "Rescale a ciphertext in the Lattigo CKKS dialect";
285285
let description = [{
286286
This operation rescales a ciphertext value in the Lattigo CKKS dialect.
@@ -314,7 +314,7 @@ def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInPlaceOp<"rotate"> {
314314
let results = (outs Lattigo_RLWECiphertext:$output);
315315
}
316316

317-
def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap"> {
317+
def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap", [ResetsLevelOpInterface]> {
318318
let summary = "Bootstrap a ciphertext in the Lattigo CKKS dialect";
319319
let description = [{
320320
Bootstraps a ciphertext value in the Lattigo CKKS dialect.

lib/Dialect/Lattigo/IR/LattigoRLWEOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def Lattigo_RLWEDecryptOp : Lattigo_RLWEOp<"decrypt"> {
122122
let results = (outs Lattigo_RLWEPlaintext:$plaintext);
123123
}
124124

125-
def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new"> {
125+
def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new", [ReducesLevelOpInterface]> {
126126
let summary = "Drop level of a ciphertext";
127127
let arguments = (ins
128128
Lattigo_RLWEEvaluator:$evaluator,
@@ -132,7 +132,7 @@ def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new"> {
132132
let results = (outs Lattigo_RLWECiphertext:$output);
133133
}
134134

135-
def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", [InPlaceOpInterface]> {
135+
def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", [InPlaceOpInterface, ReducesLevelOpInterface]> {
136136
let summary = "Drop level of a ciphertext";
137137
let description = [{
138138
This operation drops the level of a ciphertext

0 commit comments

Comments
 (0)