99
1010#include " lib/Dialect/Secret/IR/SecretPatterns.h"
1111#include " lib/Dialect/Secret/IR/SecretTypes.h"
12+ #include " lib/Utils/AttributeUtils.h"
1213#include " llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
1314#include " llvm/include/llvm/ADT/Sequence.h" // from @llvm-project
1415#include " llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
@@ -401,13 +402,15 @@ GenericOp GenericOp::removeYieldedValues(ValueRange yieldedValuesToRemove,
401402 " Cannot remove a value that is not yielded" );
402403 }
403404
405+ SmallVector<Attribute> newResultAttrs;
404406 SmallVector<int , 4 > indicesToErase;
405407 for (unsigned int i = 0 ; i < getYieldOp ()->getNumOperands (); ++i) {
406408 if (std::find (yieldedValuesToRemove.begin (), yieldedValuesToRemove.end (),
407409 getYieldOp ()->getOperand (i)) != yieldedValuesToRemove.end ()) {
408410 indicesToErase.push_back (i);
409411 } else {
410412 remainingResults.push_back (getResult (i));
413+ newResultAttrs.push_back (getResultAttrDict (i));
411414 }
412415 }
413416
@@ -416,6 +419,11 @@ GenericOp GenericOp::removeYieldedValues(ValueRange yieldedValuesToRemove,
416419 for (int i : llvm::reverse (indicesToErase)) {
417420 getYieldOp ().getValuesMutable ().erase (i);
418421 }
422+ // Update the result attr dictionary to remove the deleted results
423+ if (this ->getAllResultAttrsAttr ()) {
424+ this ->setResultAttrsAttr (
425+ ArrayAttr::get (this ->getContext (), newResultAttrs));
426+ }
419427
420428 auto newResultTypes = llvm::to_vector<4 >(
421429 llvm::map_range (yieldOp.getValues ().getTypes (), [](Type t) -> Type {
@@ -435,10 +443,13 @@ GenericOp GenericOp::removeYieldedValues(ArrayRef<int> yieldedIndicesToRemove,
435443 " Cannot remove an index that is out of range" );
436444 }
437445
446+ SmallVector<Attribute> newResultAttrs;
447+ newResultAttrs.reserve (getNumResults () - yieldedIndicesToRemove.size ());
438448 for (size_t i = 0 ; i < getYieldOp ()->getNumOperands (); ++i) {
439449 if (std::find (yieldedIndicesToRemove.begin (), yieldedIndicesToRemove.end (),
440450 i) == yieldedIndicesToRemove.end ()) {
441451 remainingResults.push_back (getResult (i));
452+ newResultAttrs.push_back (getResultAttrDict (i));
442453 }
443454 }
444455
@@ -447,6 +458,11 @@ GenericOp GenericOp::removeYieldedValues(ArrayRef<int> yieldedIndicesToRemove,
447458 for (int i : llvm::reverse (yieldedIndicesToRemove)) {
448459 getYieldOp ().getValuesMutable ().erase (i);
449460 }
461+ // Update the result attr dictionary to remove the deleted results
462+ if (this ->getAllResultAttrsAttr ()) {
463+ this ->setResultAttrsAttr (
464+ ArrayAttr::get (this ->getContext (), newResultAttrs));
465+ }
450466
451467 auto newResultTypes = llvm::to_vector<4 >(
452468 llvm::map_range (yieldOp.getValues ().getTypes (), [](Type t) -> Type {
@@ -526,25 +542,61 @@ GenericOp GenericOp::extractOpBeforeGeneric(Operation* opToExtract,
526542 auto * newOp = b.clone (*opToExtract, mp);
527543 YieldOp::create (b, loc, newOp->getResults ());
528544 });
529- LLVM_DEBUG ({
530- llvm::dbgs () << " After adding new single-op generic:\n " ;
531- newGeneric->getParentOp ()->dump ();
532- });
533-
545+ // Set the result attrs of the new generic to be the attrs of the opToExtract.
546+ for (auto attr : opToExtract->getAttrs ()) {
547+ auto attrName = attr.getName ();
548+ auto attrValue = attr.getValue ();
549+ if (!attrName.getValue ().contains (" ." ) || !attrValue) continue ;
550+ if (opToExtract->getNumResults () == 1 ) {
551+ setAttributeAssociatedWith (newGeneric.getResult (0 ), attrName, attrValue);
552+ } else if (auto arrayValueAttr = dyn_cast<ArrayAttr>(attrValue);
553+ arrayValueAttr.size () == opToExtract->getNumResults ()) {
554+ // If the opToExtract has multiple result, each result attr has an array
555+ // attr elements corresponding to each result.
556+ for (auto [i, value] : llvm::enumerate (arrayValueAttr)) {
557+ setAttributeAssociatedWith (newGeneric.getResult (i), attrName, value);
558+ }
559+ }
560+ }
534561 // Once the op is split off into a new generic op, we need to add new
535562 // operands to the old generic op, add new corresponding block arguments, and
536563 // replace all uses of the opToExtract's results with the created block
564+ // arguments. We also need to copy arg attrs corresponding to the new block
537565 // arguments.
566+ auto oldGenericArgAttrs = this ->getAllOperandAttrsAttr ();
567+ SmallVector<Attribute> newGenericArgAttrs =
568+ oldGenericArgAttrs ? llvm::to_vector (oldGenericArgAttrs)
569+ : SmallVector<Attribute>(this ->getNumOperands (),
570+ ::mlir::DictionaryAttr::get (
571+ this ->getContext (), {}));
572+ for (OpResult newOperand : newGeneric.getResults ()) {
573+ auto attrDict = newGeneric.getResultAttrDict (newOperand.getResultNumber ());
574+ newGenericArgAttrs.push_back (
575+ attrDict ? attrDict
576+ : ::mlir::DictionaryAttr::get (this ->getContext (), {}));
577+ }
538578 SmallVector<Value> oldGenericNewBlockArgs;
539579 rewriter.modifyOpInPlace (*this , [&]() {
540580 getInputsMutable ().append (newGeneric.getResults ());
541581 for (auto ty : opToExtract->getResultTypes ()) {
542582 BlockArgument arg = getBody ()->addArgument (ty, opToExtract->getLoc ());
543583 oldGenericNewBlockArgs.push_back (arg);
544584 }
585+ if (!llvm::all_of (newGenericArgAttrs, [](Attribute attr) {
586+ if (attr == nullptr ) return true ;
587+ auto dictAttr = dyn_cast<DictionaryAttr>(attr);
588+ return !dictAttr || dictAttr.empty ();
589+ })) {
590+ this ->setOperandAttrsAttr (
591+ ArrayAttr::get (this ->getContext (), newGenericArgAttrs));
592+ }
545593 });
546594 rewriter.replaceOp (opToExtract, oldGenericNewBlockArgs);
547595
596+ LLVM_DEBUG ({
597+ llvm::dbgs () << " After adding new single-op generic:\n " ;
598+ newGeneric->getParentOp ()->dump ();
599+ });
548600 return newGeneric;
549601}
550602
0 commit comments