Skip to content

Commit 927f386

Browse files
asraacopybara-github
authored andcommitted
fix: copy result attribtues when extracting single ops from secret.generic
When canonicalizing, secret.generic ops are hoisted when they are fully plaintext. But this doesn't copy the attributes of the op onto the new hoisted generic's result attrs, and then those aren't copied to the old generic's new input. This results in a segfault when running in debug mode, due to a segfault in printing secret.generic arg attributes (that expect the arg attr dict to have arg attrs for each i in the arg list). PiperOrigin-RevId: 889323354
1 parent 61d6205 commit 927f386

File tree

4 files changed

+106
-5
lines changed

4 files changed

+106
-5
lines changed

lib/Dialect/Secret/IR/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ cc_library(
102102
":types_inc_gen",
103103
"@heir//lib/Dialect:HEIRInterfaces",
104104
"@heir//lib/Kernel",
105+
"@heir//lib/Utils:AttributeUtils",
105106
"@llvm-project//llvm:Support",
106107
"@llvm-project//mlir:ControlFlowInterfaces",
107108
"@llvm-project//mlir:IR",

lib/Dialect/Secret/IR/SecretOps.cpp

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "lib/Dialect/Secret/IR/SecretPatterns.h"
1111
#include "lib/Dialect/Secret/IR/SecretTypes.h"
12+
#include "lib/Utils/AttributeUtils.h"
1213
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
1314
#include "llvm/include/llvm/ADT/Sequence.h" // from @llvm-project
1415
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
@@ -401,13 +402,15 @@ GenericOp GenericOp::removeYieldedValues(ValueRange yieldedValuesToRemove,
401402
"Cannot remove a value that is not yielded");
402403
}
403404

405+
SmallVector<Attribute> newResultAttrs;
404406
SmallVector<int, 4> indicesToErase;
405407
for (unsigned int i = 0; i < getYieldOp()->getNumOperands(); ++i) {
406408
if (std::find(yieldedValuesToRemove.begin(), yieldedValuesToRemove.end(),
407409
getYieldOp()->getOperand(i)) != yieldedValuesToRemove.end()) {
408410
indicesToErase.push_back(i);
409411
} else {
410412
remainingResults.push_back(getResult(i));
413+
newResultAttrs.push_back(getResultAttrDict(i));
411414
}
412415
}
413416

