@@ -394,6 +394,10 @@ class DedupIterator : public SparseIterator {
394
394
const SparseTensorLevel &stl;
395
395
};
396
396
397
+ //
398
+ // A filter iterator wrapped from another iterator. The filter iterator update
399
+ // the wrapped iterator *in-place*.
400
+ //
397
401
class FilterIterator : public SparseIterator {
398
402
// Coorindate translation between crd loaded from the wrap iterator and the
399
403
// filter iterator.
@@ -411,6 +415,8 @@ class FilterIterator : public SparseIterator {
411
415
Value genShouldFilter (OpBuilder &b, Location l);
412
416
413
417
public:
418
+ // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
419
+ // when crd always < size.
414
420
FilterIterator (std::unique_ptr<SparseIterator> &&wrap, Value offset,
415
421
Value stride, Value size)
416
422
: SparseIterator(IterKind::kFilter , *wrap), offset(offset),
@@ -548,9 +554,10 @@ class NonEmptySubSectIterator : public SparseIterator {
548
554
return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
549
555
}
550
556
551
- ValueRange genSubSectTraverseTillRoot (OpBuilder &b, Location l,
552
- ValueRange reduc,
553
- TraverseBuilder builder) const ;
557
+ // Generate code that inflate the current subsection tree till the current
558
+ // level such that every leaf node is visited.
559
+ ValueRange inflateSubSectTree (OpBuilder &b, Location l, ValueRange reduc,
560
+ TraverseBuilder builder) const ;
554
561
555
562
bool randomAccessible () const override {
556
563
return delegate->randomAccessible ();
@@ -861,24 +868,35 @@ ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
861
868
assert (!randomAccessible ());
862
869
// Generates
863
870
//
864
- // wrap ++ ;
865
- // while !it.end() && !legit(*it)
871
+ // bool isFirst = true ;
872
+ // while !it.end() && ( !legit(*it) || isFirst )
866
873
// wrap ++;
867
- wrap->forward (b, l);
874
+ // isFirst = false;
875
+ //
876
+ // We do not hoist the first `wrap++` outside the loop but use a `isFirst`
877
+ // flag here because `wrap++` might have a complex implementation (e.g., to
878
+ // forward a subsection).
879
+ Value isFirst = constantI1 (b, l, true );
880
+
881
+ SmallVector<Value> whileArgs (getItVals ().begin (), getItVals ().end ());
882
+ whileArgs.push_back (isFirst);
883
+
868
884
auto whileOp = b.create <scf::WhileOp>(
869
- l, getItVals ( ).getTypes (), getItVals () ,
885
+ l, ValueRange (whileArgs ).getTypes (), whileArgs ,
870
886
/* beforeBuilder=*/
871
887
[this ](OpBuilder &b, Location l, ValueRange ivs) {
872
- linkNewScope (ivs);
888
+ ValueRange isFirst = linkNewScope (ivs);
889
+ assert (isFirst.size () == 1 );
873
890
ValueRange cont =
874
891
genWhenInBound (b, l, *wrap, C_FALSE,
875
- [this ](OpBuilder &b, Location l,
876
- Value wrapCrd) -> scf::ValueVector {
892
+ [this , isFirst ](OpBuilder &b, Location l,
893
+ Value wrapCrd) -> scf::ValueVector {
877
894
// crd < size && !legit();
878
895
Value notLegit =
879
896
genCrdNotLegitPredicate (b, l, wrapCrd);
880
897
Value crd = fromWrapCrd (b, l, wrapCrd);
881
898
Value ret = ANDI (CMPI (ult, crd, size), notLegit);
899
+ ret = ORI (ret, isFirst.front ());
882
900
return {ret};
883
901
});
884
902
b.create <scf::ConditionOp>(l, cont.front (), ivs);
@@ -887,7 +905,9 @@ ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
887
905
[this ](OpBuilder &b, Location l, ValueRange ivs) {
888
906
linkNewScope (ivs);
889
907
wrap->forward (b, l);
890
- YIELD (getItVals ());
908
+ SmallVector<Value> yieldVals (getItVals ().begin (), getItVals ().end ());
909
+ yieldVals.push_back (constantI1 (b, l, false ));
910
+ YIELD (yieldVals);
891
911
});
892
912
893
913
b.setInsertionPointAfter (whileOp);
@@ -935,7 +955,7 @@ ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
935
955
return wrap.forward (b, l);
936
956
}
937
957
938
- ValueRange NonEmptySubSectIterator::genSubSectTraverseTillRoot (
958
+ ValueRange NonEmptySubSectIterator::inflateSubSectTree (
939
959
OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
940
960
// Set up the helper to help traverse a sparse subsection.
941
961
SubSectIterHelper helper (*this );
@@ -1009,29 +1029,30 @@ ValueRange NonEmptySubSectIterator::genSubSectTraverseTillRoot(
1009
1029
// Else, this is not the root, recurse until root.
1010
1030
auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1011
1031
assert (p->lvl + 1 == lvl);
1012
- return p->genSubSectTraverseTillRoot (b, l, reduc, visitDenseSubSect);
1032
+ return p->inflateSubSectTree (b, l, reduc, visitDenseSubSect);
1013
1033
}
1014
1034
1015
1035
void NonEmptySubSectIterator::genInit (OpBuilder &b, Location l,
1016
1036
const SparseIterator *) {
1017
1037
Value c0 = C_IDX (0 );
1018
1038
if (!isSubSectRoot ()) {
1019
1039
assert (parent->lvl + 1 == lvl);
1020
- // We can not call wrap->genInit() here to initialize the wrapped iterator,
1021
- // because the parent of the curent iterator is still unresolved.
1022
1040
if (randomAccessible ()) {
1041
+ // We can not call wrap->genInit() here to initialize the wrapped
1042
+ // iterator, because the parent of the curent iterator is still
1043
+ // unresolved.
1023
1044
seek ({/* minCrd=*/ c0, /* offset=*/ c0, /* notEnd=*/ C_TRUE});
1024
1045
return ;
1025
1046
}
1026
1047
1027
1048
auto *p = cast<NonEmptySubSectIterator>(parent);
1028
-
1029
1049
SmallVector<Value, 3 > reduc = {
1030
1050
C_IDX (-1 ), // minCrd (max signless integer)
1031
1051
c0, // tupleId
1032
1052
};
1033
1053
1034
- ValueRange result = p->genSubSectTraverseTillRoot (
1054
+ // Expand the subsection tree from the parent level to the current level.
1055
+ ValueRange result = p->inflateSubSectTree (
1035
1056
b, l, reduc,
1036
1057
[this ](OpBuilder &b, Location l, const SparseIterator *parent,
1037
1058
ValueRange reduc) -> scf::ValueVector {
@@ -1071,6 +1092,8 @@ void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
1071
1092
// to one node.
1072
1093
assert (isSubSectRoot ());
1073
1094
1095
+ // Initialize the position, the position marks the *lower bound* of the
1096
+ // subRange. The higher bound is determined by the size of the subsection.
1074
1097
delegate->genInit (b, l, parent);
1075
1098
if (randomAccessible ()) {
1076
1099
seek ({/* minCrd=*/ c0, /* offset=*/ c0, /* notEnd=*/ C_TRUE});
@@ -1251,19 +1274,45 @@ sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1251
1274
return std::make_unique<FilterIterator>(std::move (sit), offset, stride, size);
1252
1275
}
1253
1276
1277
+ template <typename IterType>
1278
+ static const SparseIterator *tryUnwrapFilter (const SparseIterator *it) {
1279
+ auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1280
+ if (filter && llvm::isa<IterType>(filter->wrap .get ())) {
1281
+ return filter->wrap .get ();
1282
+ }
1283
+ return it;
1284
+ }
1285
+ template <typename IterType>
1286
+ static const IterType *unwrapFilter (const SparseIterator *it) {
1287
+ auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1288
+ if (filter) {
1289
+ return llvm::cast<IterType>(filter->wrap .get ());
1290
+ }
1291
+ return llvm::cast<IterType>(it);
1292
+ }
1293
+
1254
1294
std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator (
1255
1295
OpBuilder &b, Location l, const SparseIterator *parent,
1256
1296
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
1257
- return std::make_unique<NonEmptySubSectIterator>(
1258
- b, l, parent, std::move (delegate), size, stride);
1297
+
1298
+ // Try unwrap the NonEmptySubSectIterator from a filter parent.
1299
+ parent = tryUnwrapFilter<NonEmptySubSectIterator>(parent);
1300
+ auto it = std::make_unique<NonEmptySubSectIterator>(
1301
+ b, l, parent, std::move (delegate), size, 1 );
1302
+
1303
+ if (stride != 1 )
1304
+ return std::make_unique<FilterIterator>(std::move (it), /* offset=*/ C_IDX (0 ),
1305
+ C_IDX (stride), /* size=*/ C_IDX (-1 ));
1306
+ return it;
1259
1307
}
1260
1308
1261
1309
std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator (
1262
- const SparseIterator &subsectIter , const SparseIterator &parent,
1310
+ const SparseIterator &subSectIter , const SparseIterator &parent,
1263
1311
std::unique_ptr<SparseIterator> &&wrap, Value size, unsigned stride) {
1264
- return std::make_unique<SubSectIterator>(
1265
- llvm::cast<NonEmptySubSectIterator>(subsectIter), parent, std::move (wrap),
1266
- size, stride);
1312
+ // This must be a subsection iterator or a filtered subsection iterator.
1313
+ auto &subSect = *unwrapFilter<NonEmptySubSectIterator>(&subSectIter);
1314
+ return std::make_unique<SubSectIterator>(subSect, parent, std::move (wrap),
1315
+ size, stride);
1267
1316
}
1268
1317
1269
1318
#undef CMPI
0 commit comments