Skip to content

Commit e453449

Browse files
author
Peiming Liu
committed
pass all integration tests.
1 parent 75673b0 commit e453449

File tree

2 files changed

+75
-25
lines changed

2 files changed

+75
-25
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,10 @@ class DedupIterator : public SparseIterator {
394394
const SparseTensorLevel &stl;
395395
};
396396

397+
//
398+
// A filter iterator wrapped from another iterator. The filter iterator update
399+
// the wrapped iterator *in-place*.
400+
//
397401
class FilterIterator : public SparseIterator {
398402
// Coorindate translation between crd loaded from the wrap iterator and the
399403
// filter iterator.
@@ -411,6 +415,8 @@ class FilterIterator : public SparseIterator {
411415
Value genShouldFilter(OpBuilder &b, Location l);
412416

413417
public:
418+
// TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
419+
// when crd always < size.
414420
FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
415421
Value stride, Value size)
416422
: SparseIterator(IterKind::kFilter, *wrap), offset(offset),
@@ -548,9 +554,10 @@ class NonEmptySubSectIterator : public SparseIterator {
548554
return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
549555
}
550556

551-
ValueRange genSubSectTraverseTillRoot(OpBuilder &b, Location l,
552-
ValueRange reduc,
553-
TraverseBuilder builder) const;
557+
// Generate code that inflate the current subsection tree till the current
558+
// level such that every leaf node is visited.
559+
ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
560+
TraverseBuilder builder) const;
554561