@@ -416,6 +419,11 @@ GenericOp GenericOp::removeYieldedValues(ValueRange yieldedValuesToRemove,
416419
for (int i : llvm::reverse(indicesToErase)) {
417420
getYieldOp().getValuesMutable().erase(i);
418421
}
422+
// Update the result attr dictionary to remove the deleted results
423+
if (this->getAllResultAttrsAttr()) {
424+
this->setResultAttrsAttr(
425+
ArrayAttr::get(this->getContext(), newResultAttrs));
426+
}
419427

420428
auto newResultTypes = llvm::to_vector<4>(
421429
llvm::map_range(yieldOp.getValues().getTypes(), [](Type t) -> Type {
@@ -435,10 +443,13 @@ GenericOp GenericOp::removeYieldedValues(ArrayRef<int> yieldedIndicesToRemove,
435443
"Cannot remove an index that is out of range");
436444
}
437445

446+
SmallVector<Attribute> newResultAttrs;
447+
newResultAttrs.reserve(getNumResults() - yieldedIndicesToRemove.size());
438448
for (size_t i = 0; i < getYieldOp()->getNumOperands(); ++i) {
439449
if (std::find(yieldedIndicesToRemove.begin(), yieldedIndicesToRemove.end(),
440450
i) == yieldedIndicesToRemove.end()) {
441451
remainingResults.push_back(getResult(i));
452+
newResultAttrs.push_back(getResultAttrDict(i));
442453
}
443454
}
444455

@@ -447,6 +458,11 @@ GenericOp GenericOp::removeYieldedValues(ArrayRef<int> yieldedIndicesToRemove,
447458
for (int i : llvm::reverse(yieldedIndicesToRemove)) {
448459
getYieldOp().getValuesMutable().erase(i);
449460
}
461+
// Update the result attr dictionary to remove the deleted results
462+
if (this->getAllResultAttrsAttr()) {
463+
this->setResultAttrsAttr(
464+
ArrayAttr::get(this->getContext(), newResultAttrs));
465+
}
450466

451467
auto newResultTypes = llvm::to_vector<4>(
452468
llvm::map_range(yieldOp.getValues().getTypes(), [](Type t) -> Type {
@@ -526,25 +542,61 @@ GenericOp GenericOp::extractOpBeforeGeneric(Operation* opToExtract,
526542
auto* newOp = b.clone(*opToExtract, mp);
527543
YieldOp::create(b, loc, newOp->getResults());
528544
});
529-
LLVM_DEBUG({
530-
llvm::dbgs() << "After adding new single-op generic:\n";
531-
newGeneric->getParentOp()->dump();
532-
});
533-
545+
// Set the result attrs of the new generic to be the attrs of the opToExtract.
546+
for (auto attr : opToExtract->getAttrs()) {
547+
auto attrName = attr.getName();
548+
auto attrValue = attr.getValue();
549+
if (!attrName.getValue().contains(".") || !attrValue) continue;
550+
if (opToExtract->getNumResults() == 1) {
551+
setAttributeAssociatedWith(newGeneric.getResult(0), attrName, attrValue);
552+
} else if (auto arrayValueAttr = dyn_cast<ArrayAttr>(attrValue);
553+
arrayValueAttr.size() == opToExtract->getNumResults()) {
554+
// If the opToExtract has multiple result, each result attr has an array
555+
// attr elements corresponding to each result.
556+
for (auto [i, value] : llvm::enumerate(arrayValueAttr)) {
557+
setAttributeAssociatedWith(newGeneric.getResult(i), attrName, value);
558+
}
559+
}
560+
}
534561
// Once the op is split off into a new generic op, we need to add new
535562
// operands to the old generic op, add new corresponding block arguments, and
536563
// replace all uses of the opToExtract's results with the created block
564+
// arguments. We also need to copy arg attrs corresponding to the new block
537565
// arguments.
566+
auto oldGenericArgAttrs = this->getAllOperandAttrsAttr();
567+
SmallVector<Attribute> newGenericArgAttrs =
568+
oldGenericArgAttrs ? llvm::to_vector(oldGenericArgAttrs)
569+
: SmallVector<Attribute>(this->getNumOperands(),
570+
::mlir::DictionaryAttr::get(
571+
this->getContext(), {}));
572+
for (OpResult newOperand : newGeneric.getResults()) {
573+
auto attrDict = newGeneric.getResultAttrDict(newOperand.getResultNumber());
574+
newGenericArgAttrs.push_back(
575+
attrDict ? attrDict
576+
: ::mlir::DictionaryAttr::get(this->getContext(), {}));
577+
}
538578
SmallVector<Value> oldGenericNewBlockArgs;
539579
rewriter.modifyOpInPlace(*this, [&]() {
540580
getInputsMutable().append(newGeneric.getResults());
541581
for (auto ty : opToExtract->getResultTypes()) {
542582
BlockArgument arg = getBody()->addArgument(ty, opToExtract->getLoc());
543583
oldGenericNewBlockArgs.push_back(arg);
544584
}
585+
if (!llvm::all_of(newGenericArgAttrs, [](Attribute attr) {
586+
if (attr == nullptr) return true;
587+
auto dictAttr = dyn_cast<DictionaryAttr>(attr);
588+
return !dictAttr || dictAttr.empty();
589+
})) {
590+
this->setOperandAttrsAttr(
591+
ArrayAttr::get(this->getContext(), newGenericArgAttrs));
592+
}
545593
});
546594
rewriter.replaceOp(opToExtract, oldGenericNewBlockArgs);
547595

596+
LLVM_DEBUG({
597+
llvm::dbgs() << "After adding new single-op generic:\n";
598+
newGeneric->getParentOp()->dump();
599+
});
548600
return newGeneric;
549601
}
550602

lib/Dialect/Secret/IR/SecretPatterns.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,17 @@ LogicalResult RemoveUnusedGenericArgs::matchAndRewrite(
178178
body->eraseArgument(i);
179179
op.getOperation()->eraseOperand(i);
180180
});
181+
auto attrs = op.getAllOperandAttrsAttr();
182+
if (attrs) {
183+
SmallVector<Attribute> attrList;
184+
for (auto [j, attr] : llvm::enumerate(attrs)) {
185+
if (j != i) {
186+
attrList.push_back(attr);
187+
}
188+
}
189+
op.setOperandAttrsAttr(ArrayAttr::get(op.getContext(), attrList));
190+
}
191+
181192
// Ensure the next iteration uses the right arg number
182193
--i;
183194
} else if (llvm::any_of(arg.getUsers(), [&](Operation* user) {
@@ -254,6 +265,16 @@ LogicalResult RemoveNonSecretGenericArgs::matchAndRewrite(
254265
body->eraseArgument(i);
255266
op.getOperation()->eraseOperand(i);
256267
});
268+
auto attrs = op.getAllOperandAttrsAttr();
269+
if (attrs) {
270+
SmallVector<Attribute> attrList;
271+
for (auto [j, attr] : llvm::enumerate(attrs)) {
272+
if (j != i) {
273+
attrList.push_back(attr);
274+
}
275+
}
276+
op.setOperandAttrsAttr(ArrayAttr::get(op.getContext(), attrList));
277+
}
257278
i--;
258279
}
259280
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: heir-opt --canonicalize %s | FileCheck %s
2+
3+
// This is a regression test specifically testing that the result attrs of the assign_layout that is hoisted out is preserved.
4+
5+
#kernel = #secret.kernel<name = "MatvecDiagonal", force = false>
6+
#layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and ct = 0 and (-25i1 - 5i2 - i3 + slot) mod 128 = 0 and 0 <= i1 <= 3 and 0 <= i2 <= 4 and 0 <= i3 <= 1023 - 25i1 - 5i2 and i3 <= 4 and 0 <= slot <= 1023 and 1024*floor((-128 + 25i1 + 5i2 + i3)/1024) <= -1024 + 25i1 + 5i2 + i3 }">
7+
#layout1 = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and i1 = 0 and ct = 0 and (-10i2 - i3 + slot) mod 128 = 0 and 0 <= i2 <= 9 and 0 <= i3 <= 9 and 0 <= slot <= 1023 }">
8+
#layout2 = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : exists (e0, e1, e2, e3, e4, e6: i1 = 0 and 128e6 = -10i2 - i3 + ct + slot - 20e0 - 2e1 and 0 <= i0 <= 3 and 0 <= i2 <= 1 and 0 <= i3 <= 1 and 0 <= ct <= 15 and 0 <= slot <= 1023 and 0 <= e0 <= 4 and 0 <= e1 <= 35 - 5i2 - 10e0 and e1 <= 4 and 0 <= e2 <= 1 and 16*floor((slot)/16) >= slot - 2e3 and slot - 4e2 - 2e3 + 2e4 <= 16*floor((slot)/16) <= 1 + slot - 4e2 - 2e3 + 2e4 and 16*floor((slot)/16) <= 3 + slot - 2e3 and -1 + 25i0 + 4slot + 5e0 + e1 - 2e2 - 8e3 <= 64*floor((slot)/16) <= 25i0 + 4slot + 5e0 + e1 - 2e2 - 8e3 and 50i0 + 7slot + 10e0 + 2e1 - 4e2 - 16e3 + 4e4 <= 112*floor((slot)/16) <= 1 + 50i0 + 7slot + 10e0 + 2e1 - 4e2 - 16e3 + 4e4) }">
9+
#layout3 = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : exists (e1, e2, e3: i0 = 0 and ct = 0 and 0 <= i1 <= 3 and 0 <= i2 <= 4 and 0 <= i3 <= 4 and 0 <= slot <= 99 and 128*floor((-1 - 25i1 - 5i2 - i3)/128) <= -29 - 25i1 - 5i2 - i3 and 0 <= e1 <= 4 and slot <= 2e2 <= 3 + slot and 124 + 25i1 + 5i2 + i3 + 25slot + 128*floor((-1 - 25i1 - 5i2 - i3)/128) - 5e1 <= 50e2 <= 128 + 25i1 + 5i2 + i3 + 25slot + 128*floor((-1 - 25i1 - 5i2 - i3)/128) - 5e1 and -1 - slot + 4e1 + 2e2 <= 2e3 <= -slot + 4e1 + 2e2 and -257 - 50i1 - 10i2 - 2i3 - 49slot - 256*floor((-1 - 25i1 - 5i2 - i3)/128) + 10e1 + 100e2 <= 10e3 <= -256 - 50i1 - 10i2 - 2i3 - 49slot - 256*floor((-1 - 25i1 - 5i2 - i3)/128) + 10e1 + 100e2) }">
10+
module {
11+
// CHECK: @conv2d_nchw
12+
// CHECK: tensor_ext.assign_layout
13+
// CHECK-SAME: tensor_ext.layout
14+
// CHECK: secret.generic
15+
// CHECK: return
16+
func.func @conv2d_nchw(%arg0: !secret.secret<tensor<4x1x2x2xf32>> {tensor_ext.layout = #layout2}) -> (!secret.secret<tensor<4x1x2x2xf32>> {tensor_ext.layout = #layout2}) {
17+
%cst = arith.constant dense<0.000000e+00> : tensor<1x4x5x5xf32>
18+
%cst_0 = arith.constant dense<2.500000e-01> : tensor<4x1x2x2xf32>
19+
%0 = secret.generic(%arg0: !secret.secret<tensor<4x1x2x2xf32>> {tensor_ext.layout = #layout2}) {
20+
^body(%input0: tensor<4x1x2x2xf32>):
21+
%1 = tensor_ext.assign_layout %cst_0 {layout = #layout2, tensor_ext.layout = #layout2} : tensor<4x1x2x2xf32>
22+
%2 = arith.addf %input0, %1 {tensor_ext.layout = #layout2} : tensor<4x1x2x2xf32>
23+
secret.yield %2 : tensor<4x1x2x2xf32>
24+
} -> (!secret.secret<tensor<4x1x2x2xf32>> {tensor_ext.layout = #layout2})
25+
return %0 : !secret.secret<tensor<4x1x2x2xf32>>
26+
}
27+
}

0 commit comments

Comments
 (0)