Skip to content

Commit 73df711

Browse files
author
Peiming Liu
committed
address comments.
1 parent f2f55d4 commit 73df711

File tree

5 files changed

+17
-17
lines changed

5 files changed

+17
-17
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,9 +1150,9 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
11501150

11511151
Operation &last = rewriter.getBlock()->back();
11521152
if (llvm::isa<scf::YieldOp>(last)) {
1153-
// scf.for inserts a implicit yield op when there is no reduction
1154-
// variable upon creation, in this case we need to merge the block
1155-
// *before* the yield op.
1153+
// Because `scf.for` inserts an implicit yield op when there is no
1154+
// reduction variable upon creation, we reset the insertion point such
1155+
// that the block is inlined before *before* the yield op.
11561156
rewriter.setInsertionPoint(&last);
11571157
}
11581158

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,9 +1032,9 @@ static bool getAllTidLvlsInLatPoints(
10321032

10331033
if (isDenseLT(env.lt(outTid, curr))) {
10341034
auto stt = getSparseTensorType(env.op().getOutputs().front());
1035-
// Note that we generate dense indices of the output tensor
1036-
// unconditionally, since they may not appear in the lattice, but may be
1037-
// needed for linearized env.
1035+
// Note that we generate dense indices of the output tensor unconditionally,
1036+
// since they may not appear in the lattice, but may be needed for
1037+
// linearized env.
10381038
// TODO: we should avoid introducing corner cases for all-dense sparse
10391039
// tensors.
10401040
if (stt.hasEncoding() && stt.isAllDense())
@@ -1067,8 +1067,9 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
10671067

10681068
SmallVector<TensorLevel> tidLvls;
10691069
getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1070-
// TODO: remove this! Duplication can be introduced due to the speical
1071-
// handling for all-dense "sparse" output tensor.
1070+
// TODO: remove this! The same tensor level might be added for multiple
1071+
// times due to the special handling for all-dense "sparse" output tensor
1072+
// (see L1038).
10721073
if (llvm::find(tidLvls, tl) != tidLvls.end())
10731074
return;
10741075
tidLvls.emplace_back(tl);

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ class LoopEmitter {
408408
/// alive.
409409
std::vector<LoopInfo> loopStack;
410410

411-
// Loop Sequence Stack, stores the unversial index for the current loop
411+
// Loop Sequence Stack, stores the universal index for the current loop
412412
// sequence. and a list of tid level that the loop sequence traverse.
413413
std::vector<std::pair<Value, std::vector<TensorLevel>>> loopSeqStack;
414414
};

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ class SubSectIterator : public SparseIterator {
778778
} // namespace
779779

780780
//===----------------------------------------------------------------------===//
781-
// Complex SparseIterator derived classes impl.
781+
// SparseIterator derived classes implementation.
782782
//===----------------------------------------------------------------------===//
783783

784784
ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) {
@@ -819,7 +819,6 @@ Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
819819
},
820820
/*afterBuilder=*/
821821
[](OpBuilder &b, Location l, ValueRange ivs) {
822-
// pos ++
823822
Value nxPos = ADDI(ivs[0], C_IDX(1));
824823
YIELD(nxPos);
825824
});
@@ -830,11 +829,11 @@ Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
830829
Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
831830
Value wrapCrd) {
832831
Value crd = fromWrapCrd(b, l, wrapCrd);
833-
// not on stride
832+
// Test whether the coordinate is on stride.
834833
Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
835-
// wrapCrd < offset
834+
// Test wrapCrd < offset
836835
notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
837-
// crd >= length
836+
// Test crd >= length
838837
notlegit = ORI(CMPI(uge, crd, size), notlegit);
839838
return notlegit;
840839
}

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ class SparseTensorLevel {
2929
/// the given position `p` that the immediate parent level is current at.
3030
/// Returns a pair of values for *posLo* and *loopHi* respectively.
3131
///
32-
/// For dense level, the *posLo* is the linearized position at beginning,
32+
/// For a dense level, the *posLo* is the linearized position at beginning,
3333
/// while *loopHi* is the largest *coordinate*, it also implies that the
3434
/// smallest *coordinate* to start the loop is 0.
3535
///
36-
/// For sparse level, [posLo, loopHi) specifies the range of index pointer to
37-
/// load coordinate from the coordinate buffer.
36+
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
37+
/// to load coordinate from the coordinate buffer.
3838
///
3939
/// `bound` is only used when the level is `non-unique` and deduplication is
4040
/// required. It specifies the max upper bound of the non-unique segment.

0 commit comments

Comments
 (0)