@@ -603,11 +603,16 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
603
603
return l;
604
604
}
605
605
606
- void LoopEmitter::genDenseAffineAddress (OpBuilder &builder, Location loc,
607
- TensorLevel tidLvl,
608
- AffineExpr lvlExpr) {
606
+ void LoopEmitter::locateLvlAtAffineAddress (OpBuilder &builder, Location loc,
607
+ TensorLevel tidLvl,
608
+ AffineExpr lvlExpr) {
609
609
auto [tid, lvl] = unpackTensorLevel (tidLvl);
610
+
611
+ const SparseIterator *parent =
612
+ lvl == 0 ? nullptr : iters[tid][lvl - 1 ].back ().get ();
610
613
auto &it = getCurIterator (tid, lvl);
614
+ it.genInit (builder, loc, parent);
615
+
611
616
assert (it.kind == IterKind::kTrivial && it.randomAccessible ());
612
617
Value lvlCrd = genAffine (builder, loc, lvlExpr);
613
618
it.locate (builder, loc, lvlCrd);
@@ -710,9 +715,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
710
715
// However, that would result in a rather elaborate forest of yield
711
716
// instructions during code generation. Moreover, performing the induction
712
717
// after the if-statements more closely resembles code generated by TACO.
713
- unsigned o = 0 ;
714
718
SmallVector<Value> operands;
715
- unsigned delta = 0 ;
716
719
ValueRange whileRes = whileOp.getResults ();
717
720
718
721
for (auto [tid, lvl] : unpackTensorLevelRange (loopInfo.tidLvls )) {
@@ -722,7 +725,6 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
722
725
Value cmp = CMPI (eq, it.getCrd (), iv);
723
726
it.forwardIf (builder, loc, cmp);
724
727
operands.append (it.getItVals ().begin (), it.getItVals ().end ());
725
- o += it.getItVals ().size ();
726
728
// const Value newPos = whileOp->getResult(o++);
727
729
// Following loops continue iteration from the break point of the
728
730
// current while loop.
@@ -738,20 +740,20 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
738
740
// Reduction value from users.
739
741
for (auto &i : reduc) {
740
742
operands.push_back (i);
741
- // In place update reduction variable.
742
- i = whileOp->getResult (o++);
743
+ // Update user reduction variables.
744
+ i = whileRes.front ();
745
+ whileRes = whileRes.drop_front ();
743
746
}
744
747
745
748
// An (optional) universal index.
746
- if (operands.size () + delta < whileOp.getNumResults ()) {
747
- assert (operands.size () + delta + 1 == whileOp.getNumResults ());
749
+ if (operands.size () < whileOp.getNumResults ()) {
750
+ assert (operands.size () + 1 == whileOp.getNumResults ());
748
751
// The last one is the universial index.
749
752
operands.push_back (ADDI (iv, one));
750
753
// update the loop starting point of current loop sequence
751
- loopSeqStack.back ().first = whileOp->getResult (o++ );
754
+ loopSeqStack.back ().first = whileOp->getResults (). back ( );
752
755
}
753
756
754
- assert (o == operands.size () + delta);
755
757
if (!operands.empty ())
756
758
YIELD (operands);
757
759
0 commit comments