555562
bool randomAccessible() const override {
556563
return delegate->randomAccessible();
@@ -861,24 +868,35 @@ ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
861868
assert(!randomAccessible());
862869
// Generates
863870
//
864-
// wrap ++;
865-
// while !it.end() && !legit(*it)
871+
// bool isFirst = true;
872+
// while !it.end() && (!legit(*it) || isFirst)
866873
// wrap ++;
867-
wrap->forward(b, l);
874+
// isFirst = false;
875+
//
876+
// We do not hoist the first `wrap++` outside the loop but use a `isFirst`
877+
// flag here because `wrap++` might have a complex implementation (e.g., to
878+
// forward a subsection).
879+
Value isFirst = constantI1(b, l, true);
880+
881+
SmallVector<Value> whileArgs(getItVals().begin(), getItVals().end());
882+
whileArgs.push_back(isFirst);
883+
868884
auto whileOp = b.create<scf::WhileOp>(
869-
l, getItVals().getTypes(), getItVals(),
885+
l, ValueRange(whileArgs).getTypes(), whileArgs,
870886
/*beforeBuilder=*/
871887
[this](OpBuilder &b, Location l, ValueRange ivs) {
872-
linkNewScope(ivs);
888+
ValueRange isFirst = linkNewScope(ivs);
889+
assert(isFirst.size() == 1);
873890
ValueRange cont =
874891
genWhenInBound(b, l, *wrap, C_FALSE,
875-
[this](OpBuilder &b, Location l,
876-
Value wrapCrd) -> scf::ValueVector {
892+
[this, isFirst](OpBuilder &b, Location l,
893+
Value wrapCrd) -> scf::ValueVector {
877894
// crd < size && !legit();
878895
Value notLegit =
879896
genCrdNotLegitPredicate(b, l, wrapCrd);
880897
Value crd = fromWrapCrd(b, l, wrapCrd);
881898
Value ret = ANDI(CMPI(ult, crd, size), notLegit);
899+
ret = ORI(ret, isFirst.front());
882900
return {ret};
883901
});
884902
b.create<scf::ConditionOp>(l, cont.front(), ivs);
@@ -887,7 +905,9 @@ ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
887905
[this](OpBuilder &b, Location l, ValueRange ivs) {
888906
linkNewScope(ivs);
889907
wrap->forward(b, l);
890-
YIELD(getItVals());
908+
SmallVector<Value> yieldVals(getItVals().begin(), getItVals().end());
909+
yieldVals.push_back(constantI1(b, l, false));
910+
YIELD(yieldVals);
891911
});
892912

893913
b.setInsertionPointAfter(whileOp);
@@ -935,7 +955,7 @@ ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
935955
return wrap.forward(b, l);
936956
}
937957

938-
ValueRange NonEmptySubSectIterator::genSubSectTraverseTillRoot(
958+
ValueRange NonEmptySubSectIterator::inflateSubSectTree(
939959
OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
940960
// Set up the helper to help traverse a sparse subsection.
941961
SubSectIterHelper helper(*this);
@@ -1009,29 +1029,30 @@ ValueRange NonEmptySubSectIterator::genSubSectTraverseTillRoot(
10091029
// Else, this is not the root, recurse until root.
10101030
auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
10111031
assert(p->lvl + 1 == lvl);
1012-
return p->genSubSectTraverseTillRoot(b, l, reduc, visitDenseSubSect);
1032+
return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
10131033
}
10141034

10151035
void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
10161036
const SparseIterator *) {
10171037
Value c0 = C_IDX(0);
10181038
if (!isSubSectRoot()) {
10191039
assert(parent->lvl + 1 == lvl);
1020-
// We can not call wrap->genInit() here to initialize the wrapped iterator,
1021-
// because the parent of the curent iterator is still unresolved.
10221040
if (randomAccessible()) {
1041+
// We can not call wrap->genInit() here to initialize the wrapped
1042+
// iterator, because the parent of the curent iterator is still
1043+
// unresolved.
10231044
seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
10241045
return;
10251046
}
10261047

10271048
auto *p = cast<NonEmptySubSectIterator>(parent);
1028-
10291049
SmallVector<Value, 3> reduc = {
10301050
C_IDX(-1), // minCrd (max signless integer)
10311051
c0, // tupleId
10321052
};
10331053

1034-
ValueRange result = p->genSubSectTraverseTillRoot(
1054+
// Expand the subsection tree from the parent level to the current level.
1055+
ValueRange result = p->inflateSubSectTree(
10351056
b, l, reduc,
10361057
[this](OpBuilder &b, Location l, const SparseIterator *parent,
10371058
ValueRange reduc) -> scf::ValueVector {
@@ -1071,6 +1092,8 @@ void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
10711092
// to one node.
10721093
assert(isSubSectRoot());
10731094

1095+
// Initialize the position, the position marks the *lower bound* of the
1096+
// subRange. The higher bound is determined by the size of the subsection.
10741097
delegate->genInit(b, l, parent);
10751098
if (randomAccessible()) {
10761099
seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
@@ -1251,19 +1274,45 @@ sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
12511274
return std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
12521275
}
12531276

1277+
template <typename IterType>
1278+
static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
1279+
auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1280+
if (filter && llvm::isa<IterType>(filter->wrap.get())) {
1281+
return filter->wrap.get();
1282+
}
1283+
return it;
1284+
}
1285+
template <typename IterType>
1286+
static const IterType *unwrapFilter(const SparseIterator *it) {
1287+
auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1288+
if (filter) {
1289+
return llvm::cast<IterType>(filter->wrap.get());
1290+
}
1291+
return llvm::cast<IterType>(it);
1292+
}
1293+
12541294
std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
12551295
OpBuilder &b, Location l, const SparseIterator *parent,
12561296
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
1257-
return std::make_unique<NonEmptySubSectIterator>(
1258-
b, l, parent, std::move(delegate), size, stride);
1297+
1298+
// Try unwrap the NonEmptySubSectIterator from a filter parent.
1299+
parent = tryUnwrapFilter<NonEmptySubSectIterator>(parent);
1300+
auto it = std::make_unique<NonEmptySubSectIterator>(
1301+
b, l, parent, std::move(delegate), size, 1);
1302+
1303+
if (stride != 1)
1304+
return std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1305+
C_IDX(stride), /*size=*/C_IDX(-1));
1306+
return it;
12591307
}
12601308

12611309
std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
1262-
const SparseIterator &subsectIter, const SparseIterator &parent,
1310+
const SparseIterator &subSectIter, const SparseIterator &parent,
12631311
std::unique_ptr<SparseIterator> &&wrap, Value size, unsigned stride) {
1264-
return std::make_unique<SubSectIterator>(
1265-
llvm::cast<NonEmptySubSectIterator>(subsectIter), parent, std::move(wrap),
1266-
size, stride);
1312+
// This must be a subsection iterator or a filtered subsection iterator.
1313+
auto &subSect = *unwrapFilter<NonEmptySubSectIterator>(&subSectIter);
1314+
return std::make_unique<SubSectIterator>(subSect, parent, std::move(wrap),
1315+
size, stride);
12671316
}
12681317

12691318
#undef CMPI

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ class SparseIterator {
114114
virtual Value upperBound(OpBuilder &b, Location l) const = 0;
115115

116116
// Serialize and deserialize the current status to/from a set of values. The
117-
// ValueRange should contain values that specifies the postion and loop bound.
117+
// ValueRange should contain values that specifies the current postion and
118+
// loop bound.
118119
//
119120
// Not every type of iterator supports the operations, e.g., non-empty
120121
// subsection iterator does not because the the number of non-empty
121-
// subsections can not be determined in advance.
122+
// subsections can not be determined easily.
122123
//
123124
// NOTE: All the values should have index type.
124125
virtual SmallVector<Value> serialize() const {

0 commit comments

Comments
 (0)