Skip to content

[mlir][sparse] setup SparseIterator to help generating code to traverse a sparse tensor level. #78345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 24, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
}

Value vals = loopEmitter.getValBuffer()[0];
Value pos = loopEmitter.getPosits()[0].back();
Value pos = loopEmitter.getValPosits(0);
// Loads the value from sparse tensor using position-index;
// loads the value from dense tensor using coords.
Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
Expand All @@ -1148,17 +1148,17 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
rewriter.eraseOp(srcBlock->getTerminator());

// Inline body.
if (!reducValue.empty()) {
rewriter.mergeBlocks(srcBlock, rewriter.getBlock(), args);
} else {
// This is annoying, since scf.for inserts a implicit yield op when
// there is no reduction variable upon creation, in this case we need to
// merge the block *before* the yield op.
rewriter.inlineBlockBefore(srcBlock, &*rewriter.getInsertionPoint(),
args);
Operation &last = rewriter.getBlock()->back();
if (llvm::isa<scf::YieldOp>(last)) {
// Because `scf.for` inserts an implicit yield op when there is no
// reduction variable upon creation, we reset the insertion point such
// that the block is inlined before *before* the yield op.
rewriter.setInsertionPoint(&last);
}

rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
rewriter.getInsertionPoint(), args);
rewriter.setInsertionPointToEnd(rewriter.getBlock());
for (Level l = 0; l < lvlRank; l++) {
// Link the reduction chain. Note that loop emitter update the reducValue
// in place.
Expand Down
28 changes: 18 additions & 10 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
const auto stt = getSparseTensorType(t->get());
if (stt.hasEncoding()) {
// For sparse tensors we only push the last-level's position onto `args`.
const auto pos = env.emitter().getPosits()[tid].back();
const auto pos = env.emitter().getValPosits(tid);
assert(pos);
args.push_back(pos);
} else {
Expand Down Expand Up @@ -815,8 +815,7 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
// Construct while-loop with a parameter for each index.
return env.emitter().enterCoIterationOverTensorsAtLvls(
builder, env.op().getLoc(), tidLvls, reduc, tryParallel,
/*genDedup=*/true, needsUniv);
builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
});
assert(loop);
return loop;
Expand Down Expand Up @@ -894,7 +893,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
if (isCompressedLT(lt) || isSingletonLT(lt) ||
isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
assert(lvl.has_value());
const Value crd = env.emitter().getCoords()[tid][*lvl];
const Value crd = env.emitter().getCoord(tid, *lvl);
const Value lvar = env.getLoopVar(curr);
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
crd, lvar);
Expand Down Expand Up @@ -1032,10 +1031,14 @@ static bool getAllTidLvlsInLatPoints(
});

if (isDenseLT(env.lt(outTid, curr))) {
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized env.
callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
auto stt = getSparseTensorType(env.op().getOutputs().front());
// Note that we generate dense indices of the output tensor unconditionally,
// since they may not appear in the lattice, but may be needed for
// linearized env.
// TODO: we should avoid introducing corner cases for all-dense sparse
// tensors.
if (stt.hasEncoding() && stt.isAllDense())
callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
}

if (numloopCond == 0) {
Expand Down Expand Up @@ -1064,6 +1067,11 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,

SmallVector<TensorLevel> tidLvls;
getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
// TODO: remove this! The same tensor level might be added for multiple
// times due to the special handling for all-dense "sparse" output tensor
// (see L1038).
if (llvm::find(tidLvls, tl) != tidLvls.end())
return;
tidLvls.emplace_back(tl);
});

Expand Down Expand Up @@ -1096,7 +1104,7 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
for (Level l = startLvl; l < lvlRank; l++) {
AffineExpr lvlExpr = lvlExprs[l];
if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
env.emitter().genDenseAffineAddress(
env.emitter().locateLvlAtAffineAddress(
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
else
return; // break on first non-dense non-constant level
Expand Down Expand Up @@ -1145,7 +1153,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
Location loc = env.op().getLoc();
for (auto [tidLvl, exp] : affineTidLvls) {
env.emitter().genDenseAffineAddress(builder, loc, tidLvl, exp);
env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
}

// Until now, we have entered every <tid, lvl> pair in {cond, extra,
Expand Down
Loading