Skip to content

Commit ac21c45

Browse files
author
Peiming Liu
committed
fix bugs
1 parent 8634c59 commit ac21c45

File tree

3 files changed

+18
-16
lines changed

3 files changed

+18
-16
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
11031103
for (Level l = startLvl; l < lvlRank; l++) {
11041104
AffineExpr lvlExpr = lvlExprs[l];
11051105
if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
1106-
env.emitter().genDenseAffineAddress(
1106+
env.emitter().locateLvlAtAffineAddress(
11071107
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
11081108
else
11091109
return; // break on first non-dense non-constant level
@@ -1152,7 +1152,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
11521152
Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
11531153
Location loc = env.op().getLoc();
11541154
for (auto [tidLvl, exp] : affineTidLvls) {
1155-
env.emitter().genDenseAffineAddress(builder, loc, tidLvl, exp);
1155+
env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
11561156
}
11571157

11581158
// Until now, we have entered every <tid, lvl> pair in {cond, extra,

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -603,11 +603,16 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
603603
return l;
604604
}
605605

606-
void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
607-
TensorLevel tidLvl,
608-
AffineExpr lvlExpr) {
606+
void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc,
607+
TensorLevel tidLvl,
608+
AffineExpr lvlExpr) {
609609
auto [tid, lvl] = unpackTensorLevel(tidLvl);
610+
611+
const SparseIterator *parent =
612+
lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
610613
auto &it = getCurIterator(tid, lvl);
614+
it.genInit(builder, loc, parent);
615+
611616
assert(it.kind == IterKind::kTrivial && it.randomAccessible());
612617
Value lvlCrd = genAffine(builder, loc, lvlExpr);
613618
it.locate(builder, loc, lvlCrd);
@@ -710,9 +715,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
710715
// However, that would result in a rather elaborate forest of yield
711716
// instructions during code generation. Moreover, performing the induction
712717
// after the if-statements more closely resembles code generated by TACO.
713-
unsigned o = 0;
714718
SmallVector<Value> operands;
715-
unsigned delta = 0;
716719
ValueRange whileRes = whileOp.getResults();
717720

718721
for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
@@ -722,7 +725,6 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
722725
Value cmp = CMPI(eq, it.getCrd(), iv);
723726
it.forwardIf(builder, loc, cmp);
724727
operands.append(it.getItVals().begin(), it.getItVals().end());
725-
o += it.getItVals().size();
726728
// const Value newPos = whileOp->getResult(o++);
727729
// Following loops continue iteration from the break point of the
728730
// current while loop.
@@ -738,20 +740,20 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
738740
// Reduction value from users.
739741
for (auto &i : reduc) {
740742
operands.push_back(i);
741-
// In place update reduction variable.
742-
i = whileOp->getResult(o++);
743+
// Update user reduction variables.
744+
i = whileRes.front();
745+
whileRes = whileRes.drop_front();
743746
}
744747

745748
// An (optional) universal index.
746-
if (operands.size() + delta < whileOp.getNumResults()) {
747-
assert(operands.size() + delta + 1 == whileOp.getNumResults());
749+
if (operands.size() < whileOp.getNumResults()) {
750+
assert(operands.size() + 1 == whileOp.getNumResults());
748751
// The last one is the universial index.
749752
operands.push_back(ADDI(iv, one));
750753
// update the loop starting point of current loop sequence
751-
loopSeqStack.back().first = whileOp->getResult(o++);
754+
loopSeqStack.back().first = whileOp->getResults().back();
752755
}
753756

754-
assert(o == operands.size() + delta);
755757
if (!operands.empty())
756758
YIELD(operands);
757759

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ class LoopEmitter {
126126

127127
/// Emits the address for a dense level based on the value evaluated by the
128128
/// provided affine expression.
129-
void genDenseAffineAddress(OpBuilder &builder, Location loc,
130-
TensorLevel tidLvl, AffineExpr lvlExpr);
129+
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc,
130+
TensorLevel tidLvl, AffineExpr lvlExpr);
131131

132132
// TODO: Get rid of `lvls` in the argument list? Track the level we
133133
// are currently at internally. Then it would be enterNextLvlForTensor.

0 commit comments

Comments
 (0)