Skip to content

Commit b54c326

Browse files
author
Peiming Liu
committed
minor cleanup
1 parent 9d27c8c commit b54c326

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -465,14 +465,12 @@ class NonEmptySubSectIterator : public SparseIterator {
465465
NonEmptySubSectIterator(OpBuilder &b, Location l,
466466
const SparseIterator *parent,
467467
std::unique_ptr<SparseIterator> &&delegate,
468-
Value subSectSz, unsigned stride)
468+
Value subSectSz)
469469
: SparseIterator(IterKind::kNonEmptySubSect, delegate->tid, delegate->lvl,
470470
/*itVals=*/subSectMeta),
471-
subSectSz(subSectSz), stride(stride), parent(parent),
472-
delegate(std::move(delegate)) {
473-
471+
parent(parent), delegate(std::move(delegate)),
472+
tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
474473
auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
475-
assert(stride == 1);
476474
if (p == nullptr) {
477475
// Extract subsections along the root level.
478476
maxTupleCnt = C_IDX(1);
@@ -488,8 +486,6 @@ class NonEmptySubSectIterator : public SparseIterator {
488486
// We don't need an extra buffer to find subsections on dense levels.
489487
if (randomAccessible())
490488
return;
491-
// The number of values we need to store to serialize the wrapped iterator.
492-
tupleSz = this->delegate->serialize().size();
493489
subSectPosBuf = allocSubSectPosBuf(b, l);
494490
}
495491

@@ -574,7 +570,6 @@ class NonEmptySubSectIterator : public SparseIterator {
574570
}
575571

576572
Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const {
577-
assert(stride == 1);
578573
return SUBI(wrapCrd, getAbsOff());
579574
}
580575

@@ -598,18 +593,17 @@ class NonEmptySubSectIterator : public SparseIterator {
598593
Value getAbsOff() const { return subSectMeta[1]; }
599594
Value getNotEnd() const { return subSectMeta[2]; }
600595

596+
const SparseIterator *parent;
597+
std::unique_ptr<SparseIterator> delegate;
598+
601599
// Number of values required to serialize the wrapped iterator.
602-
unsigned tupleSz;
600+
const unsigned tupleSz;
603601
// Max number of tuples, and the actual number of tuple.
604602
Value maxTupleCnt, tupleCnt;
605603
// The memory used to cache the tuple serialized from the wrapped iterator.
606604
Value subSectPosBuf;
607605

608606
const Value subSectSz;
609-
const unsigned stride;
610-
611-
const SparseIterator *parent;
612-
std::unique_ptr<SparseIterator> delegate;
613607

614608
Value subSectMeta[3]; // minCrd, absolute offset, notEnd
615609
};
@@ -1189,8 +1183,6 @@ ValueRange NonEmptySubSectIterator::forward(OpBuilder &b, Location l) {
11891183
Value minAbsOff = ADDI(getAbsOff(), c1);
11901184
nxAbsOff = b.create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff);
11911185

1192-
assert(stride == 1 && "Not yet implemented");
1193-
11941186
seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
11951187
// The coordinate should not exceeds the space upper bound.
11961188
Value crd = deref(b, l);
@@ -1286,7 +1278,7 @@ std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
12861278
// Try unwrap the NonEmptySubSectIterator from a filter parent.
12871279
parent = tryUnwrapFilter<NonEmptySubSectIterator>(parent);
12881280
auto it = std::make_unique<NonEmptySubSectIterator>(
1289-
b, l, parent, std::move(delegate), size, 1);
1281+
b, l, parent, std::move(delegate), size);
12901282

12911283
if (stride != 1)
12921284
return std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class SparseIterator {
176176

177177
// Generate a conditional it.next() in the following form
178178
//
179-
// if (crd == it.crd)
179+
// if (cond)
180180
// yield it.next
181181
// else
182182
// yield it
@@ -185,7 +185,7 @@ class SparseIterator {
185185
// if it.next() is trivial to compute, we can use a select operation instead.
186186
// E.g.,
187187
//
188-
// it = select crd == it.crd ? it+1 : it
188+
// it = select cond ? it+1 : it
189189
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond);
190190

191191
// Locate the iterator to the position specified by *crd*, this can only

0 commit comments

Comments
 (0)