diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index b1b8b762d164d..1883cf1ceed55 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1126,7 +1126,7 @@ struct ForeachRewriter : public OpRewritePattern { } Value vals = loopEmitter.getValBuffer()[0]; - Value pos = loopEmitter.getPosits()[0].back(); + Value pos = loopEmitter.getValPosits(0); // Loads the value from sparse tensor using position-index; // loads the value from dense tensor using coords. Value val = enc ? rewriter.create(loc, vals, pos) @@ -1148,17 +1148,17 @@ struct ForeachRewriter : public OpRewritePattern { SmallVector reducValue = srcBlock->getTerminator()->getOperands(); rewriter.eraseOp(srcBlock->getTerminator()); - // Inline body. - if (!reducValue.empty()) { - rewriter.mergeBlocks(srcBlock, rewriter.getBlock(), args); - } else { - // This is annoying, since scf.for inserts a implicit yield op when - // there is no reduction variable upon creation, in this case we need to - // merge the block *before* the yield op. - rewriter.inlineBlockBefore(srcBlock, &*rewriter.getInsertionPoint(), - args); + Operation &last = rewriter.getBlock()->back(); + if (llvm::isa(last)) { + // Because `scf.for` inserts an implicit yield op when there is no + // reduction variable upon creation, we reset the insertion point such + // that the block is inlined before *before* the yield op. + rewriter.setInsertionPoint(&last); } + rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(), + rewriter.getInsertionPoint(), args); + rewriter.setInsertionPointToEnd(rewriter.getBlock()); for (Level l = 0; l < lvlRank; l++) { // Link the reduction chain. Note that loop emitter update the reducValue // in place. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index fec23d2a72347..5266ca7213bfc 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -354,7 +354,7 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, const auto stt = getSparseTensorType(t->get()); if (stt.hasEncoding()) { // For sparse tensors we only push the last-level's position onto `args`. - const auto pos = env.emitter().getPosits()[tid].back(); + const auto pos = env.emitter().getValPosits(tid); assert(pos); args.push_back(pos); } else { @@ -815,8 +815,7 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { // Construct while-loop with a parameter for each index. return env.emitter().enterCoIterationOverTensorsAtLvls( - builder, env.op().getLoc(), tidLvls, reduc, tryParallel, - /*genDedup=*/true, needsUniv); + builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv); }); assert(loop); return loop; @@ -894,7 +893,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) || is2OutOf4LT(lt)) { assert(lvl.has_value()); - const Value crd = env.emitter().getCoords()[tid][*lvl]; + const Value crd = env.emitter().getCoord(tid, *lvl); const Value lvar = env.getLoopVar(curr); clause = builder.create(loc, arith::CmpIPredicate::eq, crd, lvar); @@ -1032,10 +1031,14 @@ static bool getAllTidLvlsInLatPoints( }); if (isDenseLT(env.lt(outTid, curr))) { - // Note that we generate dense indices of the output tensor - // unconditionally, since they may not appear in the lattice, but may be - // needed for linearized env. - callback(env.makeTensorLevel(outTid, *outLvl), nullptr); + auto stt = getSparseTensorType(env.op().getOutputs().front()); + // Note that we generate dense indices of the output tensor unconditionally, + // since they may not appear in the lattice, but may be needed for + // linearized env. + // TODO: we should avoid introducing corner cases for all-dense sparse + // tensors. + if (stt.hasEncoding() && stt.isAllDense()) + callback(env.makeTensorLevel(outTid, *outLvl), nullptr); } if (numloopCond == 0) { @@ -1064,6 +1067,11 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, SmallVector tidLvls; getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) { + // TODO: remove this! The same tensor level might be added for multiple + // times due to the special handling for all-dense "sparse" output tensor + // (see L1038). + if (llvm::find(tidLvls, tl) != tidLvls.end()) + return; tidLvls.emplace_back(tl); }); @@ -1096,7 +1104,7 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env, for (Level l = startLvl; l < lvlRank; l++) { AffineExpr lvlExpr = lvlExprs[l]; if (enc.isDenseLvl(l) && isa(lvlExpr)) - env.emitter().genDenseAffineAddress( + env.emitter().locateLvlAtAffineAddress( builder, loc, env.makeTensorLevel(tid, l), lvlExpr); else return; // break on first non-dense non-constant level @@ -1145,7 +1153,7 @@ static std::pair startLoop(CodegenEnv &env, Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls); Location loc = env.op().getLoc(); for (auto [tidLvl, exp] : affineTidLvls) { - env.emitter().genDenseAffineAddress(builder, loc, tidLvl, exp); + env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp); } // Until now, we have entered every pair in {cond, extra, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp index 3d8cc5222b828..0ce6a9efce1c8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp @@ -63,8 +63,6 @@ LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder, // specifies the range of the fragment, and pPtr specifies the index of the // corresponding fragment in the child level (i.e., a pointer to the sliced // position array). -static constexpr unsigned kSliceIterWidth = 3; - static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, Level lvl) { auto enc = getSparseTensorEncoding(tensor.getType()); @@ -77,217 +75,10 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl)); } -/// Converts a coordinate relative to the slice to the coordinate relative -/// to the underlying tensor. -// FIXME: that description says "sliceCrd -> tensorCrd"; but the function -// name suggests it should be "tensorCrd -> sliceCrd". -static Value toSliceCrd(OpBuilder &builder, Location loc, Value crd, - Value offset, Value stride, Value tensor, Level lvl) { - // tensorCrd = sliceCrd * stride + offset - return ADDI(MULI(crd, stride), offset); -} - -/// Generates code to compute the *absolute* offset of the slice based on the -/// provide minimum coordinates in the slice. -/// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the -/// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute* -/// offset is the offset computed relative to the initial tensors T. -/// -/// When isNonEmpty == true, the computed offset is meaningless and should not -/// be used during runtime, the method generates code to return 0 currently in -/// that case. -/// -/// offset = isNonEmpty && minCrd >= size ? minCrd - size + 1 : 0; -static Value offsetFromMinCoord(OpBuilder &builder, Location loc, Value minCrd, - Value size, Value isNonEmpty) { - Value geSize = CMPI(uge, minCrd, size); - Value pred = ANDI(isNonEmpty, geSize); - // Computes minCrd - size + 1 - Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size); - // This is the absolute offset related to the underly tensor. - return SELECT(pred, mms, C_IDX(0)); -} - -/// Converts a coordinate relative to the underlying tensor to the coordinate -/// relative to the slice, returns a extra reminder value -// FIXME: that description says "tensorCrd -> sliceCrd"; but the function -// name suggests it should be "sliceCrd -> tensorCrd". -static std::pair fromSliceCrd(OpBuilder &builder, Location loc, - Value crd, Value offset, - Value stride, Value tensor, - Level lvl) { - // sliceCrd = (tensorCrd - offset) / stride - crd = SUBI(crd, offset); - Value rem = REMUI(crd, stride); - crd = DIVUI(crd, stride); - return std::make_pair(crd, rem); -} - -// Generates a bool value for while loop condition that tries to iterate over a -// fully reduced level with affine index expression. -static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc, - const SparseTensorLevel &level, - Value crdHi, Value posit, Value posHi) { - Value inBound = CMPI(ult, posit, posHi); - auto ifOp = - builder.create(loc, builder.getI1Type(), inBound, true); - // if (inbound) - // yield coord < crdHi - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - Value crd = level.peekCrdAt(builder, loc, posit); - YIELD(CMPI(ult, crd, crdHi)); - // else - // yield false - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - YIELD(constantI1(builder, loc, false)); - - builder.setInsertionPointAfter(ifOp); - return ifOp.getResult(0); -} - -// Helper functions that load/store into the position buffer for slice-driven -// loops. -// The sliced pointer buffer is organized as: -// [[pLo0, pLo1, pLo2, ...], -// [pHi0, pHi1, pHi2, ...], -// [pNx0, pNx1, pNx2, ...]] -static Value allocSlicePosBuf(OpBuilder &builder, Location loc, - Value tupleCnt) { - Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth)); - // Additional two metadata {memSize, idx} at head. - return genAlloca(builder, loc, bufSz, builder.getIndexType()); -} - -// Gets and sets position values for slice-driven loops. -enum class SlicePosKind { kLo, kHi, kNext }; -static Value getSlicePosIdx(OpBuilder &builder, Location loc, Value posBuf, - Value tupleIdx, SlicePosKind posKind) { - Value dim = builder.create(loc, posBuf, C_IDX(0)); - Value tupleCnt = DIVUI(dim, C_IDX(kSliceIterWidth)); - switch (posKind) { - case SlicePosKind::kLo: - return tupleIdx; - case SlicePosKind::kHi: - return ADDI(tupleIdx, tupleCnt); - case SlicePosKind::kNext: - return ADDI(tupleIdx, MULI(tupleCnt, C_IDX(2))); - } - llvm_unreachable("unexpected kind"); -} -static Value loadSlicePos(OpBuilder &builder, Location loc, Value sPosBuf, - Value tupleIdx, SlicePosKind posKind) { - return genIndexLoad(builder, loc, sPosBuf, - getSlicePosIdx(builder, loc, sPosBuf, tupleIdx, posKind)); -} -static void updateSlicePos(OpBuilder &builder, Location loc, Value sPosBuf, - Value pos, Value tupleIdx, SlicePosKind posKind) { - builder.create( - loc, pos, sPosBuf, - getSlicePosIdx(builder, loc, sPosBuf, tupleIdx, posKind)); -} - -std::pair -LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd, - TensorId tid, Level lvl) { - assert(isSparseSlices[tid]); - Value slice = tensors[tid]; - Value offset = sliceOffsets[tid][lvl]; - Value stride = sliceStrides[tid][lvl]; - auto enc = getSparseTensorEncoding(slice.getType()); - - const auto [newCrd, crdRem] = - fromSliceCrd(builder, loc, crd, offset, stride, slice, lvl); - - SmallVector conds; // at most 3 conditions - - // First, coord >= offset (skip the check if offset is known to be 0). - if (auto staticOffset = enc.getStaticLvlSliceOffset(lvl); - !(staticOffset.has_value() && *staticOffset == 0)) { - auto geOffset = CMPI(uge, crd, offset); - conds.push_back(geOffset); - } - - // Second, coord_in_slice < length - auto ltLength = CMPI(ult, newCrd, lvlSizes[tid][lvl]); - conds.push_back(ltLength); - - // Third, rem == 0 (skip the check if stride is known to be 1). - if (auto staticStride = enc.getStaticLvlSliceStride(lvl); - !(staticStride.has_value() && *staticStride == 1)) { - auto fitStride = CMPI(eq, crdRem, C_IDX(0)); - conds.push_back(fitStride); - } - - // Must meet all condition to be a valid coordinate in slice. - auto pred = conds.front(); - for (auto cond : ValueRange(conds).drop_front()) - pred = ANDI(pred, cond); - - return {newCrd, pred}; -} - //===----------------------------------------------------------------------===// // Sparse tensor loop emitter class implementations //===----------------------------------------------------------------------===// -Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid, - Level lvl, Value crd) { - Value pos = lvl == 0 ? C_IDX(0) : posits[tid][lvl - 1]; - Value mul = MULI(highs[tid][lvl], pos); - if (isSparseSlices[tid]) - crd = toSliceCrd(builder, loc, crd, sliceOffsets[tid][lvl], - sliceStrides[tid][lvl], tensors[tid], lvl); - Value add = ADDI(mul, crd); - return add; -} - -Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc, - TensorId tid, Level lvl, Value pLo, - Value pHi) { - SparseTensorLevel &stl = *lvls[tid][lvl]; - const Value sameCrd = stl.peekCrdAt(builder, loc, pLo); - auto whileOp = builder.create( - loc, builder.getIndexType(), pLo, - /*beforeBuilder=*/ - [pHi, &stl, sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) { - const auto pos = ivs[0]; - Value inBound = builder.create( - loc, arith::CmpIPredicate::ult, pos, pHi); - auto ifInBound = - builder.create(loc, builder.getI1Type(), inBound, true); - { - OpBuilder::InsertionGuard guard(builder); - // Load the next coordinates only when inbound (to avoid OOB - // accesses). - builder.setInsertionPointToStart(ifInBound.thenBlock()); - Value crd = stl.peekCrdAt(builder, loc, pos); - Value isSameCrd = builder.create( - loc, arith::CmpIPredicate::eq, crd, sameCrd); - YIELD(isSameCrd); - // Else, the position is out of bound, yield false to terminate the - // loop. - builder.setInsertionPointToStart(ifInBound.elseBlock()); - YIELD(constantI1(builder, loc, false)); - } - builder.create(loc, ifInBound.getResults()[0], ivs); - }, - /*afterBuilder=*/ - [](OpBuilder &builder, Location loc, ValueRange ivs) { - // pos ++ - Value nextPos = ADDI(ivs[0], C_IDX(1)); - YIELD(nextPos); - }); - // Return the segment high. - return whileOp.getResult(0); -} - -Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid, - Level lvl) { - const Value pos = posits[tid][lvl]; - const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos); - return crd; -} - LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput, bool isSparseOut, unsigned numLoops, DependentLvlGetter dimGetter) { @@ -308,17 +99,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, // tensors array (len == numManifestTensor). this->tensors.assign(ts.begin(), ts.end()); // Arrays with len == numTensor. - this->lvlTypes.assign(numTensors, std::vector()); - this->lvlSizes.assign(numTensors, std::vector()); - this->highs.assign(numTensors, std::vector()); - this->segHi.assign(numTensors, std::vector()); - this->posits.assign(numTensors, std::vector()); - this->coords.assign(numTensors, std::vector()); this->valBuffer.assign(numTensors, nullptr); this->lvls.resize(numTensors); - this->isSparseSlices.assign(numTensors, false); - this->sliceOffsets.assign(numTensors, std::vector()); - this->sliceStrides.assign(numTensors, std::vector()); + this->iters.resize(numTensors); // These zeros will be overwritten below, but we need to initialize // them to something since we'll need random-access assignment. @@ -328,13 +111,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, // Index-reduction related fields. this->dependentLvlMap.assign( numTensors, std::vector>>()); - this->slicePosBuffer.assign(numTensors, std::vector>()); - this->sliceTupleNxStartIdx.assign(numTensors, std::vector()); - this->sliceTupleFwdCnt.assign(numTensors, std::vector()); - this->trivialSlice.assign(numTensors, std::vector()); this->sliceMeta.assign( numTensors, std::vector>>()); - this->sliceStack.assign(numTensors, std::vector()); this->levelReducedDep.assign(numTensors, std::vector()); // Initialize nested types of `TensorId`-indexed fields. @@ -345,7 +123,6 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, // to the total number of loops (each level can potentially be mapped to // one of the loop being generated). lvlRank = numLoops; - lvlTypes[tid].assign(lvlRank, LevelType::Dense); } else { const Value t = tensors[tid]; // a scalar or 0-dimension tensors @@ -355,40 +132,17 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, auto rtp = getRankedTensorType(t); const SparseTensorType stt(rtp); lvlRank = stt.getLvlRank(); - - if (stt.hasEncoding()) { - const auto enc = stt.getEncoding(); - isSparseSlices[tid] = enc.isSlice(); - for (auto lvlTp : enc.getLvlTypes()) - lvlTypes[tid].push_back(lvlTp); - } else { - lvlTypes[tid].assign(lvlRank, LevelType::Dense); - } } - // Initialize using empty value. - lvlSizes[tid].assign(lvlRank, Value()); - highs[tid].assign(lvlRank, Value()); - segHi[tid].assign(lvlRank, Value()); - posits[tid].assign(lvlRank, Value()); - coords[tid].assign(lvlRank, Value()); lvls[tid].resize(lvlRank); - - sliceOffsets[tid].assign(lvlRank, Value()); - sliceStrides[tid].assign(lvlRank, Value()); + iters[tid].resize(lvlRank); + loopHighs.assign(numLoops, nullptr); // Slice-driven loops related initialization. levelReducedDep[tid].assign(lvlRank, 0); dependentLvlMap[tid].assign( lvlRank, std::vector>()); - slicePosBuffer[tid].assign(lvlRank, std::vector()); - sliceTupleNxStartIdx[tid].assign(lvlRank, Value()); - sliceTupleFwdCnt[tid].assign(lvlRank, Value()); - trivialSlice[tid].assign(lvlRank, false); sliceMeta[tid].assign(lvlRank, std::vector>()); - sliceStack[tid].emplace_back(/*minCrd=*/Value(), - /*offset=*/Value(), /*isNonEmpty*/ Value(), - /*posTupleNum=*/Value(), std::nullopt, 0); if (dimGetter && !isSynTensor(tid)) { for (Level l = 0; l < lvlRank; l++) { std::vector> deps = dimGetter(tid, l); @@ -401,21 +155,39 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, if (depends == 0) continue; sliceMeta[tid][l].reserve(depends); - // We need `depends - 1` slices to fully reduce the affine expression. - slicePosBuffer[tid][l].reserve(depends - 1); } } } } +std::unique_ptr +LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t, + Level l) { + auto it = makeSimpleIterator(*lvls[t][l]); + auto stt = getSparseTensorType(tensors[t]); + if (stt.hasEncoding() && stt.getEncoding().isSlice()) { + Value offset = genSliceOffset(builder, loc, tensors[t], l); + Value stride = genSliceStride(builder, loc, tensors[t], l); + auto slicedIt = makeSlicedLevelIterator(std::move(it), offset, stride, + lvls[t][l]->size()); + return slicedIt; + } + return it; +} + void LoopEmitter::initializeLoopEmit( OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater, LoopEmitter::SynTensorBoundSetter synSetter) { - // For every synthetic tensor, set the high bound by calling the callback. - if (synSetter) - for (unsigned i = 0, e = highs[getSynTensorId()].size(); i < e; i++) - highs[getSynTensorId()][i] = synSetter(builder, loc, i); + if (synSetter) { + TensorId synId = getSynTensorId(); + for (unsigned i = 0, e = loopHighs.size(); i < e; i++) { + Value sz = loopHighs[i] = synSetter(builder, loc, i); + auto [stl, it] = makeSynLevelAndIterator(sz, synId, i); + lvls[synId][i] = std::move(stl); + iters[synId][i].emplace_back(std::move(it)); + } + } // For every manifest tensor: // * get the values buffer. @@ -448,14 +220,13 @@ void LoopEmitter::initializeLoopEmit( // Scan all levels of current tensor. for (Level l = 0; l < lvlRank; l++) { - lvls[t][l] = makeSparseTensorLevel(builder, loc, tensor, l); - // Find upper bound in current dimension. - highs[t][l] = lvlSizes[t][l] = lvlSzs[l]; - if (isSparseSlices[t]) { - sliceOffsets[t][l] = genSliceOffset(builder, loc, tensors[t], l); - sliceStrides[t][l] = genSliceStride(builder, loc, tensors[t], l); - } + lvls[t][l] = makeSparseTensorLevel(builder, loc, tensor, t, l); + if (!dependentLvlMap[t][l].empty()) + continue; + + auto it = makeLevelIterator(builder, loc, t, l); + iters[t][l].emplace_back(std::move(it)); } // Perform the required bufferization. Dense inputs materialize @@ -491,11 +262,11 @@ void LoopEmitter::initializeLoopEmit( // some loop preparation from tensor iteration, but will also (undesirably) // hoist the code ouside if-conditions. } - - initSliceDriven(builder, loc); + // TODO: avoid treating subsection iterator as a special case. + initSubSectIterator(builder, loc); } -void LoopEmitter::initSliceDriven(OpBuilder &builder, Location loc) { +void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) { Value c0 = C_IDX(0); for (TensorId t = 0, e = tensors.size(); t < e; t++) { auto rtp = dyn_cast(tensors[t].getType()); @@ -516,81 +287,62 @@ void LoopEmitter::initSliceDriven(OpBuilder &builder, Location loc) { if (depRedOrder.empty()) continue; + std::sort(depRedOrder.begin(), depRedOrder.end(), [](auto &l, auto &r) { return std::get<0>(l) < std::get<0>(r); }); + SmallVector lastIter(tensors.size(), nullptr); for (auto [loop, t, lvl] : depRedOrder) { std::pair curDep = remDepStack[t][lvl].back(); assert(curDep.first == loop); - Value size = c0; - for (auto [loop, stride] : remDepStack[t][lvl]) { - // The synthetic tensor high defines the loop upper bound. - Value loopHi = highs[getSynTensorId()][loop]; - size = ADDI(size, MULI(loopHi, C_IDX(stride))); - } - sliceMeta[t][lvl].emplace_back(size, curDep.second); remDepStack[t][lvl].pop_back(); - // Generate caches required to fast compute next-non-empty slices with - // increasing offset for slice-base loop. - // We do not need cache for dense levels. - if (!remDepStack[t][lvl].empty() && !isDenseLT(lvls[t][lvl]->getLT())) { - Value cnt = C_IDX(1); - for (int preLvl = lvl - 1; preLvl >= 0; preLvl--) { - if (remDepStack[t][preLvl].empty()) - break; - assert(remDepStack[t][preLvl].size() == 1 && "Not implemented"); - auto [loop, stride] = remDepStack[t][preLvl].back(); - assert(stride == 1 && "Not yet implemented"); - // Accumlate the size required to cache the pLo for the slice. - // E.g., if we want to cache the pIdx for slice on the - // second level. We at most need a memref. - // - // NOTE: this is apparently an over-approximation when the previous - // level is compressed, and we can compute a precise memory size - // inside the loops. But that would also requires us to allocate/free - // memory in loops. - cnt = MULI(highs[getSynTensorId()][loop], cnt); + auto lvlIt = makeLevelIterator(builder, loc, t, lvl); + const SparseIterator *parent = lastIter[t]; + if (!parent && lvl > 0) { + if (dependentLvlMap[t][lvl - 1].empty()) { + parent = iters[t][lvl - 1].back().get(); } - slicePosBuffer[t][lvl].push_back(allocSlicePosBuf(builder, loc, cnt)); - } // else fully resolved. + } + + std::unique_ptr it; + if (!remDepStack[t][lvl].empty()) { + // Compute the subsection size. + Value size = c0; + for (auto [loop, stride] : remDepStack[t][lvl]) { + Value loopHi = loopHighs[loop]; + size = ADDI(size, MULI(loopHi, C_IDX(stride))); + } + it = makeNonEmptySubSectIterator(builder, loc, parent, std::move(lvlIt), + size, curDep.second); + } else { + Value size = loopHighs[loop]; + const SparseIterator &subSectIter = *iters[t][lvl].back(); + it = makeTraverseSubSectIterator(subSectIter, *parent, std::move(lvlIt), + size, curDep.second); + } + lastIter[t] = it.get(); + iters[t][lvl].emplace_back(std::move(it)); } } } -void LoopEmitter::categorizeLoopCondition( - ArrayRef tidLvls, SmallVectorImpl &dnConds, - SmallVectorImpl &spConds) { +void LoopEmitter::categorizeIterators( + ArrayRef tidLvls, SmallVectorImpl &raIters, + SmallVectorImpl &spIters) { // Finds out the tensor level that we should use to generate loops. Amongs all // the tensor levels, there is at most one sparse tensor level. for (auto [t, l] : unpackTensorLevelRange(tidLvls)) { - assert(lvlTypes[t].size() > l); // Must be a valid tid, dim pair - auto lvlType = lvlTypes[t][l]; - // Must be a recognizable LT. - assert(isDenseLT(lvlType) || isCompressedLT(lvlType) || - isLooseCompressedLT(lvlType) || isSingletonLT(lvlType) || - is2OutOf4LT(lvlType)); - - bool isSparse = !isDenseLT(lvlType); - bool isSlice = isSparseSlices[t]; - bool isAffine = !dependentLvlMap[t][l].empty(); - bool isUnRedu = false; - // TODO: Supports affine index expression on sparse tensor slices. - assert(!isSlice || !isAffine); - - // Whether the affine index expression has been fully reduced or not. - if (!dependentLvlMap[t][l].empty()) - isUnRedu = !depFullyReduced(t, l); - - auto &dstVec = isSparse ? spConds : dnConds; - dstVec.emplace_back( - makeTensorLevel(t, l), - makeLoopCondKind(isSparse, isSlice, isAffine, isUnRedu)); + SparseIterator *it = &getCurIterator(t, l); + if (it->randomAccessible()) + raIters.push_back(it); + else + spIters.push_back(it); } - std::stable_sort(spConds.begin(), spConds.end(), [](auto lhs, auto rhs) { + std::stable_sort(spIters.begin(), spIters.end(), [](auto lhs, auto rhs) { // AffineUnRed > Affine > Slice > Trivial - return static_cast(lhs.second) > static_cast(rhs.second); + return static_cast(lhs->kind) > static_cast(rhs->kind); }); } @@ -599,35 +351,24 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc, // TODO: sort assert(loopSeqStack.size() == loopStack.size()); // Prepares for all the tensors used in the current loop sequence. - std::vector> slicedTids; for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { - if (!dependentLvlMap[tid][lvl].empty()) { - bool fullyRed = genSliceBegin(builder, loc, tid, lvl); - slicedTids.emplace_back(tid, lvl, fullyRed); - } else if (!isSynTensor(tid)) { - prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); - } + levelReducedDep[tid][lvl]++; + prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); } // Universal Index starts from 0. - loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids)); + loopSeqStack.emplace_back(C_IDX(0), tidLvls.vec()); } void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) { assert(loopSeqStack.size() == loopStack.size() + 1); - const auto &slicedTids = loopSeqStack.back().second; - // Depending on whether the slice is resolved or not at current loop sequence, // end them in different ways. - for (auto [tid, lvl, res] : slicedTids) { - if (!res) { - // If this is a unresolved-slice-driven loop, pops out the slice. - assert(sliceStack[tid].back().slicedOnLvl == lvl); - sliceStack[tid].pop_back(); - } - } + for (auto [tid, lvl] : unpackTensorLevelRange(loopSeqStack.back().second)) + levelReducedDep[tid][lvl]--; + loopSeqStack.pop_back(); } @@ -661,16 +402,15 @@ Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) { } std::pair LoopEmitter::emitForLoopOverTensorAtLvl( - OpBuilder &builder, Location loc, TensorId tid, Level lvl, Value lo, - Value hi, MutableArrayRef reduc, bool isParallel) { - bool isSparseCond = isCompressedLT(lvlTypes[tid][lvl]) || - isLooseCompressedLT(lvlTypes[tid][lvl]) || - is2OutOf4LT(lvlTypes[tid][lvl]) || - isSingletonLT(lvlTypes[tid][lvl]); + OpBuilder &builder, Location loc, SparseIterator &iter, + MutableArrayRef reduc, bool isParallel) { + // TODO: support dynamic slices. // Uses the first dimension here to build the loop bound (which is also the // biggest range). + Value step = C_IDX(1); + auto [lo, hi] = iter.genForCond(builder, loc); Operation *loop = nullptr; Value iv; if (isParallel) { @@ -703,255 +443,30 @@ std::pair LoopEmitter::emitForLoopOverTensorAtLvl( } assert(loop && iv); - Value crd; - if (isSparseCond) { - // For COO, the position is the same across consecutive levels. - /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. - posits[tid][lvl] = iv; - crd = genSparseCrd(builder, loc, tid, lvl); + Value crd = iv; + if (!iter.randomAccessible()) { + iter.linkNewScope(iv); + crd = iter.deref(builder, loc); } else { - // Dense tensor, the coordinate is the inducation variable. - crd = iv; - } - - if (isSparseSlices[tid] && isSparseCond) { - // For sparse level slices, we need to filter out invalid coordinates that - // are not included in the slice. - SmallVector types; - for (Value red : reduc) - types.push_back(red.getType()); - - auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, tid, lvl); - bool hasReduc = !types.empty(); - scf::IfOp ifOp = builder.create(loc, types, pred, - /*else*/ hasReduc); - if (hasReduc) { - // scf.for (a) -> v - // %s = scf.if (a) -> v - // user-generated code. - // else - // yield a - // yield %s - YIELD(ifOp.getResults()); - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - // On mismatch. - YIELD(reduc); - } - // Set the insertion point to matched branch. - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - crd = trans; + iter.locate(builder, loc, iv); } - assert(crd); - coords[tid][lvl] = crd; return {loop, crd}; } -Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc, - ValueRange ivs, TensorLvlCond cond) { - auto [tid, lvl] = unpackTensorLevel(cond.first); - - switch (cond.second) { - case LoopCondKind::SparseCond: { - assert(ivs.size() == 1); - // We used the first level bound as the bound the collapsed set of levels. - return CMPI(ult, ivs.back(), highs[tid][lvl]); - } - case LoopCondKind::SparseSliceCond: { - assert(ivs.size() == 1); - return CMPI(ult, ivs.back(), highs[tid][lvl]); - } - case LoopCondKind::SparseAffineCond: { - assert(ivs.size() == 1); - - Value crdHi; // loop upper bound - { - OpBuilder::InsertionGuard guard(builder); - Operation *loop = builder.getInsertionBlock()->getParentOp(); - // crdHi is a loop invariant, hosit the computation outside the loop. - if (llvm::isa_and_nonnull(loop)) - builder.setInsertionPoint(loop); - auto [remSz, stride] = sliceMeta[tid][lvl].back(); - assert(stride == 1 && "Not yet implemented"); - crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset, remSz); - } - assert(crdHi); - return genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl], crdHi, - ivs[0], highs[tid][lvl]); - } - case LoopCondKind::SparseAffineUnRedCond: { - assert(ivs.size() == 3); - return ivs.front(); // isNonEmpty - } - default: - llvm_unreachable("Unhandled LoopCondKind"); - } - llvm_unreachable("Unhandled LoopCondKind"); -} - -std::optional LoopEmitter::genWhileLoopBody(OpBuilder &builder, - Location loc, ValueRange ivs, - TensorLvlCond cond) { - auto [tid, lvl] = unpackTensorLevel(cond.first); - - switch (cond.second) { - case LoopCondKind::SparseCond: { - // Updates position. For collapsed COO, the position is the same across - // consecutive levels. - posits[tid][lvl] = ivs.back(); - - // Update coordinates. - coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl); - return std::nullopt; - } - case LoopCondKind::SparseSliceCond: { - assert(ivs.size() == 1); - posits[tid][lvl] = ivs.front(); - Value sCrd = genSparseCrd(builder, loc, tid, lvl); - // Converts the coordinate loaded from the actual sparse tensor to the - // coordinates in the sparse slice. - auto [dCrd, pred] = genSliceLegitPredicate(builder, loc, sCrd, tid, lvl); - coords[tid][lvl] = dCrd; - return pred; - } - case LoopCondKind::SparseAffineCond: { - assert(ivs.size() == 1); - // Coord is the relative offset related to its parents. - assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement"); - sliceTupleFwdCnt[tid][lvl] = SUBI(ivs[0], posits[tid][lvl]); - // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1] - Value posit = ivs[0]; - // We need to substract the offset to get relative coordinates. - // TODO: Maybe assert relC >=0 during runtime in debug build? - Value absC = lvls[tid][lvl]->peekCrdAt(builder, loc, posit); - auto relC = SUBI(absC, getFinalSliceOnLvl(tid, lvl).offset); - posits[tid][lvl] = posit; - coords[tid][lvl] = relC; - return std::nullopt; - } - case LoopCondKind::SparseAffineUnRedCond: { - unsigned depth = sliceStack[tid].back().depth; - unsigned curStride = sliceMeta[tid][lvl][depth - 1].second; - assert(ivs.size() == 3); - - // Updates the current slice info - SliceInfo &sliceInfo = sliceStack[tid].back(); - sliceInfo.isNonEmpty = ivs[0]; - sliceInfo.minCrd = ivs[1]; - sliceInfo.offset = ivs[2]; - - // Crd (the value we used to coiterate) is the relative offset related to - // its parents, we can use the absolute offset here because when depth = 1, - // absOffset[lvl][depth - 1] always equals zero. - // TODO: Update crd =absOffset[lvl][depth] - absOffset[lvl][depth - 1] - assert(depth == 1 && "TODO: not yet implement"); - Value crd = sliceInfo.offset; - - Value onStride = constantI1(builder, loc, true); - if (curStride != 1) { - Value strideVal = C_IDX(curStride); - Value rem = REMUI(crd, strideVal); - crd = DIVUI(crd, strideVal); - onStride = CMPI(eq, rem, C_IDX(0)); - } - coords[tid][lvl] = crd; - // No extra check is needed before accessing the tensor level. - return onStride; - } - default: - llvm_unreachable("Unhandled LoopCondKind"); - } - llvm_unreachable("Unhandled LoopCondKind"); -} - -ValueRange LoopEmitter::genCheckedValue(OpBuilder &builder, Location loc, - Value pred, ValueRange curArgs, - TensorLvlCond cond) { - assert(isSparseCond(cond.second)); - auto [tid, lvl] = unpackTensorLevel(cond.first); - if (isAffineIdxUnRedCond(cond.second)) { - unsigned depth = sliceStack[tid].back().depth; - unsigned curStride = sliceMeta[tid][lvl][depth - 1].second; - if (curStride == 1) - return curArgs; - // Build - // if (onStride) { - // yield curSlice - // } else { - // yield nxSlice. - //} - assert(curArgs.size() == 3); - auto ifOp = builder.create(loc, curArgs.getTypes(), pred, true); - { - OpBuilder::InsertionGuard guard(builder); - // If not all slices are legit, yield the updated value. - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - - YIELD(curArgs); - // If not all slices are legit, yield the updated value. - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - auto [nonEmpty, minCrd, offset] = - genSliceNextInduction(builder, loc, tid, lvl); - SmallVector nxSlice{nonEmpty, minCrd, offset}; - YIELD(nxSlice); - } - // If all slices are legit, start the user generated code. - return ifOp.getResults(); - } else { - // Currently only sparse slice condition need extra check. - assert(isSliceCond(cond.second) && isSparseCond(cond.second)); - assert(curArgs.size() == 1); - Value nextPos = ADDI(curArgs.front(), C_IDX(1)); - return SELECT(pred, curArgs.front(), nextPos)->getResults(); - } - llvm_unreachable("unhandled case"); -} - std::pair LoopEmitter::emitWhileLoopOverTensorsAtLvls( - OpBuilder &builder, Location loc, ArrayRef spConds, + OpBuilder &builder, Location loc, ArrayRef spIters, MutableArrayRef reduc, bool needsUniv) { // NOTE: the slice driven tensor-related reduction variable must // appear before normal tensors. - assert(!spConds.empty()); // The set of induction variables for the while loop. SmallVector ivs; - // Segment sizes for induction variables used for different kinds of loop - // conditions. - SmallVector opSegSize; // Construct the while-loop with a parameter for each coordinate. - for (auto [tl, cKind] : spConds) { - auto [tid, lvl] = unpackTensorLevel(tl); - const auto lvlTp = lvlTypes[tid][lvl]; - // Dense level are handled by the shared univeral index. - assert(!isDenseCond(cKind)); - // Must be a recognizable sparse level. - assert(isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp) || - isSingletonLT(lvlTp)); - (void)lvlTp; - - unsigned prevSz = ivs.size(); - if (isAffineIdxCond(cKind)) { - // TODO: Support view-based reshape on sparse levels with affine index - // expressions. - if (isAffineIdxUnRedCond(cKind)) { - SliceInfo &sliceInfo = sliceStack[tid].back(); - // The order matters! - ivs.push_back(sliceInfo.isNonEmpty); - ivs.push_back(sliceInfo.minCrd); - ivs.push_back(sliceInfo.offset); - } else { - ivs.push_back(posits[tid][lvl]); // loop lower bound (pos low). - } - // We reduced one more dependency after entering the loop. - levelReducedDep[tid][lvl]++; - } else { - assert(dependentLvlMap[tid][lvl].empty()); - const Value pos = posits[tid][lvl]; - ivs.push_back(pos); - } - opSegSize.push_back(ivs.size() - prevSz); + for (SparseIterator *it : spIters) { + ValueRange itVals = it->getItVals(); + ivs.append(itVals.begin(), itVals.end()); } // The position where user-supplied reduction variable starts. @@ -973,10 +488,11 @@ std::pair LoopEmitter::emitWhileLoopOverTensorsAtLvls( builder.setInsertionPointToStart(before); ValueRange bArgs = before->getArguments(); Value whileCond = nullptr; // bool values for loop condition. - for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) { - Value cv = genWhileLoopConditions(builder, loc, bArgs.take_front(segSz), c); - bArgs = bArgs.drop_front(segSz); - whileCond = !whileCond ? cv : ANDI(whileCond, cv); + + for (SparseIterator *it : spIters) { + auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs); + whileCond = !whileCond ? cond : ANDI(whileCond, cond); + bArgs = remArgs; } // The remaining block arguments are user-provided reduction values and an // optional universal index. Make sure their sizes match. @@ -990,49 +506,11 @@ std::pair LoopEmitter::emitWhileLoopOverTensorsAtLvls( // iterations, we maintains another array to hold the iteration arguments to // yield if the checks fails. SmallVector nextArgs(aArgs.begin(), aArgs.end()); - // A mutable alias for convenient slicing. - MutableArrayRef nextArgsRef = nextArgs; - Value extraPred = nullptr; - for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) { - ValueRange condArgs = aArgs.take_front(segSz); - auto pred = genWhileLoopBody(builder, loc, condArgs, c); - assert(pred.has_value() == isCondWithExtraCheck(c.second)); - if (pred.has_value()) { - // We need all extra checks to pass. - extraPred = extraPred == nullptr ? *pred : ANDI(*pred, extraPred); - ValueRange nxArgs = genCheckedValue(builder, loc, *pred, condArgs, c); - assert(nxArgs.size() == segSz); - // Update the value for cases when some check fails. - for (unsigned i = 0; i < segSz; i++) { - nextArgsRef[i] = nxArgs[i]; - } - } - aArgs = aArgs.drop_front(segSz); - nextArgsRef = nextArgsRef.drop_front(segSz); - } - - if (extraPred) { - auto ifOp = builder.create(loc, types, extraPred, /*else*/ true); - // Marks this special IfOp so that Sparsification does not finalizing it. - ifOp->setAttr(getLoopEmitterLoopAttrName(), - StringAttr::get(builder.getContext(), "slice")); - // Links the SSA chain outside the if statement. - YIELD(ifOp->getResults()); - // If not all slices are legit, yield the updated value. - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - YIELD(nextArgs); - - // If all slices are legit, start the user generated code. - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - } - - for (auto [tid, lvl] : unpackTensorLevelFromCondRange(spConds)) { - // Generates segment high for non-unique level. - if (!isUniqueLT(lvlTypes[tid][lvl])) { - segHi[tid][lvl] = genSegmentHigh(builder, loc, tid, lvl, posits[tid][lvl], - highs[tid][lvl]); - } + for (SparseIterator *it : spIters) { + aArgs = it->linkNewScope(aArgs); + // Dereference the iterator to cache the coordinate. + it->deref(builder, loc); } // In-place update on reduction variable. @@ -1043,21 +521,15 @@ std::pair LoopEmitter::emitWhileLoopOverTensorsAtLvls( Value min; // Finds the minimum coordinate if (!needsUniv) { - for (auto [tid, lvl] : unpackTensorLevelFromCondRange(spConds)) { - const auto lvlTp = lvlTypes[tid][lvl]; - if (isCompressedLT(lvlTp) || isSingletonLT(lvlTp) || - isLooseCompressedLT(lvlTp)) { - const auto crd = coords[tid][lvl]; - if (min) { - Value cmp = CMPI(ult, coords[tid][lvl], min); - min = SELECT(cmp, coords[tid][lvl], min); - } else { - min = crd; - } + for (SparseIterator *it : spIters) { + if (min) { + Value cmp = CMPI(ult, it->getCrd(), min); + min = SELECT(cmp, it->getCrd(), min); + } else { + min = it->getCrd(); } } } else { - assert(!min); // Otherwise, universal index is the minimal pos. min = whileOp.getAfterArguments().back(); } @@ -1065,307 +537,108 @@ std::pair LoopEmitter::emitWhileLoopOverTensorsAtLvls( return {whileOp, min}; } -bool LoopEmitter::shouldIteratedByForLoop(ArrayRef sparseConds, - bool genDedup) { - assert(llvm::all_of(sparseConds, - [](TensorLvlCond c) { return isSparseCond(c.second); })); - +bool LoopEmitter::shouldIteratedByForLoop(ArrayRef spIters) { // If we need to co-iterate over two sparse tensors, we need a while loop - if (sparseConds.size() > 1) + if (spIters.size() > 1) return false; - // We also need a while loop for levels with affine index expression and - // non-unique levels when deduplication is required. - if (sparseConds.size() == 1) { - auto [tid, lvl] = unpackTensorLevel(sparseConds.back().first); - return !isAffineIdxCond(sparseConds.back().second) && - !(genDedup && !isUniqueLT(lvlTypes[tid][lvl])); - } + if (spIters.size() == 1) + return spIters.front()->iteratableByFor(); return true; } Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( OpBuilder &builder, Location loc, ArrayRef tidLvls, - MutableArrayRef reduc, bool tryParallel, bool genDedup, - bool needsUniv) { -#ifndef NDEBUG - // Sanity checks. - assert(!tidLvls.empty()); - for (auto [t, l] : unpackTensorLevelRange(tidLvls)) { - assert(!coords[t][l] || // We cannot re-enter the same level - !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop - } -#endif + MutableArrayRef reduc, bool tryParallel, bool needsUniv) { + // TODO: support multiple return on parallel for? tryParallel = tryParallel && reduc.size() <= 1; - SmallVector spConds; - SmallVector dnConds; - categorizeLoopCondition(tidLvls, dnConds, spConds); + SmallVector raIters; + SmallVector spIters; + categorizeIterators(tidLvls, raIters, spIters); // Only when there is at least one sparse conditions, do we really need the // universal index. // TODO: Maybe we should instead requires merger to pass in a valid value at // the first place instead of adjusting it in LoopEmitter? - needsUniv = !spConds.empty() && needsUniv; + needsUniv = !spIters.empty() && needsUniv; // The TensorLevel used for loop conditions. // If there is any sparse level, we need to use the sparse condition. // If all levels are dense, we can pick arbitrary one (dense slice-driven loop // can be generated using a simple ForOp as well). Operation *l = nullptr; Value iv = nullptr; - SmallVector sliceDrivenInfo; - SmallVector trivialLvls; + SmallVector tls; // Generates loops differently depending on whether we need a slice-driven // loop or a simple level traversal loop. - if (shouldIteratedByForLoop(spConds, genDedup) && !needsUniv) { - assert(spConds.size() <= 1); - TensorLvlCond tlCond = spConds.empty() ? dnConds.front() : spConds.front(); - auto loopCondKind = tlCond.second; - auto [tid, lvl] = unpackTensorLevel(tlCond.first); - Value lo = isSparseCond(loopCondKind) - ? posits[tid][lvl] // current offset - : loopSeqStack.back().first; // universal index - Value hi = highs[tid][lvl]; - if (isDenseCond(loopCondKind) && isAffineIdxCond(loopCondKind)) { - bool unReduc = isAffineIdxUnRedCond(loopCondKind); - assert(unReduc == !depFullyReduced(tid, lvl)); - unsigned depth = sliceStack[tid].back().depth; - assert(depth >= 1); - // The *next* slice size after reducing the current index variable. - auto [nxSz, nxStride] = sliceMeta[tid][lvl][depth]; - // The *current* stride to reduce the current index variable. - // E.g., for 2 * i, stride = 2. - unsigned stride = sliceMeta[tid][lvl][depth - 1].second; - hi = nxSz; - if (unReduc) { - // Adjust for loop hi for dense slice-driven loop. - hi = SUBI(lvlSizes[tid][lvl], hi); - hi = ADDI(hi, C_IDX(1)); - hi = DIVUI(hi, C_IDX(stride)); - } else { - // TODO: dialuted convolution. - assert(nxStride == 1 && "Not yet implemented."); - } - } - std::tie(l, iv) = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi, - reduc, tryParallel); - // For loop condition must be a trivial condition (levels without affine - // index expression). - trivialLvls.push_back(tlCond.first); + if (shouldIteratedByForLoop(spIters) && !needsUniv) { + assert(spIters.size() <= 1); + SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front(); + std::tie(l, iv) = + emitForLoopOverTensorAtLvl(builder, loc, it, reduc, tryParallel); + tls.push_back(makeTensorLevel(it.tid, it.lvl)); } else { - for (auto [tl, cKind] : spConds) { - if (isAffineIdxCond(cKind)) { - auto [tid, lvl] = unpackTensorLevel(tl); - bool unReduc = isAffineIdxUnRedCond(cKind); - assert(unReduc == !depFullyReduced(tid, lvl)); - sliceDrivenInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc); - } else { - trivialLvls.push_back(tl); - } + for (auto *it : spIters) { + tls.push_back(makeTensorLevel(it->tid, it->lvl)); } + if (needsUniv) + for (auto *it : raIters) + tls.push_back(makeTensorLevel(it->tid, it->lvl)); + std::tie(l, iv) = - emitWhileLoopOverTensorsAtLvls(builder, loc, spConds, reduc, needsUniv); + emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv); } // Enter dense tensor levels. - enterTensorsAtDenseLvls(builder, loc, dnConds, iv, sliceDrivenInfo); - // NOTE: we can also prepare for next dim here in advance + for (SparseIterator *it : raIters) + it->locate(builder, loc, iv); + // NOTE: we can also prepare for next dim here in advance // Pushes the loop into stack. - loopStack.emplace_back(trivialLvls, sliceDrivenInfo, l, - builder.getInsertionBlock(), iv, loopTag); + loopStack.emplace_back(tidLvls, l, builder.getInsertionBlock(), iv, loopTag); return l; } -Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl( - OpBuilder &builder, Location loc, TensorId tid, Level lvl, - AffineExpr affine, MutableArrayRef reduc) { - assert(isValidLevel(tid, lvl)); - assert(!isa(affine) && !isDenseLT(lvlTypes[tid][lvl])); - // We can not re-enter the same level. - assert(!coords[tid][lvl]); - - // TODO: We should instead use a whileOp for filter loop to allow early - // break when exceeding (for ordered levels). - // TODO: There are many other potiential opportunities that we might apply in - // the future. E.g., we could use binary search to locate positions. - const Value step = C_IDX(1); - const Value pLo = posits[tid][lvl]; - const Value pHi = highs[tid][lvl]; - scf::ForOp forOp = builder.create(loc, pLo, pHi, step, reduc); - - // In-place update on the reduction variable vector. - assert(forOp.getNumRegionIterArgs() == reduc.size()); - for (int i = 0, e = reduc.size(); i < e; i++) - reduc[i] = forOp.getRegionIterArg(i); - - builder.setInsertionPointToStart(forOp.getBody()); - // The induction variable gives the position. - const Value pos = forOp.getInductionVar(); - posits[tid][lvl] = pos; - const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos); - coords[tid][lvl] = crd; - - // Generate an if-condition to filter out coordinates that are not - // equal to the result of the affine expression. - Value expected = genAffine(builder, loc, affine); - auto pred = CMPI(eq, crd, expected); - SmallVector types; - for (Value red : reduc) { - types.push_back(red.getType()); - } +void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc, + TensorLevel tidLvl, + AffineExpr lvlExpr) { + auto [tid, lvl] = unpackTensorLevel(tidLvl); - bool hasReduc = !types.empty(); - scf::IfOp ifOp = - builder.create(loc, types, pred, /*else*/ hasReduc); - if (hasReduc) { - // scf.for (a) -> v - // %s = scf.if (a) -> v - // user-generated code. - // else - // yield a - // yield %s - YIELD(ifOp.getResults()); - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - // On mismatch. - YIELD(reduc); - } - // Set the insert point to matched branch. - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - - // NOTE: we can also prepare for next lvl here in advance - // Push the loop into stack - loopStack.emplace_back(ArrayRef(makeTensorLevel(tid, lvl)), - ArrayRef(), forOp, - builder.getInsertionBlock(), coords[tid][lvl], - nullptr); - return forOp; -} + const SparseIterator *parent = + lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get(); + auto &it = getCurIterator(tid, lvl); + it.genInit(builder, loc, parent); -void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc, - TensorLevel tidLvl, - AffineExpr lvlExpr) { - auto [tid, lvl] = unpackTensorLevel(tidLvl); - assert(isDenseLT(lvlTypes[tid][lvl])); - // For dense levels, the vel-coordinate also serves as the position. + assert(it.kind == IterKind::kTrivial && it.randomAccessible()); Value lvlCrd = genAffine(builder, loc, lvlExpr); - posits[tid][lvl] = genAddress(builder, loc, tid, lvl, lvlCrd); + it.locate(builder, loc, lvlCrd); } void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid, Level lvl) { - assert(isValidLevel(tid, lvl)); - const auto lvlTp = lvlTypes[tid][lvl]; - - if (isDenseLT(lvlTp)) - return; - - const Value c0 = C_IDX(0); - const Value c1 = C_IDX(1); - // Either the first level, or the previous level has been set. - /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. - assert(lvl == 0 || posits[tid][lvl - 1]); - if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp) || - is2OutOf4LT(lvlTp)) { - - Value pos = lvl == 0 ? c0 : posits[tid][lvl - 1]; - std::tie(posits[tid][lvl], highs[tid][lvl]) = - lvls[tid][lvl]->peekRangeAt(builder, loc, pos); - return; - } - if (isSingletonLT(lvlTp)) { - // TODO: merge this as well when SparseTensorLevel support dedup. - const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1]; - posits[tid][lvl] = pLo; - - // If we are coiterating non-unique levels, then use pHi=segHi; - // otherwise use pHi=pLo+1. - // NOTE: Just because the level is non-unique, that does not - // guarantee that segHi is defined: because we only generate segHi - // whenever coiterating, in order to improve code quality for the - // non-coiterating cases. - const auto parentSegHi = segHi[tid][lvl - 1]; - highs[tid][lvl] = (!isUniqueLT(lvlTypes[tid][lvl - 1]) && parentSegHi) - ? parentSegHi - : ADDI(pLo, c1); - return; - } - llvm_unreachable("Unrecognized level-type!"); -} + // if this is the first level, there is no parent iterator for the current + // iterator. + // If the current iterator is a subsection-based iterator, the parent iterator + // is memorized by the iterator. + bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty(); -void LoopEmitter::enterTensorsAtDenseLvls( - OpBuilder &builder, Location loc, ArrayRef dnConds, Value iv, - SmallVectorImpl &sliceInfo) { - for (auto [dnTidLvl, denseLoopCond] : dnConds) { - auto [tid, lvl] = unpackTensorLevel(dnTidLvl); - assert(isDenseLT(lvlTypes[tid][lvl])); - - if (isAffineIdxCond(denseLoopCond)) { - // Pushes sliced levels to build correct LoopInfo. - bool unReduc = isAffineIdxUnRedCond(denseLoopCond); - SliceInfo &info = sliceStack[tid].back(); - // Pushes sliced dense loop info to tell LoopEmitter how to exit it. - sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc); - // FIXME: The offset and position iterator need to be adjusted when the - // slice is strided. - if (unReduc) { - assert(*info.slicedOnLvl == lvl); - unsigned depth = sliceStack[tid].back().depth; - assert(depth >= 1); - unsigned stride = sliceMeta[tid][lvl][depth - 1].second; - // Update the slice information as we enter the new loop. - info.minCrd = info.offset = MULI(iv, C_IDX(stride)); - info.isNonEmpty = constantI1(builder, loc, true); - } else { - posits[tid][lvl] = - genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv)); - Value fwdCnt = lvl == 0 || trivialSlice[tid][lvl] - ? C_IDX(0) - : sliceTupleFwdCnt[tid][lvl - 1]; - Value sz = sliceMeta[tid][lvl].back().first; - Value mul = MULI(fwdCnt, sz); - sliceTupleFwdCnt[tid][lvl] = ADDI(mul, iv); - } - levelReducedDep[tid][lvl]++; - } else { - // Skips the synthetic tensor - if (isSynTensor(tid)) - continue; - // A dense level with trivial index expression. - assert(dependentLvlMap[tid][lvl].empty()); - auto enc = getSparseTensorEncoding(tensors[tid].getType()); - if (enc && !isSparseOutput(tid)) { - bool validPos = lvl == 0 || posits[tid][lvl - 1]; - if (!validPos) { - // We might not find the pos for the sparse output tensor as it is - // unconditionally required by the sparsification. - assert(isOutputTensor(tid)); - continue; - } - posits[tid][lvl] = genAddress(builder, loc, tid, lvl, iv); - // NOTE: we can also prepare for next lvl here in advance - } - } - } + const SparseIterator *parent = + hasParent ? nullptr : iters[tid][lvl - 1].back().get(); + auto &it = getCurIterator(tid, lvl); + it.genInit(builder, loc, parent); + + // Locates the randon accessible iterator to 0. + if (it.randomAccessible()) + it.locate(builder, loc, C_IDX(0)); } void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc) { const LoopInfo &loopInfo = loopStack.back(); - for (auto [tid, lvl, reduced] : loopInfo.sliceDrivenInfo) { - if (!reduced) { - SliceInfo &info = sliceStack[tid].back(); - assert(isDenseLT(lvlTypes[tid][lvl])); - assert(*info.slicedOnLvl == lvl); - (void)reduced; - info.minCrd = info.offset = info.isNonEmpty = Value(); - } - levelReducedDep[tid][lvl]--; - } if (auto forOp = llvm::dyn_cast(loopInfo.loop)) { if (!reduc.empty()) { assert(reduc.size() == forOp.getNumResults()); @@ -1428,18 +701,6 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++) reduc[i] = parOp.getResult(i); } - - // Finished iterating a tensor, clean up - // We only do the clean up on for loop as while loops do not necessarily - // finish the iteration on a sparse tensor - for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) { - // Reset to null. - coords[tid][lvl] = Value(); - posits[tid][lvl] = Value(); - // Dense level, high is fixed. - if (!isDenseLT(lvlTypes[tid][lvl])) - highs[tid][lvl] = Value(); - } } void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, @@ -1454,98 +715,45 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, // However, that would result in a rather elaborate forest of yield // instructions during code generation. Moreover, performing the induction // after the if-statements more closely resembles code generated by TACO. - unsigned o = 0; SmallVector operands; - unsigned delta = 0; - for (auto [tid, lvl, resolved] : loopInfo.sliceDrivenInfo) { - // TODO: handle dense. - assert(isCompressedLT(lvlTypes[tid][lvl])); - levelReducedDep[tid][lvl]--; - if (!resolved) { - // TODO: support coiterating multiple slices - assert(loopInfo.sliceDrivenInfo.size() == 1); - auto [nxNonEmpty, nxMinCrd, nxAbsOffset] = - genSliceNextInduction(builder, loc, tid, lvl); - // Update while loop induction operands. - operands.push_back(nxNonEmpty); - operands.push_back(nxMinCrd); - operands.push_back(nxAbsOffset); - - // Update the slice stack. - SliceInfo &info = sliceStack[tid].back(); - info.isNonEmpty = whileOp.getResult(o++); - info.minCrd = whileOp.getResult(o++); - info.offset = whileOp.getResult(o++); - continue; - } - - Value forwarded = nullptr; - if (loopInfo.trivialTidLvls.empty() && - loopInfo.sliceDrivenInfo.size() == 1) { - // Forwards the position iterator. - operands.push_back(ADDI(posits[tid][lvl], one)); - forwarded = constantI1(builder, loc, true); - } else { - const Value pos = posits[tid][lvl]; - const Value nxPos = ADDI(posits[tid][lvl], one); - forwarded = CMPI(eq, coords[tid][lvl], iv); - operands.push_back(SELECT(forwarded, nxPos, pos)); - } - // The coordinate is invalid now. - coords[tid][lvl] = nullptr; - - // Update the position iterator as we exit the while loop. - posits[tid][lvl] = whileOp->getResult(o++); - }; - - for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) { - const auto lvlTp = lvlTypes[tid][lvl]; - if (isCompressedLT(lvlTp) || isSingletonLT(lvlTp) || - isLooseCompressedLT(lvlTp)) { - const Value crd = coords[tid][lvl]; - const Value pos = posits[tid][lvl]; - Value cmp = CMPI(eq, crd, iv); - // If the loop contains a coiteration with non-unique level, we fast - // forward all the duplicated coords by setting the position to the - // segment high. - Value add = - !isUniqueLT(lvlTypes[tid][lvl]) ? segHi[tid][lvl] : ADDI(pos, one); - - operands.push_back(SELECT(cmp, add, pos)); + ValueRange whileRes = whileOp.getResults(); + + for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.tidLvls)) { + SparseIterator &it = getCurIterator(tid, lvl); + if (!it.randomAccessible()) { + // Forward the sparse iterator. + Value cmp = CMPI(eq, it.getCrd(), iv); + it.forwardIf(builder, loc, cmp); + operands.append(it.getItVals().begin(), it.getItVals().end()); + // const Value newPos = whileOp->getResult(o++); // Following loops continue iteration from the break point of the // current while loop. - const Value newPos = whileOp->getResult(o++); - // We need to define a new local variable for `tid` to avoid - // warnings about "captured structured bindings are a C++20 extension". - // FIXME(wrengr): define a helper function to capture this idiom! - const TensorId newTid = tid; - posits[newTid][lvl] = newPos; - - // The coordinate is invalid now. - coords[tid][lvl] = nullptr; - // The segment high is invalid now. - segHi[tid][lvl] = nullptr; - // highs remains unchanged. + whileRes = it.linkNewScope(whileRes); + } else { + // Make sure randomly accessible (dense) iterator is set to the right + // position according to the universal index. + Value uniIdx = whileOp.getResults().back(); + it.locate(builder, loc, uniIdx); } } // Reduction value from users. for (auto &i : reduc) { operands.push_back(i); - // In place update reduction variable. - i = whileOp->getResult(o++); + // Update user reduction variables. + i = whileRes.front(); + whileRes = whileRes.drop_front(); } // An (optional) universal index. - if (operands.size() + delta < whileOp.getNumResults()) { - assert(operands.size() + delta + 1 == whileOp.getNumResults()); + if (operands.size() < whileOp.getNumResults()) { + assert(operands.size() + 1 == whileOp.getNumResults()); // The last one is the universial index. operands.push_back(ADDI(iv, one)); // update the loop starting point of current loop sequence - loopSeqStack.back().first = whileOp->getResult(o++); + loopSeqStack.back().first = whileOp->getResults().back(); } - assert(o == operands.size() + delta); if (!operands.empty()) YIELD(operands); @@ -1578,651 +786,6 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc, loopStack.pop_back(); } -//===----------------------------------------------------------------------===// -// Slice-driven loop related methods. -//===----------------------------------------------------------------------===// - -unsigned LoopEmitter::remDepOnLevel(TensorId tid, Level lvl) const { - unsigned totalDependencies = dependentLvlMap[tid][lvl].size(); - if (totalDependencies != 0) { - assert(totalDependencies >= 2); - return totalDependencies - levelReducedDep[tid][lvl]; - } - return totalDependencies; -} - -const LoopEmitter::SliceInfo &LoopEmitter::getMostRecentSliceOnLvl(TensorId tid, - Level lvl) { - // Finds the most-recent slice using a reverse iteration. - for (auto it = sliceStack[tid].rbegin(), ie = sliceStack[tid].rend(); it < ie; - it++) { - if (it->slicedOnLvl == lvl) { // the level matched - return *it; - } - } - llvm_unreachable("Failed to find sliceInfo"); -} - -// Generates a while loop to iterate over a slice sparse level as follows. -// -// while(coords[loopLo] < offset + size) { -// body_builder -// loopLo ++; -// } -std::pair LoopEmitter::genSliceLvlTraverseLoop( - OpBuilder &builder, Location loc, Value posLo, Value posHi, Value offset, - Value size, TensorId tid, Level lvl, ValueRange userReduc, - LoopBodyBuilder bodyBuilder) { - Value c1 = C_IDX(1); - auto [sliceSz, stride] = sliceMeta[tid][lvl].back(); - assert(stride == 1 && "Not yet implemented"); - Value sliceHi = ADDI(offset, sliceSz); - - SmallVector reduc{posLo}; // loop lower bounds - const unsigned numMetaReduc = reduc.size(); - - // Append user required reduction value. - reduc.append(userReduc.begin(), userReduc.end()); - scf::WhileOp whileOp = builder.create( - loc, ValueRange(reduc).getTypes(), reduc, - /*beforeBuilder=*/ - [this, posHi, sliceHi, tid, lvl](OpBuilder &builder, Location loc, - ValueRange args) { - Value cond = genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl], - sliceHi, args[0], posHi); - // continue if not yet break nor out of bound. - builder.create(loc, cond, args); - }, - /*afterBuilder=*/ - [c1, numMetaReduc, bodyBuilder](OpBuilder &builder, Location loc, - ValueRange args) { - Value iv = args[0]; - TypeRange types = args.drop_front(numMetaReduc).getTypes(); - // The coordinate must be in bound as guaranteed by the loop - // condition. We generate a fake if operation here only to hide the - // extra loop induction variables maintained by us from users, which - // will be removed by later optimization pass. - auto ifOp = builder.create(loc, types, - constantI1(builder, loc, true), - /*withElseBlock=*/!types.empty()); - { - // 2 reduction variable maintained by us. - SmallVector ifRet = args.drop_front(numMetaReduc); - assert(ifRet.size() == args.size() - 1); - - OpBuilder::InsertionGuard guard(builder); - // If coord >= sliceHi. - if (!ifRet.empty()) { - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - YIELD(ifRet); - } - - // If coord < sliceHi. - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - // Delegates to users' callback. - bodyBuilder(builder, loc, iv, ifRet); - } - // Marks this special ifOp to avoid sparisification finalizing it. - ifOp->setAttr(getLoopEmitterLoopAttrName(), - StringAttr::get(builder.getContext(), "slice")); - // Insertion point restored to after ifOp. - SmallVector yields; - // Increase induction variable. - yields.push_back(ADDI(iv, c1)); - yields.append(ifOp.getResults().begin(), ifOp.getResults().end()); - YIELD(yields); - }); - - builder.setInsertionPointAfter(whileOp); - return std::make_pair(whileOp, whileOp.getResults().drop_front(numMetaReduc)); -} - -// Generates a loop nest that traverse all the unresolved levels in between. -// -// for(int i = 0; i < slicePos.size(); i+=2) { -// loopLo = slicePos[i]; -// loopHi = slicePos[i + 1]; -// -// // Then the same loop generated by genSliceLvlTraverse above. -// while (loopLo < loopHI) { -// if (pos[loopLo] < sliceHi) { -// bodyBuilder(); -// } else { -// break; -// } -// loopLo ++; -// } -// } -ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse( - OpBuilder &builder, Location loc, TensorId tid, - ArrayRef unResLvls, - std::optional> firstResLvl, ValueRange userReduc, - LoopBodyBuilder bodyBuilder) { - - Value c0 = C_IDX(0), c1 = C_IDX(1); - Value pos = c0; - OpBuilder::InsertPoint ip; - SmallVector innerArgs(userReduc.begin(), userReduc.end()); - scf::ForOp outerMost = nullptr; // the outermost loop. - - // Wraps body builder and inserts a extra counting instruction at the end. - auto wrapped = [bodyBuilder](OpBuilder &builder, Location loc, Value iv, - MutableArrayRef reduc) { - bodyBuilder(builder, loc, iv, reduc.drop_back()); - // Increments the counter. - reduc.back() = ADDI(reduc.back(), C_IDX(1)); - }; - - // FIXME: Need special handling when the previous unresolved slice is strided: - // We probably need to filter out coordinates that is not on stride. - if (firstResLvl.has_value()) { - // Overwrite position when the first level is fully resolved. - pos = posits[firstResLvl->first][firstResLvl->second]; - ip = builder.saveInsertionPoint(); - } else { - const SliceInfo &frontSlice = *unResLvls.back(); - Level firstLvl = *frontSlice.slicedOnLvl; - if (!lvlFullyResolved(tid, firstLvl)) { - if (isCompressedLT(lvlTypes[tid][firstLvl])) { - // An extra counter that tracks how many segments are there in the child - // compressed level. - innerArgs.push_back(c0); - // Overrides the user-provided builder. - bodyBuilder = wrapped; - unsigned depth = frontSlice.depth - 1; - Value offset = frontSlice.offset; - Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth]; - Value mSz = frontSlice.posTupleNum; - outerMost = builder.create( - loc, c0, mSz, c1, innerArgs, - [this, tid, firstLvl, offset, sPtrBuf, &ip, &pos, - &innerArgs](OpBuilder &builder, Location loc, Value iv, - ValueRange iterArgs) { - // generate traversal for each level. - Value loopLo = - loadSlicePos(builder, loc, sPtrBuf, iv, SlicePosKind::kLo); - Value loopHi = - loadSlicePos(builder, loc, sPtrBuf, iv, SlicePosKind::kHi); - // We need to remember the starting index for next level's - // position, because slice-driven loop breaks the level into - // non-consecutive segments. - updateSlicePos(builder, loc, sPtrBuf, iterArgs.back(), iv, - SlicePosKind::kNext); - - auto [size, stride] = sliceMeta[tid][firstLvl].back(); - assert(stride == 1 && "Not yet implemented"); - ValueRange itArgs = - genSliceLvlTraverseLoop( - builder, loc, loopLo, loopHi, offset, size, tid, firstLvl, - iterArgs, - [&](OpBuilder &builder, Location, Value iv, - MutableArrayRef reduc) { - ip = builder.saveInsertionPoint(); - pos = iv; - innerArgs.assign(reduc.begin(), reduc.end()); - }) - .second; - YIELD(itArgs); - }); - } else if (isDenseLT(lvlTypes[tid][firstLvl])) { - assert(firstLvl == 0); // This must be the first level. - Value lb = frontSlice.offset; - auto [sliceSz, stride] = - sliceMeta[tid][*frontSlice.slicedOnLvl][frontSlice.depth]; - assert(stride == 1 && "Not yet implemented"); - Value ub = ADDI(lb, sliceSz); - outerMost = builder.create( - loc, lb, ub, c1, innerArgs, - [&](OpBuilder &builder, Location loc, Value iv, - ValueRange iterArgs) { - ip = builder.saveInsertionPoint(); - pos = iv; - innerArgs.assign(iterArgs.begin(), iterArgs.end()); - }); - } - // We generated the loop for the first slice above, now remove it. - unResLvls = unResLvls.drop_back(); - } - } - // Reset the insertion point into the loop body. - builder.restoreInsertionPoint(ip); - if (!unResLvls.empty()) { - // Fills in dense slices levels in between. - SmallVector lbs, ubs, steps, lvlSzs; - for (const SliceInfo *slice : llvm::reverse(unResLvls)) { - Level sliceLvl = *slice->slicedOnLvl; - assert(isDenseLT(lvlTypes[tid][sliceLvl])); - Value offset = slice->offset; - auto [sliceSz, stride] = sliceMeta[tid][sliceLvl][slice->depth]; - assert(stride == 1 && "Not yet implemented"); - lbs.push_back(offset); - ubs.push_back(ADDI(offset, sliceSz)); - steps.push_back(c1); - lvlSzs.push_back(lvlSizes[tid][sliceLvl]); - } - auto denseNest = - scf::buildLoopNest(builder, loc, lbs, ubs, steps, innerArgs, - [&innerArgs, &lvlSzs, &pos, bodyBuilder]( - OpBuilder &builder, Location loc, ValueRange ivs, - ValueRange iterArgs) -> scf::ValueVector { - for (auto em : llvm::enumerate(ivs)) { - // Linearizes position: pos = (pos * lvlsize) + - // iv; - pos = MULI(pos, lvlSzs[em.index()]); - pos = ADDI(pos, em.value()); - } - innerArgs.assign(iterArgs.begin(), iterArgs.end()); - // Generates user request loop body. - bodyBuilder(builder, loc, pos, innerArgs); - return innerArgs; - }); - - if (!outerMost) { - // If the outermost loop has not been set, this is the outermost loop. - outerMost = denseNest.loops.front(); - } else { - // Otherwise we need to generate yield operations to link the SSA chain. - YIELD(denseNest.results); - } - } else { - assert(outerMost); - // Generates user request loop body. - bodyBuilder(builder, loc, pos, innerArgs); - YIELD(innerArgs); - } - assert(outerMost); - // Insert after current while operation. - builder.setInsertionPointAfter(outerMost); - return outerMost.getResults(); -} - -void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc, - TensorId tid, Level lvl) { - Value c0 = C_IDX(0), c1 = C_IDX(1); - if (isDenseLT(lvlTypes[tid][lvl])) { - // Dense slice begin is trivial. - sliceStack[tid].emplace_back(/*minCoord=*/c0, /*offset=*/c0, - /*nonEmpty=*/constantI1(builder, loc, true), - c0, lvl, /*depth=*/1); - return; - } - auto [nxSz, stride] = sliceMeta[tid][lvl][1]; - assert(stride == 1 && "Not yet implemented"); - Value sPtrBuf = slicePosBuffer[tid][lvl][0]; - const SparseTensorLevel &stl = *lvls[tid][lvl]; - - Value p = lvl == 0 ? c0 : posits[tid][lvl - 1]; - auto [pLo, pHi] = stl.peekRangeAt(builder, loc, p); - - // Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi] - updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo); - updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi); - // Slice over a resolved parent, we only need one pair of pos hi and lo to - // specify the current slice. - Value tupleNum = c1; - // This is an non empty tensor if pLo < pHi. - Value isNonEmpty = CMPI(ult, pLo, pHi); - // The minimal coord must be at the first on ordered level. - // FIXME: Technically we should load the coord only when the slice is - // nonempty. though we assume that even on empty sparse tensors, a non-empty - // ptr/idx buffer is allocated for each level so it would not cause OOB to - // avoid generating a ifOp here. - Value minCrd = stl.peekCrdAt(builder, loc, pLo); - - // FIXME: We need the relative offset related to the base slice. - Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty); - sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, tupleNum, lvl, - /*depth=*/1); -} - -// Fills in the slicePosBuffer before slice-driven loop begin. -// TODO: it can only handle all compressed tensors. -// -// // Loop generated by `genUnResolvedSliceTreeTraverse` -// for(int i = 0; i < slicePos.size(); i+=2) { -// loopLo = slicePos[i]; -// loopHi = slicePos[i + 1]; -// minCrd = max; -// while (loopLo < loopHi) { -// if (pos[loopLo] < sliceHi) { -// // bodyBuilder -// slicePos[tid].push_back(pos[loopLo]); -// slicePos[tid].push_back(pos[loopLo + 1]); -// minCrd = min(minCrd, crd[pos[loopLo]]); -// } else { -// break; -// } -// loopLo ++; -// } -// } -void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc, - TensorId tid, Level lvl) { - Value c0 = C_IDX(0); - unsigned depth = levelReducedDep[tid][lvl]; - // The remaining slice size after reduction. - Value remSz = sliceMeta[tid][lvl][depth + 1].first; - // Dense slice begin is trivial - if (isDenseLT(lvlTypes[tid][lvl])) { - sliceStack[tid].emplace_back(c0, c0, constantI1(builder, loc, false), c0, - lvl, depth + 1); - return; - } - - assert(isCompressedLT(lvlTypes[tid][lvl])); - // Unhandled Cases: - // - // 1st, lvl = prevSlicedLvl, i.e., t[d0 + d1 + d2,...] (more than one - // variable need to be reduced on the same level). - // - // 2nd, lvl > prevSliceLvl + 1, i.e., t[..., d2, d3 + d4] (having a - // simple dim expression in between). - assert(lvl == *sliceStack[tid].back().slicedOnLvl + 1); - - SmallVector unResSlices; - std::optional> firstResLvl; - for (Level curLvl = lvl; curLvl >= 1; curLvl--) { - Level prevLvl = curLvl - 1; - if (lvlFullyResolved(tid, prevLvl)) { - firstResLvl = std::make_pair(tid, prevLvl); - break; - } - unResSlices.push_back(&getMostRecentSliceOnLvl(tid, prevLvl)); - if (!isDenseLT(lvlTypes[tid][prevLvl])) { - break; - } - } - - assert(!unResSlices.empty() && - !lvlFullyResolved(tid, *unResSlices.front()->slicedOnLvl)); - - Value sPtrBuf = slicePosBuffer[tid][lvl].back(); - SmallVector reduc = { - constantI1(builder, loc, false), // isNonEmpty - lvlSizes[tid][lvl], // minCoord - c0, // memSize - }; - - ValueRange result = genUnResolvedSliceTreeTraverse( - builder, loc, tid, unResSlices, firstResLvl, reduc, - [this, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc, Value iv, - MutableArrayRef reduc) { - Value &nonEmpty = reduc[0]; - Value &minCrd = reduc[1]; - Value &curTupleCnt = reduc[2]; - - const SparseTensorLevel &stl = *lvls[tid][lvl]; - auto [sPLo, sPHi] = stl.peekRangeAt(builder, loc, iv); - - // isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is - // one non-empty lvl, the slice is non-empty. - Value lvlNonEmpty = CMPI(ult, sPLo, sPHi); - nonEmpty = builder.create(loc, lvlNonEmpty, nonEmpty); - - // Update the minimum coordinate. - auto ifNonEmpty = builder.create(loc, builder.getIndexType(), - lvlNonEmpty, true); - { - // Generate Code as follows. - // - // if (nonEmpty) { - // minCrd = min(minCrd, crd[pos[pLo]]); - // } - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(ifNonEmpty.thenBlock()); - Value curC = stl.peekCrdAt(builder, loc, sPLo); - Value isSmaller = CMPI(ult, curC, minCrd); - Value newMin = SELECT(isSmaller, curC, minCrd); - YIELD(newMin); - builder.setInsertionPointToStart(ifNonEmpty.elseBlock()); - YIELD(minCrd); - } - minCrd = ifNonEmpty.getResult(0); - updateSlicePos(builder, loc, sPtrBuf, sPLo, curTupleCnt, - SlicePosKind::kLo); - updateSlicePos(builder, loc, sPtrBuf, sPHi, curTupleCnt, - SlicePosKind::kHi); - curTupleCnt = ADDI(curTupleCnt, C_IDX(1)); - }); - - Value isNonEmpty = result[0]; - Value minCrd = result[1]; - // Two metadata [memSize, idx]. - // FIXME: we need the relative offset related to the base slice. - Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty); - sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, result[2], lvl, - depth + 1); -} - -bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, - Level lvl) { - Value curLvlIdx = C_IDX(0); - if (depFullyReduced(tid, lvl)) { - if (lvl == 0 || trivialSlice[tid][lvl]) { - sliceTupleNxStartIdx[tid][lvl] = C_IDX(0); - } else { - if (isDenseLT(lvlTypes[tid][lvl])) { - sliceTupleNxStartIdx[tid][lvl] = sliceTupleNxStartIdx[tid][lvl - 1]; - } else { - assert(isCompressedLT(lvlTypes[tid][lvl])); - curLvlIdx = ADDI(sliceTupleNxStartIdx[tid][lvl - 1], - sliceTupleFwdCnt[0][lvl - 1]); - sliceTupleNxStartIdx[tid][lvl] = - loadSlicePos(builder, loc, slicePosBuffer[tid][lvl].back(), - curLvlIdx, SlicePosKind::kNext); - } - } - if (isDenseLT(lvlTypes[tid][lvl])) - return true; - - Value sPosBuf = slicePosBuffer[tid][lvl].back(); - // If constraints on the tensor is fully resolved. We do not need to - // generates slice begin any more, instead we fall back to TACO-based - // algorithm to (co)iterates over the slice. - Value tupleIdx = curLvlIdx; - posits[tid][lvl] = - loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kLo); - highs[tid][lvl] = - loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kHi); - return true; - } - - // Only when the level is sorted, the next-non-empty slice can be computed - // efficiently. - const LevelType lvlType = lvlTypes[tid][lvl]; - assert(isOrderedLT(lvlType)); - if (isSingletonLT(lvlType)) { - llvm_unreachable("TODO: dense level should be easy to support, while " - "singleton level requires more efforts"); - } - - assert(!dependentLvlMap[tid][lvl].empty()); - assert(!sliceStack[tid].empty()); - - const SliceInfo &sliceInfo = sliceStack[tid].back(); - auto baseEnc = getSparseTensorEncoding(tensors[tid].getType()); - if (baseEnc.isSlice()) - llvm_unreachable("TODO: not yet implemented"); - - if (sliceInfo.isInitialTensor() || - (lvl >= 1 && lvlFullyResolved(tid, lvl - 1))) { - // First level or previous level has been full resolved. - trivialSlice[tid][lvl] = true; - genResolvedSliceBegin(builder, loc, tid, lvl); - } else { - // The previous level has not been full resolved. - trivialSlice[tid][lvl] = false; - genUnResolvedSliceBegin(builder, loc, tid, lvl); - } - return false; -} - -std::tuple -LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc, - TensorId tid, Level lvl) { - if (!isCompressedLT(lvlTypes[tid][lvl])) - llvm_unreachable("TODO"); - - // else generate code to compute next non empty slice. - Value c0 = C_IDX(0), c1 = C_IDX(1); - - SliceInfo &info = sliceStack[tid].back(); - assert(info.slicedOnLvl == lvl); - // - // We forward to the next non empty slice by - // if (minCrd > offset) { - // offset += 1 - // } else { - // minCrd = nextMinInSlice(); - // offset = minCrd - size + 1; - // } - // - // if (offset + size > parents.size) - // isNonEmpty = false; - // - Value absOffset = info.offset; - SmallVector reduc = {info.minCrd, info.isNonEmpty, absOffset}; - Value sPtrBuf = slicePosBuffer[tid][lvl][info.depth - 1]; - Value fastPathP = CMPI(ugt, info.minCrd, absOffset); - auto ifOp = builder.create(loc, ValueRange(reduc).getTypes(), - fastPathP, true); - { - OpBuilder::InsertionGuard guard(builder); - // Take the fast path - // if (minCrd > offset) { - // return offset += 1 - // } - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - reduc[2] = ADDI(absOffset, c1); - // Yield offset + 1. - YIELD(reduc); - - // else /*minCrd == offset*/ { - // for (i = 0; i < slicePos.size(); i+=kSliceIterWidth) { - // if (crd[pos[slicePos[i]]] == minCrd) { - // slicePos[i]++; - // } - // minCrd=min(minCrd, crd[pos[slicePos[i]]]); - // } - // offset = minCrd - size + 1; - // } - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - reduc[2] = absOffset; // restore value. - Value mSz = info.posTupleNum; // tuple number. - reduc[0] = lvlSizes[tid][lvl]; // next min coord - reduc[1] = constantI1(builder, loc, false); // isNonEmpty - auto loopArgs = static_cast(reduc).drop_back(); - auto forOp = scf::buildLoopNest( - builder, loc, c0, mSz, c1, loopArgs, - [this, tid, lvl, c1, sPtrBuf, - &info](OpBuilder &builder, Location loc, ValueRange ivs, - ValueRange iterArgs) -> scf::ValueVector { - Value curMinCrd = iterArgs[0]; - Value isNonEmpty = iterArgs[1]; - - Type idxTp = builder.getIndexType(); - Value pLo = loadSlicePos(builder, loc, sPtrBuf, ivs.front(), - SlicePosKind::kLo); - Value pHi = loadSlicePos(builder, loc, sPtrBuf, ivs.front(), - SlicePosKind::kHi); - // - // if (pLo < pHi) // Only loads when inbound. - // coord = load[pLo] - // if coord == minCrd - // pLo += 1 - // - // if (pLo < pHi) - // curMinCrd = min(curMinCrd, load[pLo]) - // - Value pred = CMPI(ult, pLo, pHi); - auto advPLo = builder.create(loc, idxTp, pred, true); - /* if pLo < pHi */ { - builder.setInsertionPointToStart(&advPLo.getThenRegion().front()); - // coord = load[pLo] - Value coord = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo); - Value pred = CMPI(eq, coord, info.minCrd); - auto ifEqual = builder.create(loc, idxTp, pred, true); - /* if coord == minCrd */ { - builder.setInsertionPointToStart( - &ifEqual.getThenRegion().front()); - Value newPlo = ADDI(pLo, c1); - // Updates the cache. - updateSlicePos(builder, loc, sPtrBuf, newPlo, ivs.front(), - SlicePosKind::kLo); - YIELD(newPlo); - } - /* else coord != minCrd */ { - builder.setInsertionPointToStart( - &ifEqual.getElseRegion().front()); - YIELD(pLo); - } - builder.setInsertionPointAfter(ifEqual); - YIELD(ifEqual.getResults()); - } - /* else pLo >= pHi */ { - builder.setInsertionPointToStart(&advPLo.getElseRegion().front()); - YIELD(pLo); - } - - builder.setInsertionPointAfter(advPLo); - pLo = advPLo.getResult(0); - Value lvlNonEmpty = CMPI(ult, pLo, pHi); - // Update minCrds - auto newMin = - builder.create(loc, idxTp, lvlNonEmpty, true); - builder.setInsertionPointToStart(&newMin.getThenRegion().front()); - YIELD(lvls[tid][lvl]->peekCrdAt(builder, loc, pLo)); - - builder.setInsertionPointToStart(&newMin.getElseRegion().front()); - YIELD(curMinCrd); - builder.setInsertionPointAfter(newMin); - - // isNonEmpty = isNonEmpty || lvlNonEmpty - isNonEmpty = - builder.create(loc, lvlNonEmpty, isNonEmpty); - curMinCrd = builder.create( - loc, CMPI(ult, newMin.getResult(0), curMinCrd), - newMin.getResult(0), curMinCrd); - return {curMinCrd, isNonEmpty}; - }); - - builder.setInsertionPointAfter(forOp.loops.front()); - // minOffset = minCrd + 1 >= size ? minCrd + 1 - size : c0 - Value tmp = ADDI(forOp.results.front(), c1); - auto [size, stride] = sliceMeta[tid][lvl][info.depth]; - assert(stride == 1 && "Not yet implemented"); - Value minOffset = SUBI(tmp, size); - Value p = CMPI(uge, tmp, size); - minOffset = SELECT(p, minOffset, c0); - - SmallVector yields; - yields.assign(forOp.results.begin(), forOp.results.end()); - yields.push_back(minOffset); - YIELD(yields); - } - - Value nextMinCrd = ifOp.getResults()[0]; - Value nextNonEmpty = ifOp.getResults()[1]; - - // The next offset should at least be offset + 1; - Value minOffset = ifOp.getResults()[2]; - Value nxOffset = ADDI(info.offset, c1); - Value maxPred = CMPI(ugt, minOffset, nxOffset); - Value nextAbsOffset = SELECT(maxPred, minOffset, nxOffset); - - auto [size, stride] = sliceMeta[tid][lvl][info.depth]; - assert(stride == 1 && "Not yet implemented"); - Value sliceUB = ADDI(nextAbsOffset, size); - - // FIXME: this only works if there is only one parent. - assert(info.depth - 1 == 0); - // nextNonEmpty = nextNonEmpty && slice upper bound <= parent upperbound. - nextNonEmpty = ANDI(nextNonEmpty, CMPI(ule, sliceUB, lvlSizes[tid][lvl])); - - // FIXME: compute relative offset. - assert(info.depth - 1 == 0); - return std::make_tuple(nextNonEmpty, nextMinCrd, nextAbsOffset); -} - #undef CMPI #undef C_IDX #undef YIELD diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h index 450678924c138..d0f447d926f71 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h @@ -124,21 +124,10 @@ class LoopEmitter { /// Exits the current loop sequence, this will reset universal index to 0. void exitCurrentLoopSeq(OpBuilder &builder, Location loc); - /// Enters a loop that tries to locate a coordinates in a sparse level based - /// on the value evaluated by the provided affine expression. - /// DEPRECATED: affine index expression should be handled by index reduction - /// loop, filter loop-based solution is slow. - Operation *enterFilterLoopOverTensorAtLvl(OpBuilder &builder, Location loc, - TensorId tid, Level lvl, - AffineExpr affine, - MutableArrayRef reduc = {}); - /// Emits the address for a dense level based on the value evaluated by the /// provided affine expression. - /// DEPRECATED: affine index expression should be handled by index reduction - /// loop, filter loop-based solution is slow. - void genDenseAffineAddress(OpBuilder &builder, Location loc, - TensorLevel tidLvl, AffineExpr lvlExpr); + void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, + TensorLevel tidLvl, AffineExpr lvlExpr); // TODO: Get rid of `lvls` in the argument list? Track the level we // are currently at internally. Then it would be enterNextLvlForTensor. @@ -153,7 +142,7 @@ class LoopEmitter { Operation *enterCoIterationOverTensorsAtLvls( OpBuilder &builder, Location loc, ArrayRef tidLvls, MutableArrayRef reduc = {}, bool isParallel = false, - bool genDedup = false, bool needsUniv = false); + bool needsUniv = false); /// Generates code to exit the current loop (e.g., generates yields, forwards /// loop induction variables, etc). @@ -224,21 +213,16 @@ class LoopEmitter { }); } - template - auto unpackTensorLevelFromCondRange(ContainerTy &&c) const { - using EltTy = decltype(*c.begin()); - static_assert(std::is_same_v, TensorLvlCond>, - "Must be unpacking a TensorLvlCond range"); - return unpackTensorLevelRange( - llvm::make_first_range(std::forward(c))); - } - /// /// Getters. /// - const std::vector> &getPosits() const { return posits; }; - const std::vector> &getCoords() const { return coords; }; - const std::vector> &getHighs() const { return highs; }; + Value getValPosits(TensorId tid) const { + Value lastLvlPos = iters[tid].back().back()->getCurPosition().first; + return lastLvlPos; + }; + Value getCoord(TensorId tid, Level lvl) const { + return getCurIterator(tid, lvl).getCrd(); + }; const std::vector &getValBuffer() const { return valBuffer; }; constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() { @@ -250,22 +234,12 @@ class LoopEmitter { /// Structure definitions that hold different kinds of loops information. /// - // A tuple that stored the slice-driven loop information. - struct SliceLoopInfo final { - SliceLoopInfo(TensorId tid, Level lvl, bool reduced) - : tid(tid), lvl(lvl), reduced(reduced) {} - TensorId tid; - Level lvl; - bool reduced; - }; // LoopInfo stores information of a loop generated by LoopEmitter. E.g., // the set of tensors levels that the loop is iterating over. struct LoopInfo final { - LoopInfo(ArrayRef trivialTidLvls, - ArrayRef sliceDrivenInfo, Operation *loop, - Block *userBlock, Value iv, StringAttr loopTag) - : trivialTidLvls(trivialTidLvls), sliceDrivenInfo(sliceDrivenInfo), - loop(loop), userCodeBlock(userBlock), iv(iv) { + LoopInfo(ArrayRef tidLvls, Operation *loop, Block *userBlock, + Value iv, StringAttr loopTag) + : tidLvls(tidLvls), loop(loop), userCodeBlock(userBlock), iv(iv) { // Attached a special tag to loop emitter generated loop. if (loopTag) loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag); @@ -274,124 +248,15 @@ class LoopEmitter { // used as the condition for the generated loop. Extra information is // required for levels with non-tivial index expressions, which is // maintained by the sliceDrivenInfo array below. - const llvm::SmallVector trivialTidLvls; - // The set of , with *only* non-trivial index expressions, that - // are used as the condition for the generated loop. - const llvm::SmallVector sliceDrivenInfo; + const llvm::SmallVector tidLvls; const Operation *loop; // the loop operation Block *const userCodeBlock; // the block holding users' generated code. const Value iv; // the induction variable for the loop }; - // SliceInfo stores information of an extracted slice for slice-driven loop. - // E.g., the in-scope SSA values for the minimum coordinates and offset for - // the slice, etc. - struct SliceInfo final { - // Note that we do not need to create a actual sparse tensor slice but - // instead only need to maintain the metadata of the slice. - SliceInfo(Value minCrd, Value offset, Value isNonEmpty, Value posTupleNum, - std::optional slicedOnLvl, unsigned depth) - : minCrd(minCrd), offset(offset), isNonEmpty(isNonEmpty), - posTupleNum(posTupleNum), slicedOnLvl(slicedOnLvl), depth(depth) { - // TODO: use std::optional> - assert(!slicedOnLvl || minCrd); - } - - // Whether this is the tensor that has not yet been sliced. - bool isInitialTensor() const { return !slicedOnLvl.has_value(); } - - Value minCrd; // the minimum coordinate of the slice. - Value offset; // the *absolute* offset of the current slice. - Value isNonEmpty; // whether the slice is empty. - Value posTupleNum; // The number of position tuples used in the slice. - std::optional slicedOnLvl; // the level on which the slice is done - unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]). - }; - - /// - /// Enums for different kinds of loop conditions. - /// - - // The bit indicating whether the loop conditions is sparse. - static constexpr uint8_t kSparseCond = 1 << 3; - // The bit indicating whether the loop iterates over sparse tensor slices - // (i.e., with non-empty SliceDimAttr). - static constexpr uint8_t kSliceCond = 1 << 2; - // The bit indicating whether the loop iterates over tensor levels with - // non-trivial affine index reduction. - static constexpr uint8_t kAffineIdxCond = 1 << 1; - // The bit indicating whether the loop iterates over tensor levels with - // non-trivial affine index reduction, and it is not fully reduced. - static constexpr uint8_t kAffineIdxCondUnRed = 1 << 0; - - enum class LoopCondKind : uint8_t { - // Dense conditions. - DenseCond = 0, - DenseSliceCond = kSliceCond, - DenseAffineCond = kAffineIdxCond, - DenseAffineUnRedCond = kAffineIdxCond | kAffineIdxCondUnRed, - // Sparse Conditions. - SparseCond = kSparseCond, - SparseSliceCond = kSparseCond | kSliceCond, - SparseAffineCond = kSparseCond | kAffineIdxCond, - SparseAffineUnRedCond = kSparseCond | kAffineIdxCond | kAffineIdxCondUnRed, - }; - using TensorLvlCond = std::pair; - - /// Sparse or dense loop condition. - static bool isSparseCond(LoopCondKind k) { - return static_cast(k) & kSparseCond; - } - static bool isDenseCond(LoopCondKind k) { return !isSparseCond(k); } - - /// Whether loops over sparse tensor slices or sparse tensors. - static bool isSliceCond(LoopCondKind k) { - return static_cast(k) & kSliceCond; - } - - /// Affine or trivial index expression loop condition. - static bool isAffineIdxCond(LoopCondKind k) { - return static_cast(k) & kAffineIdxCond; - } - static bool isTrivalIdxCond(LoopCondKind k) { return !isAffineIdxCond(k); } - - /// Whether the affine index expression is fully reduced. - static bool isAffineIdxUnRedCond(LoopCondKind k) { - return isAffineIdxCond(k) && static_cast(k) & kAffineIdxCondUnRed; - } - static bool isAffineIdxRedCond(LoopCondKind k) { - return isAffineIdxCond(k) && !isAffineIdxUnRedCond(k); - } - - // Whether the loop condition kind requires extra check inside the loop body. - // E.g., to iterate over sparse tensor slice, we need to check whether the - // current cooridnate is on the slice (e.g., due to stride) or not. - static bool isCondWithExtraCheck(LoopCondKind k) { - return isSparseCond(k) && (isSliceCond(k) || isAffineIdxUnRedCond(k)); - } - - static LoopCondKind makeLoopCondKind(bool isSparse, bool isSlice, - bool isAffine, bool isUnRedu) { - assert(!isUnRedu || isAffine); - uint8_t bits = 0; - bits = isSparse ? bits | kSparseCond : bits; - bits = isSlice ? bits | kSliceCond : bits; - bits = isAffine ? bits | kAffineIdxCond : bits; - bits = isUnRedu ? bits | kAffineIdxCondUnRed : bits; - LoopCondKind kind = static_cast(bits); - - // Sanity checks. - assert(isSparse == isSparseCond(kind)); - assert(isSlice == isSliceCond(kind)); - assert(isAffine == isAffineIdxCond(kind)); - assert(isUnRedu == isAffineIdxUnRedCond(kind)); - return kind; - } - - void categorizeLoopCondition(ArrayRef tidLvls, - SmallVectorImpl &dnConds, - SmallVectorImpl &spConds); - + void categorizeIterators(ArrayRef tidLvls, + SmallVectorImpl &raIters, + SmallVectorImpl &spIters); /// /// LoopEmitter internal helper functions. /// @@ -400,21 +265,7 @@ class LoopEmitter { MutableArrayRef)>; /// Whether the list of the sparse condition should be iterated by for loop. - bool shouldIteratedByForLoop(ArrayRef spConds, bool genDedup); - - /// Linearizes address for dense dimension (i.e., p = (i * d0) + j). - Value genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl, - Value iv); - - /// Generates the segment high for a non-unique level (to fast forward - /// duplicated coordinates). That is, it generates the code: - /// - /// crd = coordinates_tid_lvl[pos] - /// while (pos < pHi && coordinates_tid_lvl[pos] == crd) - /// pos++; - /// ; - Value genSegmentHigh(OpBuilder &builder, Location loc, TensorId tid, - Level lvl, Value pos, Value pHi); + bool shouldIteratedByForLoop(ArrayRef spIters); /// Generates instructions to compute the coordinate of tensors[tid][lvl] /// under the current loop context. The final argument is the @@ -423,13 +274,6 @@ class LoopEmitter { Value genSparseCrd(OpBuilder &builder, Location loc, TensorId tid, Level dstLvl); - /// Generates a predicate to determine whether the tranformed coordinates are - /// in the given slice. - /// Returns std::pair - std::pair genSliceLegitPredicate(OpBuilder &builder, - Location loc, Value crd, - TensorId tid, Level lvl); - bool isSynTensor(TensorId tid) const { return tid == getSynTensorId(); } bool isOutputTensor(TensorId tid) const { @@ -441,7 +285,7 @@ class LoopEmitter { } bool isValidLevel(TensorId tid, Level lvl) const { - return tid < lvlTypes.size() && lvl < lvlTypes[tid].size(); + return tid < lvls.size() && lvl < lvls[tid].size(); } /// Prepares loop for iterating over `tensor[lvl]`, under the assumption @@ -449,13 +293,6 @@ class LoopEmitter { void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid, Level lvl); - /// Enter dense tensor levels. Since the dense tensor condition could be - /// optimized from the loop condition, we need to compute the - /// positions/coordinates inside the loop body. - void enterTensorsAtDenseLvls(OpBuilder &builder, Location loc, - ArrayRef dnConds, Value iv, - SmallVectorImpl &sliceInfo); - /// Emits a for loop to iterate over a tensor level with the provided /// lower bound `lo` and upper bound `hi`. Apart from iterating just /// single tensor level, for loops can be used for slice-driven loop on @@ -463,9 +300,9 @@ class LoopEmitter { /// Returns a pair: the loop generated and the value for the induction /// variable. std::pair - emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid, - Level lvl, Value lo, Value hi, - MutableArrayRef reduc, bool isParallel); + emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc, + SparseIterator &iter, MutableArrayRef reduc, + bool isParallel); /// Emits a while loop to co-iterate over a list of sparse condition, or /// (complex) single sparse condition that can not be handled by for loop @@ -475,26 +312,9 @@ class LoopEmitter { /// iterated). std::pair emitWhileLoopOverTensorsAtLvls(OpBuilder &builder, Location loc, - ArrayRef spConds, + ArrayRef iters, MutableArrayRef reduc, bool needsUniv); - /// Generates the while loop condition for the given tensor level condition. - Value genWhileLoopConditions(OpBuilder &builder, Location loc, ValueRange ivs, - TensorLvlCond cond); - - /// Generates the while loop body for the given tensor level condition. - std::optional genWhileLoopBody(OpBuilder &builder, Location loc, - ValueRange ivs, TensorLvlCond cond); - - /// Generates the values (to forward the loop) if the extra check failes. - /// E.g., to iterate over a sparse tensor slice, we need: - /// - /// pos = onSlice(curCrd) ? pos : pos + 1 - /// - /// to skip invalid coordinate that is included in the slice. - ValueRange genCheckedValue(OpBuilder &builder, Location loc, Value pred, - ValueRange curArg, TensorLvlCond cond); - /// Exits a for loop, returns the reduction results, e.g., /// For sequential for loops: /// %ret = for () { @@ -530,85 +350,23 @@ class LoopEmitter { // Slice-driven loop related methods. // - void initSliceDriven(OpBuilder &builder, Location loc); - - /// Retrieves the most recent slice on lvl. To reduce affine expression like - /// d0 + d1 + d2, we need two slices (one of size d1 + d2, and the other of - /// size d2). This methods returns the latter slice (of size d2). - const SliceInfo &getMostRecentSliceOnLvl(TensorId tid, Level lvl); + void initSubSectIterator(OpBuilder &builder, Location loc); - /// Similar to getMostRecentSliceOnLvl, but yields error when the most recent - /// slice is not the final slice needed to fully reduced the dependencies. - const SliceInfo &getFinalSliceOnLvl(TensorId tid, Level lvl) { - const SliceInfo &info = getMostRecentSliceOnLvl(tid, lvl); - assert(info.depth == dependentLvlMap[tid][lvl].size() - 1); - return info; - } - - /// Get the remaining number of constraints needed to fully *resolve* - /// dependent levels on tensor[tid]. - unsigned remDepOnLevel(TensorId tid, Level lvl) const; + /// Get the reduced number of contraints on tensor[tid][lvl]. + unsigned redDepOnLevel(TensorId tid, Level lvl) const { + return levelReducedDep[tid][lvl]; + }; - /// Whether the tid, lvl is fully *reduced*, i.e., the non-trivial index - /// expression has been reduced to a trivial one. - /// E.g., A[i + j] => A[i + 2] (j is reduced) - bool depFullyReduced(TensorId tid, Level lvl) const { - return remDepOnLevel(tid, lvl) == 1; - } + SparseIterator &getCurIterator(TensorId tid, Level lvl) const { + if (dependentLvlMap[tid][lvl].empty()) + return *iters[tid][lvl].back(); - /// Whether the tid, lvl is fully resolved, i.e., we entered the level already - /// (the index on that level is determined). - /// E.g., A[i + j] => A[2 + 3] (both i and j become invariants for inner - /// loops). - bool lvlFullyResolved(TensorId tid, Level lvl) const { - return remDepOnLevel(tid, lvl) == 0; + assert(redDepOnLevel(tid, lvl) >= 1); + return *iters[tid][lvl][redDepOnLevel(tid, lvl) - 1]; } - /// Generates a whileOp to iterate over a subset of coordinates on tid on lvl - /// using the pHi and pLo provided, the loop break on the first coordinate - /// that exceeds the slice boundary (i.e., coord >= slice.offset + - /// slice.size). - std::pair - genSliceLvlTraverseLoop(OpBuilder &builder, Location loc, Value pLo, - Value pHi, Value offset, Value size, TensorId tid, - Level lvl, ValueRange userReduc, - LoopBodyBuilder bodyBuilder); - - /// Generates a nested loop that iterates over tid on all the coordinates on - /// lvl. - ValueRange genUnResolvedSliceTreeTraverse( - OpBuilder &builder, Location loc, TensorId tid, - ArrayRef unResLvls, - std::optional> firstResLvl, - ValueRange userReduc, LoopBodyBuilder bodyBuilder); - - /// Generates code to get the first non-empty slice of tid on lvl, when all - /// the previous level before `lvl` are resolved (or lvl is the first level). - /// - /// This is the simple case because the previous level are resolved into a - /// single node in the storage tree. - void genResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid, - Level lvl); - - /// Generates code to get the first non-empty slice of tid on lvl, when - /// the previous levels before `lvl` are unresolved - /// - /// This is the complex case because the previous levels corresponding to a - /// range of nodes in the storage tree. - void genUnResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid, - Level lvl); - - /// Generates code to get the first non-empty slice of tid on lvl. - /// return true if has already been resolved. - bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl); - - /// Generates code to get the next non-empty slices of tid on lvl. - /// Returns a tuple of values for (see - /// SliceInfo) respectively. - std::tuple genSliceNextInduction(OpBuilder &builder, - Location loc, - TensorId tid, - Level lvl); + std::unique_ptr + makeLevelIterator(OpBuilder &builder, Location loc, TensorId tid, Level l); /// A optional string attribute that should be attached to the loop /// generated by loop emitter, it might help following passes to identify @@ -622,50 +380,19 @@ class LoopEmitter { // // Fields which have `numTensor` many entries. // - // TODO: switch to an AOS style to avoid any possible mismatches. - // /// Input and (optional) output tensors. std::vector tensors; - /// Level-types for each `(TensorId, Level)` pair. - std::vector> lvlTypes; - // Sparse iteration information for each `(TensorId, Level)` pair. - // These arrays are updated to remain current within the current loop. - std::vector> posits; - /// The collection of coordinates for a given element (one such - /// collection for each tensor). - std::vector> coords; - // The segment upper bound for non-uniques level after de-duplication. - std::vector> segHi; - std::vector> highs; - std::vector> lvlSizes; + std::vector loopHighs; std::vector>> lvls; + std::vector>>> iters; std::vector valBuffer; // to_value - // - // Slice-driven loops related fields. - // - - /// Whether the sparse input is a slice. - std::vector isSparseSlices; - /// Values related to slices. - std::vector> sliceOffsets; - std::vector> sliceStrides; - // Map from [tid, level] to a list of dependent [tidlevel, coefficient]. // See comments for `DependentLvlGetter`. std::vector>>> dependentLvlMap; - // The cached position buffer for the slices, they serve the same purpose as - // ptrBuffer for compressed dimensions. - // But they always starts with the first pidx pointing to coord > slice.offset - // to avoid iteration from the beginning. - std::vector>> slicePosBuffer; - std::vector> sliceTupleNxStartIdx; - std::vector> sliceTupleFwdCnt; - std::vector> trivialSlice; - // The (size, stride) for each conceptual slice used for index reduction // loops. std::vector>>> sliceMeta; @@ -673,9 +400,6 @@ class LoopEmitter { // The number of reduced dependencies on a tensor level so far. std::vector> levelReducedDep; - // sliceStack[tid] holds the generated slice stack on tid. - std::vector> sliceStack; - // // Fields which have at most `numLoops` many entries. // @@ -684,11 +408,9 @@ class LoopEmitter { /// alive. std::vector loopStack; - // Loop Sequence Stack, stores the unversial index for the current loop - // sequence. and a list of tids which was taken sliced. - // TODO: maybe we should have a LoopSeqInfo - std::vector>>> - loopSeqStack; + // Loop Sequence Stack, stores the universal index for the current loop + // sequence. and a list of tid level that the loop sequence traverse. + std::vector>> loopSeqStack; }; } // namespace sparse_tensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp index aea0910d980ab..22e65be8782fb 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp @@ -9,31 +9,36 @@ #include "SparseTensorLevel.h" #include "CodegenUtils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" using namespace mlir; using namespace mlir::sparse_tensor; using ValuePair = std::pair; +using ValueTuple = std::tuple; //===----------------------------------------------------------------------===// // File local helper functions/macros. //===----------------------------------------------------------------------===// #define CMPI(p, lhs, rhs) \ - (b.create(l, arith::CmpIPredicate::p, (lhs), (rhs))) + (b.create(l, arith::CmpIPredicate::p, (lhs), (rhs)) \ + .getResult()) +#define C_FALSE (constantI1(b, l, false)) +#define C_TRUE (constantI1(b, l, true)) #define C_IDX(v) (constantIndex(b, l, (v))) #define YIELD(vs) (b.create(l, (vs))) -#define ADDI(lhs, rhs) (b.create(l, (lhs), (rhs))) -#define ANDI(lhs, rhs) (b.create(l, (lhs), (rhs))) -#define SUBI(lhs, rhs) (b.create(l, (lhs), (rhs))) -#define MULI(lhs, rhs) (b.create(l, (lhs), (rhs))) -#define REMUI(lhs, rhs) (b.create(l, (lhs), (rhs))) -#define DIVUI(lhs, rhs) (b.create(l, (lhs), (rhs))) -#define SELECT(c, lhs, rhs) (b.create(l, (c), (lhs), (rhs))) - -static ValuePair constantRange(OpBuilder &b, Location l, Value lo, Value sz) { - return std::make_pair(lo, ADDI(lo, sz)); -} +#define ADDI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) +#define ORI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) +#define ANDI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) +#define SUBI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) +#define MULI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) +#define MINUI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) +#define REMUI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) +#define DIVUI(lhs, rhs) (b.create(l, (lhs), (rhs)).getResult()) +#define SELECT(c, lhs, rhs) \ + (b.create(l, (c), (lhs), (rhs)).getResult()) //===----------------------------------------------------------------------===// // SparseTensorLevel derived classes. @@ -43,11 +48,12 @@ namespace { class SparseLevel : public SparseTensorLevel { public: - SparseLevel(LevelType lt, Value lvlSize, Value crdBuffer) - : SparseTensorLevel(lt, lvlSize), crdBuffer(crdBuffer) {} + SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize, + Value crdBuffer) + : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {} - Value peekCrdAt(OpBuilder &b, Location l, Value pos) const override { - return genIndexLoad(b, l, crdBuffer, pos); + Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override { + return genIndexLoad(b, l, crdBuffer, iv); } protected: @@ -56,10 +62,9 @@ class SparseLevel : public SparseTensorLevel { class DenseLevel : public SparseTensorLevel { public: - DenseLevel(Value lvlSize) : SparseTensorLevel(LevelType::Dense, lvlSize) { - // Dense level, loop upper bound equals to the level size. - loopHi = lvlSize; - } + DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded) + : SparseTensorLevel(tid, lvl, LevelType::Dense, lvlSize), + encoded(encoded) {} Value peekCrdAt(OpBuilder &, Location, Value pos) const override { return pos; @@ -68,14 +73,22 @@ class DenseLevel : public SparseTensorLevel { ValuePair peekRangeAt(OpBuilder &b, Location l, Value p, Value max) const override { assert(max == nullptr && "Dense level can not be non-unique."); - return constantRange(b, l, C_IDX(0), lvlSize); + if (encoded) { + Value posLo = MULI(p, lvlSize); + return {posLo, lvlSize}; + } + // No need to linearize the position for non-annotated tensors. + return {C_IDX(0), lvlSize}; } + + const bool encoded; }; class CompressedLevel : public SparseLevel { public: - CompressedLevel(LevelType lt, Value lvlSize, Value posBuffer, Value crdBuffer) - : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {} + CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize, + Value posBuffer, Value crdBuffer) + : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {} ValuePair peekRangeAt(OpBuilder &b, Location l, Value p, Value max) const override { @@ -84,7 +97,7 @@ class CompressedLevel : public SparseLevel { Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1))); return {pLo, pHi}; } - llvm_unreachable("TODO: dedup not implemented"); + llvm_unreachable("compressed-nu should be the first non-unique level."); } private: @@ -93,15 +106,13 @@ class CompressedLevel : public SparseLevel { class LooseCompressedLevel : public SparseLevel { public: - LooseCompressedLevel(LevelType lt, Value lvlSize, Value posBuffer, - Value crdBuffer) - : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {} + LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize, + Value posBuffer, Value crdBuffer) + : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {} ValuePair peekRangeAt(OpBuilder &b, Location l, Value p, Value max) const override { - // Allows this? assert(max == nullptr && "loss compressed level can not be non-unique."); - p = MULI(p, C_IDX(2)); Value pLo = genIndexLoad(b, l, posBuffer, p); Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1))); @@ -114,68 +125,1176 @@ class LooseCompressedLevel : public SparseLevel { class SingletonLevel : public SparseLevel { public: - SingletonLevel(LevelType lt, Value lvlSize, Value crdBuffer) - : SparseLevel(lt, lvlSize, crdBuffer) {} + SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize, + Value crdBuffer) + : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {} ValuePair peekRangeAt(OpBuilder &b, Location l, Value p, - Value max) const override { - if (max == nullptr) - return constantRange(b, l, p, C_IDX(1)); - llvm_unreachable("TODO: dedup not implemented"); + Value segHi) const override { + if (segHi == nullptr) + return {p, ADDI(p, C_IDX(1))}; + + // Use the segHi as the loop upper bound. + return {p, segHi}; } }; class TwoOutFourLevel : public SparseLevel { public: - TwoOutFourLevel(LevelType lt, Value lvlSize, Value crdBuffer) - : SparseLevel(lt, lvlSize, crdBuffer) {} + TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize, + Value crdBuffer) + : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {} ValuePair peekRangeAt(OpBuilder &b, Location l, Value p, Value max) const override { - assert(max == nullptr && "2:4 level can not be non-unique."); - // Each 2:4 block has exactly two specified elements. - Value c2 = C_IDX(2); - return constantRange(b, l, MULI(p, c2), c2); + assert(max == nullptr && isUnique() && "2:4 level can not be non-unique."); + // Each 2:4 blk has exactly two specified elements. + Value posLo = MULI(p, C_IDX(2)); + return {posLo, ADDI(posLo, C_IDX(2))}; + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// File local helpers +//===----------------------------------------------------------------------===// + +static scf::ValueVector genWhenInBound( + OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet, + llvm::function_ref + builder) { + TypeRange ifRetTypes = elseRet.getTypes(); + auto ifOp = b.create(l, ifRetTypes, it.genNotEnd(b, l), true); + + b.setInsertionPointToStart(ifOp.thenBlock()); + Value crd = it.deref(b, l); + scf::ValueVector ret = builder(b, l, crd); + YIELD(ret); + + b.setInsertionPointToStart(ifOp.elseBlock()); + YIELD(elseRet); + + b.setInsertionPointAfter(ifOp); + return ifOp.getResults(); +} + +/// Generates code to compute the *absolute* offset of the slice based on the +/// provide minimum coordinates in the slice. +/// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the +/// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute* +/// offset is the offset computed relative to the initial tensors T. +/// +/// When isNonEmpty == true, the computed offset is meaningless and should not +/// be used during runtime, the method generates code to return 0 currently in +/// that case. +/// +/// offset = minCrd >= size ? minCrd - size + 1 : 0; +static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd, + Value size) { + Value geSize = CMPI(uge, minCrd, size); + // Compute minCrd - size + 1. + Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size); + // This is the absolute offset related to the actual tensor. + return SELECT(geSize, mms, C_IDX(0)); +} + +//===----------------------------------------------------------------------===// +// SparseIterator derived classes. +//===----------------------------------------------------------------------===// + +namespace { + +class TrivialIterator : public SparseIterator { + Value getLoopLo(OpBuilder &b, Location l) const { + // Dense loop are traversed by coordinate, delinearize the position to get + // the coordinate. + if (randomAccessible()) + return SUBI(itPos, posLo); + return itPos; + } + +public: + TrivialIterator(const SparseTensorLevel &stl, + const IterKind kind = IterKind::kTrivial) + : SparseIterator(kind, stl.tid, stl.lvl, itPos), stl(stl) {} + + // For LLVM-style RTTI. + static bool classof(const SparseIterator *from) { + return from->kind == IterKind::kTrivial; + } + + bool randomAccessible() const override { return isDenseLT(stl.getLT()); }; + bool iteratableByFor() const override { return true; }; + Value upperBound(OpBuilder &b, Location l) const override { + return stl.size(); + }; + + SmallVector serialize() const override { + SmallVector ret; + ret.push_back(itPos); + if (randomAccessible()) { + // Loop high is implicit (defined by `upperBound()`) for random-access + // iterator, but we need to memorize posLo for linearization. + ret.push_back(posLo); + } else { + ret.push_back(posHi); + } + return ret; + }; + + void deserialize(ValueRange vs) override { + assert(vs.size() == 2); + seek(vs.front()); + if (randomAccessible()) + posLo = vs.back(); + else + posHi = vs.back(); + }; + + ValuePair getCurPosition() const override { return {itPos, nullptr}; } + + void genInit(OpBuilder &b, Location l, + const SparseIterator *parent) override { + Value pos = C_IDX(0); + Value hi = nullptr; + if (parent) + std::tie(pos, hi) = parent->getCurPosition(); + + std::tie(posLo, posHi) = stl.peekRangeAt(b, l, pos, hi); + // Seek to the lowest position. + seek(posLo); + } + + ValuePair genForCond(OpBuilder &b, Location l) override { + if (randomAccessible()) + return {deref(b, l), upperBound(b, l)}; + return std::make_pair(getLoopLo(b, l), posHi); + } + + Value genNotEnd(OpBuilder &b, Location l) override { + // We used the first level bound as the bound the collapsed set of levels. + return CMPI(ult, itPos, posHi); + } + + Value deref(OpBuilder &b, Location l) override { + if (randomAccessible()) { + updateCrd(SUBI(itPos, posLo)); + } else { + updateCrd(stl.peekCrdAt(b, l, itPos)); + } + return getCrd(); + }; + + ValueRange forward(OpBuilder &b, Location l) override { + seek(ADDI(itPos, C_IDX(1))); + return getItVals(); + } + + ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override { + Value curPos = getItVals().front(); + Value nxPos = forward(b, l).front(); + seek(SELECT(cond, nxPos, curPos)); + return getItVals(); + } + + void locate(OpBuilder &b, Location l, Value crd) override { + assert(randomAccessible()); + // Seek to the linearized position. + seek(ADDI(crd, posLo)); + updateCrd(crd); + } + + Value itPos; // the position that represent the iterator + + Value posLo, posHi; + const SparseTensorLevel &stl; +}; + +class DedupIterator : public SparseIterator { +private: + Value genSegmentHigh(OpBuilder &b, Location l, Value pos); + +public: + DedupIterator(const SparseTensorLevel &stl) + : SparseIterator(IterKind::kDedup, stl.tid, stl.lvl, posAndSegHi), + stl(stl) { + assert(!stl.isUnique()); + } + // For LLVM-style RTTI. + static bool classof(const SparseIterator *from) { + return from->kind == IterKind::kDedup; + } + + bool randomAccessible() const override { return false; }; + bool iteratableByFor() const override { return false; }; + Value upperBound(OpBuilder &b, Location l) const override { + return stl.size(); + }; + + ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; } + + void genInit(OpBuilder &b, Location l, + const SparseIterator *parent) override { + + Value pos = C_IDX(0); + Value hi = nullptr; + if (parent) + std::tie(pos, hi) = parent->getCurPosition(); + + Value posLo; + std::tie(posLo, posHi) = stl.peekRangeAt(b, l, pos, hi); + + seek({posLo, genSegmentHigh(b, l, posLo)}); + } + + SmallVector serialize() const override { + SmallVector ret; + ret.append(getItVals().begin(), getItVals().end()); + ret.push_back(posHi); + return ret; + }; + void deserialize(ValueRange vs) override { + assert(vs.size() == 3); + seek(vs.take_front(getItVals().size())); + posHi = vs.back(); + }; + + Value genNotEnd(OpBuilder &b, Location l) override { + return CMPI(ult, getPos(), posHi); + } + + Value deref(OpBuilder &b, Location l) override { + updateCrd(stl.peekCrdAt(b, l, getPos())); + return getCrd(); + }; + + ValueRange forward(OpBuilder &b, Location l) override { + Value nxPos = getSegHi(); // forward the position to the next segment. + seek({nxPos, genSegmentHigh(b, l, nxPos)}); + return getItVals(); + } + + Value getPos() const { return posAndSegHi[0]; } + Value getSegHi() const { return posAndSegHi[1]; } + + Value posHi; + Value posAndSegHi[2]; // position and segment high + const SparseTensorLevel &stl; +}; + +// +// A filter iterator wrapped from another iterator. The filter iterator update +// the wrapped iterator *in-place*. +// +class FilterIterator : public SparseIterator { + // Coorindate translation between crd loaded from the wrap iterator and the + // filter iterator. + Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const { + // crd = (wrapCrd - offset) / stride + return DIVUI(SUBI(wrapCrd, offset), stride); + } + Value toWrapCrd(OpBuilder &b, Location l, Value crd) const { + // wrapCrd = crd * stride + offset + return ADDI(MULI(crd, stride), offset); + } + + Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd); + + Value genShouldFilter(OpBuilder &b, Location l); + +public: + // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or + // when crd always < size. + FilterIterator(std::unique_ptr &&wrap, Value offset, + Value stride, Value size) + : SparseIterator(IterKind::kFilter, *wrap), offset(offset), + stride(stride), size(size), wrap(std::move(wrap)) {} + + // For LLVM-style RTTI. + static bool classof(const SparseIterator *from) { + return from->kind == IterKind::kFilter; + } + + bool randomAccessible() const override { return wrap->randomAccessible(); }; + bool iteratableByFor() const override { return randomAccessible(); }; + Value upperBound(OpBuilder &b, Location l) const override { return size; }; + + SmallVector serialize() const override { return wrap->serialize(); }; + void deserialize(ValueRange vs) override { wrap->deserialize(vs); }; + ValuePair getCurPosition() const override { return wrap->getCurPosition(); } + + void genInit(OpBuilder &b, Location l, + const SparseIterator *parent) override { + wrap->genInit(b, l, parent); + if (!randomAccessible()) { + // TODO: we can skip this when stride == 1 and offset == 0, we can also + // use binary search here. + forwardIf(b, l, genShouldFilter(b, l)); + } else { + // Else, locate to the slice.offset, which is the first coordinate + // included by the slice. + wrap->locate(b, l, offset); + } + } + + Value genNotEnd(OpBuilder &b, Location l) override; + + Value deref(OpBuilder &b, Location l) override { + updateCrd(fromWrapCrd(b, l, wrap->deref(b, l))); + return getCrd(); + } + + void locate(OpBuilder &b, Location l, Value crd) override { + assert(randomAccessible()); + wrap->locate(b, l, toWrapCrd(b, l, crd)); + updateCrd(crd); + } + + ValueRange forward(OpBuilder &b, Location l) override; + + const Value offset, stride, size; + std::unique_ptr wrap; +}; + +class NonEmptySubSectIterator : public SparseIterator { +public: + using TraverseBuilder = llvm::function_ref; + + NonEmptySubSectIterator(OpBuilder &b, Location l, + const SparseIterator *parent, + std::unique_ptr &&delegate, + Value subSectSz) + : SparseIterator(IterKind::kNonEmptySubSect, delegate->tid, delegate->lvl, + /*itVals=*/subSectMeta), + parent(parent), delegate(std::move(delegate)), + tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) { + auto *p = dyn_cast_or_null(parent); + if (p == nullptr) { + // Extract subsections along the root level. + maxTupleCnt = C_IDX(1); + } else if (p->lvl == lvl) { + // Extract subsections along the same level. + maxTupleCnt = p->maxTupleCnt; + assert(false && "Not implemented."); + } else { + // Extract subsections along the previous level. + assert(p->lvl + 1 == lvl); + maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz); + } + // We don't need an extra buffer to find subsections on dense levels. + if (randomAccessible()) + return; + subSectPosBuf = allocSubSectPosBuf(b, l); + } + + // For LLVM-style RTTI. + static bool classof(const SparseIterator *from) { + return from->kind == IterKind::kNonEmptySubSect; + } + + // The sliced pointer buffer is organized as: + // [[itVal0, itVal1, ..., pNx0], + // [itVal0, itVal1, ..., pNx0], + // ...] + Value allocSubSectPosBuf(OpBuilder &b, Location l) { + return b.create( + l, + MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()), + maxTupleCnt); + } + + void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId, + Value start) const { + b.create(l, start, subSectPosBuf, + ValueRange{tupleId, C_IDX(tupleSz)}); + } + + Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const { + return b.create(l, subSectPosBuf, + ValueRange{tupleId, C_IDX(tupleSz)}); + } + + void storeItVals(OpBuilder &b, Location l, Value tupleId, + ValueRange itVals) const { + assert(itVals.size() == tupleSz); + for (unsigned i = 0; i < tupleSz; i++) { + b.create(l, itVals[i], subSectPosBuf, + ValueRange{tupleId, C_IDX(i)}); + } + } + + SmallVector loadItVals(OpBuilder &b, Location l, Value tupleId) const { + SmallVector ret; + for (unsigned i = 0; i < tupleSz; i++) { + Value v = b.create(l, subSectPosBuf, + ValueRange{tupleId, C_IDX(i)}); + ret.push_back(v); + } + return ret; + } + + bool isSubSectRoot() const { + return !parent || !llvm::isa(parent); + } + + // Generate code that inflate the current subsection tree till the current + // level such that every leaf node is visited. + ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc, + TraverseBuilder builder) const; + + bool randomAccessible() const override { + return delegate->randomAccessible(); + }; + bool iteratableByFor() const override { return randomAccessible(); }; + Value upperBound(OpBuilder &b, Location l) const override { + auto *p = dyn_cast_or_null(parent); + Value parentUB = + p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l); + return ADDI(SUBI(parentUB, subSectSz), C_IDX(1)); + }; + + void genInit(OpBuilder &b, Location l, const SparseIterator *) override; + + void locate(OpBuilder &b, Location l, Value crd) override { + Value absOff = crd; + + if (isSubSectRoot()) + delegate->locate(b, l, absOff); + else + assert(parent->lvl + 1 == lvl); + + seek(ValueRange{absOff, absOff, C_TRUE}); + updateCrd(crd); + } + + Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const { + return SUBI(wrapCrd, getAbsOff()); + } + + Value genNotEnd(OpBuilder &b, Location l) override { return getNotEnd(); }; + + Value deref(OpBuilder &b, Location l) override { + // Use the relative offset to coiterate. + Value crd; + auto *p = dyn_cast_or_null(parent); + if (p && p->lvl == lvl) + crd = SUBI(getAbsOff(), p->getAbsOff()); + crd = getAbsOff(); + + updateCrd(crd); + return crd; + }; + + ValueRange forward(OpBuilder &b, Location l) override; + + Value getMinCrd() const { return subSectMeta[0]; } + Value getAbsOff() const { return subSectMeta[1]; } + Value getNotEnd() const { return subSectMeta[2]; } + + const SparseIterator *parent; + std::unique_ptr delegate; + + // Number of values required to serialize the wrapped iterator. + const unsigned tupleSz; + // Max number of tuples, and the actual number of tuple. + Value maxTupleCnt, tupleCnt; + // The memory used to cache the tuple serialized from the wrapped iterator. + Value subSectPosBuf; + + const Value subSectSz; + + Value subSectMeta[3]; // minCrd, absolute offset, notEnd +}; + +class SubSectIterator; + +// A wrapper that helps generating code to traverse a subsection, used +// by both `NonEmptySubSectIterator`and `SubSectIterator`. +struct SubSectIterHelper { + explicit SubSectIterHelper(const SubSectIterator &iter); + explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect); + + // Delegate methods. + void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId); + void locate(OpBuilder &b, Location l, Value crd); + Value genNotEnd(OpBuilder &b, Location l); + Value deref(OpBuilder &b, Location l); + ValueRange forward(OpBuilder &b, Location l); + + const NonEmptySubSectIterator &subSect; + SparseIterator &wrap; +}; + +class SubSectIterator : public SparseIterator { + // RAII to sync iterator values between the wrap the iterator and the + // SubSectIterator. + struct WrapItValSyncer { + explicit WrapItValSyncer(SubSectIterator &it) : it(it) { + if (!it.randomAccessible()) + it.wrap->seek(it.getItVals().drop_back()); + } + ~WrapItValSyncer() { + if (!it.randomAccessible()) { + ValueRange wrapItVals = it.wrap->getItVals(); + std::copy(wrapItVals.begin(), wrapItVals.end(), it.itVals.begin()); + } + } + SubSectIterator ⁢ + }; + +public: + SubSectIterator(const NonEmptySubSectIterator &subSect, + const SparseIterator &parent, + std::unique_ptr &&wrap, Value size, + unsigned stride) + : SparseIterator(IterKind::kSubSect, *wrap), itVals(), subSect(subSect), + wrap(std::move(wrap)), parent(parent), size(size), stride(stride), + helper(*this) { + assert(stride == 1 && "Not implemented."); + assert(subSect.tid == tid && subSect.lvl == lvl); + assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl); + + if (!randomAccessible()) { + // We maintain a extra counter to count the actually sparse coordinate + // included in the subsection. + unsigned itValSz = this->wrap->getItVals().size() + 1; + itVals.resize(itValSz, nullptr); + relinkItVals(itVals); + } + }; + + // For LLVM-style RTTI. + static bool classof(const SparseIterator *from) { + return from->kind == IterKind::kSubSect; + } + + bool randomAccessible() const override { return wrap->randomAccessible(); }; + bool iteratableByFor() const override { return randomAccessible(); }; + Value upperBound(OpBuilder &b, Location l) const override { return size; } + std::pair getCurPosition() const override { + return wrap->getCurPosition(); + }; + + Value getNxLvlTupleId(OpBuilder &b, Location l) const { + if (randomAccessible()) { + return ADDI(getCrd(), nxLvlTupleStart); + }; + return ADDI(itVals.back(), nxLvlTupleStart); + } + + void genInit(OpBuilder &b, Location l, const SparseIterator *) override { + WrapItValSyncer syncer(*this); + if (randomAccessible()) { + if (auto *p = llvm::dyn_cast(&parent)) { + assert(p->lvl + 1 == lvl); + wrap->genInit(b, l, p); + // Linearize the dense subsection index. + nxLvlTupleStart = MULI(size, p->getNxLvlTupleId(b, l)); + } else { + assert(subSect.lvl == lvl && subSect.isSubSectRoot()); + wrap->deserialize(subSect.delegate->serialize()); + nxLvlTupleStart = C_IDX(0); + } + return; + } + assert(!randomAccessible()); + assert(itVals.size() == wrap->getItVals().size() + 1); + // Extra counter that counts the number of actually visited coordinates in + // the sparse subsection. + itVals.back() = C_IDX(0); + Value tupleId; + if (auto *p = llvm::dyn_cast(&parent)) { + assert(p->lvl + 1 == lvl); + tupleId = p->getNxLvlTupleId(b, l); + } else { + assert(subSect.lvl == lvl && subSect.isSubSectRoot()); + tupleId = C_IDX(0); + } + nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId); + helper.deserializeFromTupleId(b, l, tupleId); + } + + void locate(OpBuilder &b, Location l, Value crd) override { + WrapItValSyncer syncer(*this); + helper.locate(b, l, crd); + updateCrd(crd); + } + + Value genNotEnd(OpBuilder &b, Location l) override { + WrapItValSyncer syncer(*this); + return helper.genNotEnd(b, l); } + + Value deref(OpBuilder &b, Location l) override { + WrapItValSyncer syncer(*this); + Value crd = helper.deref(b, l); + updateCrd(crd); + return crd; + }; + + ValueRange forward(OpBuilder &b, Location l) override { + { + WrapItValSyncer syncer(*this); + helper.forward(b, l); + } + assert(!randomAccessible()); + assert(itVals.size() == wrap->getItVals().size() + 1); + itVals.back() = ADDI(itVals.back(), C_IDX(1)); + return getItVals(); + }; + + SmallVector itVals; + Value nxLvlTupleStart; + + const NonEmptySubSectIterator &subSect; + std::unique_ptr wrap; + const SparseIterator &parent; + + Value size; + unsigned stride; + + SubSectIterHelper helper; }; } // namespace +//===----------------------------------------------------------------------===// +// SparseIterator derived classes implementation. +//===----------------------------------------------------------------------===// + +ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) { + auto ifOp = b.create(l, getItVals().getTypes(), cond, true); + // Generate else branch first, otherwise iterator values will be updated by + // `forward()`. + b.setInsertionPointToStart(ifOp.elseBlock()); + YIELD(getItVals()); + + b.setInsertionPointToStart(ifOp.thenBlock()); + YIELD(forward(b, l)); + + b.setInsertionPointAfter(ifOp); + seek(ifOp.getResults()); + return getItVals(); +} + +Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) { + auto whileOp = b.create( + l, pos.getType(), pos, + /*beforeBuilder=*/ + [this, pos](OpBuilder &b, Location l, ValueRange ivs) { + Value inBound = CMPI(ult, ivs.front(), posHi); + auto ifInBound = b.create(l, b.getI1Type(), inBound, true); + { + OpBuilder::InsertionGuard guard(b); + // If in bound, load the next coordinates and check duplication. + b.setInsertionPointToStart(ifInBound.thenBlock()); + Value headCrd = stl.peekCrdAt(b, l, pos); + Value tailCrd = stl.peekCrdAt(b, l, ivs.front()); + Value isDup = CMPI(eq, headCrd, tailCrd); + YIELD(isDup); + // Else, the position is out of bound, yield false. + b.setInsertionPointToStart(ifInBound.elseBlock()); + YIELD(constantI1(b, l, false)); + } + b.create(l, ifInBound.getResults()[0], ivs); + }, + /*afterBuilder=*/ + [](OpBuilder &b, Location l, ValueRange ivs) { + Value nxPos = ADDI(ivs[0], C_IDX(1)); + YIELD(nxPos); + }); + // Return the segment high. + return whileOp.getResult(0); +} + +Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l, + Value wrapCrd) { + Value crd = fromWrapCrd(b, l, wrapCrd); + // Test whether the coordinate is on stride. + Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd); + // Test wrapCrd < offset + notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit); + // Test crd >= length + notlegit = ORI(CMPI(uge, crd, size), notlegit); + return notlegit; +} + +Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) { + auto r = genWhenInBound( + b, l, *wrap, C_FALSE, + [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector { + Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd); + return {notLegit}; + }); + + assert(r.size() == 1); + return r.front(); +} + +Value FilterIterator::genNotEnd(OpBuilder &b, Location l) { + assert(!wrap->randomAccessible()); + auto r = genWhenInBound( + b, l, *wrap, C_FALSE, + [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector { + Value crd = fromWrapCrd(b, l, wrapCrd); + // crd < size + return {CMPI(ult, crd, size)}; + }); + assert(r.size() == 1); + return r.front(); +} + +ValueRange FilterIterator::forward(OpBuilder &b, Location l) { + assert(!randomAccessible()); + // Generates + // + // bool isFirst = true; + // while !it.end() && (!legit(*it) || isFirst) + // wrap ++; + // isFirst = false; + // + // We do not hoist the first `wrap++` outside the loop but use a `isFirst` + // flag here because `wrap++` might have a complex implementation (e.g., to + // forward a subsection). + Value isFirst = constantI1(b, l, true); + + SmallVector whileArgs(getItVals().begin(), getItVals().end()); + whileArgs.push_back(isFirst); + + auto whileOp = b.create( + l, ValueRange(whileArgs).getTypes(), whileArgs, + /*beforeBuilder=*/ + [this](OpBuilder &b, Location l, ValueRange ivs) { + ValueRange isFirst = linkNewScope(ivs); + assert(isFirst.size() == 1); + ValueRange cont = + genWhenInBound(b, l, *wrap, C_FALSE, + [this, isFirst](OpBuilder &b, Location l, + Value wrapCrd) -> scf::ValueVector { + // crd < size && !legit(); + Value notLegit = + genCrdNotLegitPredicate(b, l, wrapCrd); + Value crd = fromWrapCrd(b, l, wrapCrd); + Value ret = ANDI(CMPI(ult, crd, size), notLegit); + ret = ORI(ret, isFirst.front()); + return {ret}; + }); + b.create(l, cont.front(), ivs); + }, + /*afterBuilder=*/ + [this](OpBuilder &b, Location l, ValueRange ivs) { + linkNewScope(ivs); + wrap->forward(b, l); + SmallVector yieldVals(getItVals().begin(), getItVals().end()); + yieldVals.push_back(constantI1(b, l, false)); + YIELD(yieldVals); + }); + + b.setInsertionPointAfter(whileOp); + linkNewScope(whileOp.getResults()); + return getItVals(); +} + +SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect) + : subSect(subSect), wrap(*subSect.delegate) {} + +SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter) + : subSect(iter.subSect), wrap(*iter.wrap) {} + +void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l, + Value tupleId) { + assert(!subSect.randomAccessible()); + wrap.deserialize(subSect.loadItVals(b, l, tupleId)); +} + +void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) { + Value absCrd = ADDI(crd, subSect.getAbsOff()); + wrap.locate(b, l, absCrd); +} + +Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) { + assert(!wrap.randomAccessible()); + auto r = genWhenInBound( + b, l, wrap, C_FALSE, + [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector { + Value crd = SUBI(wrapCrd, subSect.getAbsOff()); + // crd < size + return {CMPI(ult, crd, subSect.subSectSz)}; + }); + assert(r.size() == 1); + return r.front(); +} + +Value SubSectIterHelper::deref(OpBuilder &b, Location l) { + Value wrapCrd = wrap.deref(b, l); + Value crd = subSect.toSubSectCrd(b, l, wrapCrd); + return crd; +} + +ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) { + return wrap.forward(b, l); +} + +ValueRange NonEmptySubSectIterator::inflateSubSectTree( + OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const { + // Set up the helper to help traverse a sparse subsection. + SubSectIterHelper helper(*this); + if (!randomAccessible()) { + // The subsection tree have been expanded till the level and cached, + // traverse all the leaves and expanded to the next level. + SmallVector iterArgs; + iterArgs.push_back(C_IDX(0)); + iterArgs.append(reduc.begin(), reduc.end()); + auto forEachLeaf = b.create( + l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs, + [&helper, &builder](OpBuilder &b, Location l, Value tupleId, + ValueRange iterArgs) { + // Deserialize the iterator at the cached position (tupleId). + helper.deserializeFromTupleId(b, l, tupleId); + + Value cnt = iterArgs.front(); + // Record the number of leaf nodes included in the subsection. + // The number indicates the starting tupleId for the next level that + // is corresponding to the current node. + helper.subSect.storeNxLvlStart(b, l, tupleId, cnt); + + SmallVector whileArgs(helper.wrap.getItVals()); + whileArgs.append(iterArgs.begin(), iterArgs.end()); + + auto whileOp = b.create( + l, ValueRange(whileArgs).getTypes(), whileArgs, + /*beforeBuilder=*/ + [&helper](OpBuilder &b, Location l, ValueRange ivs) { + helper.wrap.linkNewScope(ivs); + b.create(l, helper.genNotEnd(b, l), ivs); + }, + /*afterBuilder=*/ + [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) { + ValueRange remIter = helper.wrap.linkNewScope(ivs); + Value cnt = remIter.front(); + ValueRange userIter = remIter.drop_front(); + scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter); + + SmallVector nxIter = helper.forward(b, l); + nxIter.push_back(ADDI(cnt, C_IDX(1))); + nxIter.append(userNx.begin(), userNx.end()); + YIELD(nxIter); + }); + ValueRange res = helper.wrap.linkNewScope(whileOp.getResults()); + YIELD(res); + }); + return forEachLeaf.getResults().drop_front(); + } + + assert(randomAccessible()); + // Helper lambda that traverse the current dense subsection range. + auto visitDenseSubSect = [&, this](OpBuilder &b, Location l, + const SparseIterator *parent, + ValueRange reduc) { + assert(!parent || parent->lvl + 1 == lvl); + delegate->genInit(b, l, parent); + auto forOp = b.create( + l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc, + [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) { + helper.locate(b, l, crd); + scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs); + YIELD(nx); + }); + return forOp.getResults(); + }; + + if (isSubSectRoot()) { + return visitDenseSubSect(b, l, parent, reduc); + } + // Else, this is not the root, recurse until root. + auto *p = llvm::cast(parent); + assert(p->lvl + 1 == lvl); + return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect); +} + +void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l, + const SparseIterator *) { + Value c0 = C_IDX(0); + if (!isSubSectRoot()) { + assert(parent->lvl + 1 == lvl); + if (randomAccessible()) { + // We can not call wrap->genInit() here to initialize the wrapped + // iterator, because the parent of the curent iterator is still + // unresolved. + seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE}); + return; + } + + auto *p = cast(parent); + SmallVector reduc = { + C_IDX(-1), // minCrd (max signless integer) + c0, // tupleId + }; + + // Expand the subsection tree from the parent level to the current level. + ValueRange result = p->inflateSubSectTree( + b, l, reduc, + [this](OpBuilder &b, Location l, const SparseIterator *parent, + ValueRange reduc) -> scf::ValueVector { + assert(parent->lvl + 1 == lvl && reduc.size() == 2); + Value minCrd = reduc.front(); + Value tupleId = reduc.back(); + + // Initialize the subsection range. + SubSectIterHelper helper(*this); + helper.wrap.genInit(b, l, parent); + + // Update minCrd. + minCrd = genWhenInBound(b, l, helper.wrap, minCrd, + [minCrd](OpBuilder &b, Location l, + Value crd) -> scf::ValueVector { + Value min = MINUI(crd, minCrd); + return {min}; + }) + .front(); + + // Cache the sparse range. + storeItVals(b, l, tupleId, helper.wrap.serialize()); + tupleId = ADDI(tupleId, C_IDX(1)); + return {minCrd, tupleId}; + }); + assert(result.size() == 2); + tupleCnt = result.back(); + + Value minCrd = result.front(); + Value absOff = offsetFromMinCrd(b, l, minCrd, subSectSz); + Value notEnd = CMPI(ne, minCrd, C_IDX(-1)); + seek({minCrd, absOff, notEnd}); + return; + } + + // This is the root level of the subsection, which means that it is resolved + // to one node. + assert(isSubSectRoot()); + + // Initialize the position, the position marks the *lower bound* of the + // subRange. The higher bound is determined by the size of the subsection. + delegate->genInit(b, l, parent); + if (randomAccessible()) { + seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE}); + return; + } + + // Only have one root node. + tupleCnt = C_IDX(1); + // Cache the sparse range. + storeItVals(b, l, c0, delegate->serialize()); + SmallVector elseRet{c0, c0, /*notEnd=*/C_FALSE}; + auto meta = genWhenInBound( + b, l, *delegate, elseRet, + [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector { + Value offset = offsetFromMinCrd(b, l, crd, subSectSz); + return {crd, offset, C_TRUE}; + }); + + seek(meta); +} + +ValueRange NonEmptySubSectIterator::forward(OpBuilder &b, Location l) { + assert(!randomAccessible()); + Value c0 = C_IDX(0), c1 = C_IDX(1); + // Forward to the next non empty slice by generating + // + // if (minCrd > offset) { + // offset += 1 + // } else { + // minCrd = nextMinInSlice(); + // offset = minCrd - size + 1; + // } + // + // if (offset + size > parents.size) + // isNonEmpty = false; + Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff()); + auto ifOp = b.create(l, getItVals().getTypes(), fastPathP, true); + { + OpBuilder::InsertionGuard guard(b); + // Take the fast path + // if (minCrd > offset) + // offset += 1 + b.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value nxOffset = ADDI(getAbsOff(), c1); + YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()})); + + // else /*minCrd == offset*/ { + // for (i = 0; i < tupleCnt; i++) { + // wrap->deserialize(pos[i]); + // minCrd=min(minCrd, *wrap); + // } + // offset = minCrd - size + 1; + // } + b.setInsertionPointToStart(&ifOp.getElseRegion().front()); + SmallVector loopArgs{C_IDX(-1), // nextMinCrd + C_FALSE}; // isNotEnd + auto loopNest = scf::buildLoopNest( + b, l, c0, tupleCnt, c1, loopArgs, + [this](OpBuilder &b, Location l, ValueRange ivs, + ValueRange iterArgs) -> scf::ValueVector { + Value tupleId = ivs.front(); + SubSectIterHelper helper(*this); + helper.deserializeFromTupleId(b, l, tupleId); + + return genWhenInBound( + b, l, *delegate, /*elseRet=*/iterArgs, + [this, iterArgs, tupleId](OpBuilder &b, Location l, + Value crd) -> scf::ValueVector { + // if coord == minCrd + // wrap->forward(); + Value isMin = CMPI(eq, crd, getMinCrd()); + delegate->forwardIf(b, l, isMin); + // Update the forwarded iterator values if needed. + auto ifIsMin = b.create(l, isMin, false); + b.setInsertionPointToStart(&ifIsMin.getThenRegion().front()); + storeItVals(b, l, tupleId, delegate->serialize()); + b.setInsertionPointAfter(ifIsMin); + // if (!wrap.end()) + // yield(min(nxMinCrd, *wrap), true) + Value nxMin = iterArgs[0]; + return genWhenInBound(b, l, *delegate, /*elseRet=*/iterArgs, + [nxMin](OpBuilder &b, Location l, + Value crd) -> scf::ValueVector { + Value nx = b.create( + l, crd, nxMin); + return {nx, C_TRUE}; + }); + }); + }); + + scf::ForOp forOp = loopNest.loops.front(); + b.setInsertionPointAfter(forOp); + + Value nxMinCrd = forOp.getResult(0); + Value nxNotEnd = forOp.getResult(1); + Value nxAbsOff = offsetFromMinCrd(b, l, nxMinCrd, subSectSz); + YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd})); + } + + Value nxMinCrd = ifOp.getResult(0); + Value nxAbsOff = ifOp.getResult(1); + Value nxNotEnd = ifOp.getResult(2); + + // We should at least forward the offset by one. + Value minAbsOff = ADDI(getAbsOff(), c1); + nxAbsOff = b.create(l, minAbsOff, nxAbsOff); + + seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}); + // The coordinate should not exceeds the space upper bound. + Value crd = deref(b, l); + nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l))); + + seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}); + return getItVals(); +} + +//===----------------------------------------------------------------------===// +// SparseIterator factory functions. +//===----------------------------------------------------------------------===// + std::unique_ptr -sparse_tensor::makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t, - Level l) { +sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t, + unsigned tid, Level lvl) { auto stt = getSparseTensorType(t); - LevelType lt = stt.getLvlType(l); - Value lvlSz = stt.hasEncoding() - ? builder.create(loc, t, l).getResult() - : builder.create(loc, t, l).getResult(); + LevelType lt = stt.getLvlType(lvl); + Value sz = stt.hasEncoding() ? b.create(l, t, lvl).getResult() + : b.create(l, t, lvl).getResult(); switch (*getLevelFormat(lt)) { case LevelFormat::Dense: - return std::make_unique(lvlSz); + return std::make_unique(tid, lvl, sz, stt.hasEncoding()); case LevelFormat::Compressed: { - Value posBuf = genToPositions(builder, loc, t, l); - Value crdBuf = genToCoordinates(builder, loc, t, l); - return std::make_unique(lt, lvlSz, posBuf, crdBuf); + Value pos = genToPositions(b, l, t, lvl); + Value crd = genToCoordinates(b, l, t, lvl); + return std::make_unique(tid, lvl, lt, sz, pos, crd); } case LevelFormat::LooseCompressed: { - Value posBuf = genToPositions(builder, loc, t, l); - Value crdBuf = genToCoordinates(builder, loc, t, l); - return std::make_unique(lt, lvlSz, posBuf, crdBuf); + Value pos = genToPositions(b, l, t, lvl); + Value crd = genToCoordinates(b, l, t, lvl); + return std::make_unique(tid, lvl, lt, sz, pos, crd); } case LevelFormat::Singleton: { - Value crdBuf = genToCoordinates(builder, loc, t, l); - return std::make_unique(lt, lvlSz, crdBuf); + Value crd = genToCoordinates(b, l, t, lvl); + return std::make_unique(tid, lvl, lt, sz, crd); } case LevelFormat::TwoOutOfFour: { - Value crdBuf = genToCoordinates(builder, loc, t, l); - return std::make_unique(lt, lvlSz, crdBuf); + Value crd = genToCoordinates(b, l, t, lvl); + return std::make_unique(tid, lvl, lt, sz, crd); } } llvm_unreachable("unrecognizable level format"); } +std::pair, std::unique_ptr> +sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl) { + auto stl = std::make_unique(tid, lvl, sz, /*encoded=*/false); + auto it = std::make_unique(*stl); + return std::make_pair(std::move(stl), std::move(it)); +} + +std::unique_ptr +sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl) { + if (!isUniqueLT(stl.getLT())) { + // We always dedupliate the non-unique level, but we should optimize it away + // if possible. + return std::make_unique(stl); + } + return std::make_unique(stl); +} + +std::unique_ptr +sparse_tensor::makeSlicedLevelIterator(std::unique_ptr &&sit, + Value offset, Value stride, Value size) { + + return std::make_unique(std::move(sit), offset, stride, size); +} + +template +static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) { + auto *filter = llvm::dyn_cast_or_null(it); + if (filter && llvm::isa(filter->wrap.get())) { + return filter->wrap.get(); + } + return it; +} +template +static const IterType *unwrapFilter(const SparseIterator *it) { + auto *filter = llvm::dyn_cast_or_null(it); + if (filter) { + return llvm::cast(filter->wrap.get()); + } + return llvm::cast(it); +} + +std::unique_ptr sparse_tensor::makeNonEmptySubSectIterator( + OpBuilder &b, Location l, const SparseIterator *parent, + std::unique_ptr &&delegate, Value size, unsigned stride) { + + // Try unwrap the NonEmptySubSectIterator from a filter parent. + parent = tryUnwrapFilter(parent); + auto it = std::make_unique( + b, l, parent, std::move(delegate), size); + + if (stride != 1) + return std::make_unique(std::move(it), /*offset=*/C_IDX(0), + C_IDX(stride), /*size=*/C_IDX(-1)); + return it; +} + +std::unique_ptr sparse_tensor::makeTraverseSubSectIterator( + const SparseIterator &subSectIter, const SparseIterator &parent, + std::unique_ptr &&wrap, Value size, unsigned stride) { + // This must be a subsection iterator or a filtered subsection iterator. + auto &subSect = *unwrapFilter(&subSectIter); + return std::make_unique(subSect, parent, std::move(wrap), + size, stride); +} + #undef CMPI #undef C_IDX #undef YIELD diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h index f5c29cda7c54f..08f7c6a747eb5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h @@ -14,6 +14,9 @@ namespace mlir { namespace sparse_tensor { +/// The base class for all types of sparse tensor levels. It provides interfaces +/// to query the loop range (see `peekRangeAt`) and look up the coordinates (see +/// `peekCrdAt`). class SparseTensorLevel { SparseTensorLevel(SparseTensorLevel &&) = delete; SparseTensorLevel(const SparseTensorLevel &) = delete; @@ -21,42 +24,236 @@ class SparseTensorLevel { SparseTensorLevel &operator=(const SparseTensorLevel &) = delete; public: - SparseTensorLevel() : SparseTensorLevel(LevelType::Undef, nullptr){}; virtual ~SparseTensorLevel() = default; - virtual Value peekCrdAt(OpBuilder &b, Location l, Value p) const = 0; + virtual Value peekCrdAt(OpBuilder &b, Location l, Value iv) const = 0; /// Peeks the lower and upper bound to *fully* traverse the level with /// the given position `p` that the immediate parent level is current at. + /// Returns a pair of values for *posLo* and *loopHi* respectively. + /// + /// For a dense level, the *posLo* is the linearized position at beginning, + /// while *loopHi* is the largest *coordinate*, it also implies that the + /// smallest *coordinate* to start the loop is 0. + /// + /// For a sparse level, [posLo, loopHi) specifies the range of index pointer + /// to load coordinate from the coordinate buffer. + /// /// `bound` is only used when the level is `non-unique` and deduplication is /// required. It specifies the max upper bound of the non-unique segment. virtual std::pair peekRangeAt(OpBuilder &b, Location l, Value p, - Value bound = Value()) const = 0; + Value segHi = Value()) const = 0; + Level getLevel() const { return lvl; } LevelType getLT() const { return lt; } - Value getPos() const { return pos; } - Value getCrd() const { return crd; } - Value getLoopHi() const { return loopHi; } - Value getLoopLo() const { return loopLo; } + Value size() const { return lvlSize; } + + // + // Level properties + // + bool isUnique() const { return isUniqueLT(lt); } protected: - SparseTensorLevel(LevelType lt, Value lvlSize) - : lt(lt), lvlSize(lvlSize), pos(nullptr), crd(nullptr), loopHi(nullptr), - loopLo(nullptr){}; + SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize) + : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize){}; +public: + const unsigned tid, lvl; const LevelType lt; const Value lvlSize; +}; + +enum class IterKind : uint8_t { + kTrivial, + kDedup, + kSubSect, + kNonEmptySubSect, + kFilter, +}; + +/// Helper class that generates loop conditions, etc, to traverse a +/// sparse tensor level. +class SparseIterator { + SparseIterator(SparseIterator &&) = delete; + SparseIterator(const SparseIterator &) = delete; + SparseIterator &operator=(SparseIterator &&) = delete; + SparseIterator &operator=(const SparseIterator &) = delete; + +protected: + SparseIterator(IterKind kind, unsigned tid, unsigned lvl, + MutableArrayRef itVals) + : kind(kind), tid(tid), lvl(lvl), crd(nullptr), itVals(itVals){}; + + SparseIterator(IterKind kind, const SparseIterator &wrap) + : kind(kind), tid(wrap.tid), lvl(wrap.lvl), crd(nullptr), + itVals(wrap.itVals){}; + +public: + virtual ~SparseIterator() = default; + + Value getCrd() const { return crd; } + ValueRange getItVals() const { return itVals; }; + + // Sets the iterate to the specified position. + void seek(ValueRange vals) { + assert(vals.size() == itVals.size()); + std::copy(vals.begin(), vals.end(), itVals.begin()); + // Now that the iterator is re-positioned, the coordinate becomes invalid. + crd = nullptr; + } + + // + // Iterator properties. + // + + // Whether the iterator support random access (i.e., support look up by + // *coordinate*). A random access iterator must also traverses a dense space. + virtual bool randomAccessible() const = 0; + + // Whether the iterator can simply traversed by a for loop. + virtual bool iteratableByFor() const { return false; }; + + // Get the upper bound of the sparse space that the iterator might visited. A + // sparse space is a subset of a dense space [0, bound), this function returns + // *bound*. + virtual Value upperBound(OpBuilder &b, Location l) const = 0; + + // Serializes and deserializes the current status to/from a set of values. The + // ValueRange should contain values that specifies the current postion and + // loop bound. + // + // Not every type of iterator supports the operations, e.g., non-empty + // subsection iterator does not because the the number of non-empty + // subsections can not be determined easily. + // + // NOTE: All the values should have index type. + virtual SmallVector serialize() const { + llvm_unreachable("unsupported"); + }; + virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported"); }; -public: // TODO: make these values private upon feature complete. - Value pos; - Value crd; - Value loopHi; - Value loopLo; + // + // Core functions. + // + + // Gets the current position and the optional *position high* (for non-unique + // iterators), the value is essentially the number of sparse coordinate that + // the iterator is current visiting. It should be able to uniquely identify + // the sparse range for the next level. See SparseTensorLevel::peekRangeAt(); + // + // Not every type of iterator supports the operation, e.g., non-empty + // subsection iterator does not because it represent a range of coordinates + // instead of just one. + virtual std::pair getCurPosition() const { + llvm_unreachable("unsupported"); + }; + + // Initializes the iterator according to the parent iterator's state. + virtual void genInit(OpBuilder &, Location, const SparseIterator *) = 0; + + // Returns a pair of values for *upper*, *lower* bound respectively. + virtual std::pair genForCond(OpBuilder &b, Location l) { + assert(randomAccessible()); + // Random-access iterator is traversed by coordinate, i.e., [curCrd, UB). + return {getCrd(), upperBound(b, l)}; + } + + // Returns a boolean value that equals `!it.end()` + virtual Value genNotEnd(OpBuilder &b, Location l) = 0; + std::pair genWhileCond(OpBuilder &b, Location l, + ValueRange vs) { + ValueRange rem = linkNewScope(vs); + return std::make_pair(genNotEnd(b, l), rem); + } + + // Dereference the iterator, loads the coordinate at the current position. + // + // The method assumes that the iterator is not currently exhausted (i.e., + // it != it.end()). + virtual Value deref(OpBuilder &b, Location l) = 0; + + virtual ValueRange forward(OpBuilder &b, Location l) = 0; + + // Generate a conditional it.next() in the following form + // + // if (cond) + // yield it.next + // else + // yield it + // + // The function is virtual to allow alternative implementation. For example, + // if it.next() is trivial to compute, we can use a select operation instead. + // E.g., + // + // it = select cond ? it+1 : it + virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond); + + // Locate the iterator to the position specified by *crd*, this can only + // be done on an iterator that supports randm access. + virtual void locate(OpBuilder &b, Location l, Value crd) { + llvm_unreachable("Unsupported"); + } + + // Update the SSA value for the iterator after entering a new scope. + ValueRange linkNewScope(ValueRange pos) { + assert(!randomAccessible() && "random accessible iterators are traversed " + "by coordinate, call locate() instead."); + seek(pos.take_front(itVals.size())); + return pos.drop_front(itVals.size()); + }; + +protected: + void updateCrd(Value crd) { this->crd = crd; } + void relinkItVals(MutableArrayRef itVals) { this->itVals = itVals; } + +public: + const IterKind kind; // For LLVM-style RTTI. + const unsigned tid, lvl; // tensor level identifier. + +private: + Value crd; // The sparse coordinate used to coiterate; + + // A range of value that together defines the current state of the + // iterator. Only loop variants should be included. + // + // For trivial iterators, it is the position; for dedup iterators, it consists + // of the positon and the segment high, for non-empty subsection iterator, it + // is the metadata that specifies the subsection. + MutableArrayRef itVals; }; /// Helper function to create a TensorLevel object from given `tensor`. -std::unique_ptr -makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t, Level l); +std::unique_ptr makeSparseTensorLevel(OpBuilder &builder, + Location loc, Value t, + unsigned tid, Level l); + +/// Helper function to create a simple SparseIterator object that iterate over +/// the SparseTensorLevel. +std::unique_ptr +makeSimpleIterator(const SparseTensorLevel &stl); + +/// Helper function to create a synthetic SparseIterator object that iterate +/// over a dense space specified by [0,`sz`). +std::pair, std::unique_ptr> +makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl); + +/// Helper function to create a SparseIterator object that iterate over a +/// sliced space, the orignal space (before slicing) is traversed by `sit`. +std::unique_ptr +makeSlicedLevelIterator(std::unique_ptr &&sit, Value offset, + Value stride, Value size); + +/// Helper function to create a SparseIterator object that iterate over the +/// non-empty subsections set. +std::unique_ptr makeNonEmptySubSectIterator( + OpBuilder &b, Location l, const SparseIterator *parent, + std::unique_ptr &&delegate, Value size, unsigned stride); + +/// Helper function to create a SparseIterator object that iterate over a +/// non-empty subsection created by NonEmptySubSectIterator. +std::unique_ptr makeTraverseSubSectIterator( + const SparseIterator &subsectIter, const SparseIterator &parent, + std::unique_ptr &&delegate, Value size, unsigned stride); } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/test/Dialect/SparseTensor/dense.mlir b/mlir/test/Dialect/SparseTensor/dense.mlir index 2d8dcfea9adc1..60a217e05e61e 100644 --- a/mlir/test/Dialect/SparseTensor/dense.mlir +++ b/mlir/test/Dialect/SparseTensor/dense.mlir @@ -42,9 +42,9 @@ // CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref // CHECK: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32> // CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { -// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_10]] : index +// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : index // CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_13]], %[[VAL_2]] : f32 // CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32> @@ -82,9 +82,9 @@ func.func @dense1(%arga: tensor<32x16xf32, #DenseMatrix>, // CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_0]] : memref<32x16xf32> // CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref // CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { -// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_10]] : index +// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : index // CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32> // CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_13]], %[[VAL_2]] : f32 // CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref @@ -125,9 +125,9 @@ func.func @dense2(%arga: tensor<32x16xf32>, // CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_0]] : memref<32x16x8xf32> // CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref // CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { -// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_10]] : index +// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : index // CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref // CHECK: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) { // CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]], %[[VAL_15]]] : memref<32x16x8xf32> diff --git a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir index 91e7920b3a903..2b9a2dd8f4883 100644 --- a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir +++ b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir @@ -1,3 +1,4 @@ +// TODO: re-enable after lowering coo.next to function call (such that loop structure is more clear). // RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --canonicalize | FileCheck %s #SortedCOO = #sparse_tensor.encoding<{ @@ -37,47 +38,47 @@ // // CHECK-LABEL: func.func @sparse_scale( -// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor to memref -// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor to memref> -// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor to memref -// CHECK-DAG: %[[VAL_8:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref -// CHECK-DAG: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref -// CHECK: %[[VAL_10:.*]] = scf.while (%[[VAL_11:.*]] = %[[VAL_8]]) : (index) -> index { -// CHECK: %[[VAL_12:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_9]] : index -// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_11]] : index -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_13:.*]]: index): -// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref> -// CHECK: %[[VAL_15:.*]] = scf.while (%[[VAL_16:.*]] = %[[VAL_13]]) : (index) -> index { -// CHECK: %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_9]] : index -// CHECK: %[[VAL_18:.*]] = scf.if %[[VAL_17]] -> (i1) { -// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_16]]] : memref> -// CHECK: %[[VAL_20:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_14]] : index -// CHECK: scf.yield %[[VAL_20]] : i1 -// CHECK: } else { -// CHECK: scf.yield %[[VAL_1]] : i1 -// CHECK: } -// CHECK: scf.condition(%[[VAL_21:.*]]) %[[VAL_16]] : index -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_22:.*]]: index): -// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_22]], %[[VAL_3]] : index -// CHECK: scf.yield %[[VAL_23]] : index -// CHECK: } -// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_13]] to %[[VAL_25:.*]] step %[[VAL_3]] { -// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref -// CHECK: %[[VAL_27:.*]] = arith.mulf %[[VAL_26]], %[[VAL_4]] : f32 -// CHECK: memref.store %[[VAL_27]], %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref -// CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: scf.yield %[[VAL_28:.*]] : index -// CHECK: } attributes {"Emitted from" = "linalg.generic"} -// CHECK: %[[VAL_29:.*]] = sparse_tensor.load %[[VAL_0]] : tensor -// CHECK: return %[[VAL_29]] : tensor -// CHECK: } +// C_HECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// C_HECK-DAG: %[[VAL_1:.*]] = arith.constant false +// C_HECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index +// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index +// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 2.000000e+00 : f32 +// C_HECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor to memref +// C_HECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor to memref> +// C_HECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor to memref +// C_HECK-DAG: %[[VAL_8:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref +// C_HECK-DAG: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// C_HECK: %[[VAL_10:.*]] = scf.while (%[[VAL_11:.*]] = %[[VAL_8]]) : (index) -> index { +// C_HECK: %[[VAL_12:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_9]] : index +// C_HECK: scf.condition(%[[VAL_12]]) %[[VAL_11]] : index +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_13:.*]]: index): +// C_HECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref> +// C_HECK: %[[VAL_15:.*]] = scf.while (%[[VAL_16:.*]] = %[[VAL_13]]) : (index) -> index { +// C_HECK: %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_9]] : index +// C_HECK: %[[VAL_18:.*]] = scf.if %[[VAL_17]] -> (i1) { +// C_HECK: %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_16]]] : memref> +// C_HECK: %[[VAL_20:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_14]] : index +// C_HECK: scf.yield %[[VAL_20]] : i1 +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_1]] : i1 +// C_HECK: } +// C_HECK: scf.condition(%[[VAL_21:.*]]) %[[VAL_16]] : index +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_22:.*]]: index): +// C_HECK: %[[VAL_23:.*]] = arith.addi %[[VAL_22]], %[[VAL_3]] : index +// C_HECK: scf.yield %[[VAL_23]] : index +// C_HECK: } +// C_HECK: scf.for %[[VAL_24:.*]] = %[[VAL_13]] to %[[VAL_25:.*]] step %[[VAL_3]] { +// C_HECK: %[[VAL_26:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref +// C_HECK: %[[VAL_27:.*]] = arith.mulf %[[VAL_26]], %[[VAL_4]] : f32 +// C_HECK: memref.store %[[VAL_27]], %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref +// C_HECK: } {"Emitted from" = "linalg.generic"} +// C_HECK: scf.yield %[[VAL_28:.*]] : index +// C_HECK: } attributes {"Emitted from" = "linalg.generic"} +// C_HECK: %[[VAL_29:.*]] = sparse_tensor.load %[[VAL_0]] : tensor +// C_HECK: return %[[VAL_29]] : tensor +// C_HECK: } func.func @sparse_scale(%argx: tensor) -> tensor { %c = arith.constant 2.0 : f32 %0 = linalg.generic #trait_scale @@ -89,57 +90,57 @@ func.func @sparse_scale(%argx: tensor) -> tensor } -// CHECK-LABEL: func.func @matvec( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x64xf64, #sparse{{[0-9]*}}>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<64xf64>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> { -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref -// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> -// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> -// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref -// CHECK: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64> -// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref -// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref -// CHECK: %[[VAL_13:.*]] = scf.while (%[[VAL_14:.*]] = %[[VAL_11]]) : (index) -> index { -// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_14]], %[[VAL_12]] : index -// CHECK: scf.condition(%[[VAL_15]]) %[[VAL_14]] : index -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_16:.*]]: index): -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref> -// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref> -// CHECK: %[[VAL_19:.*]] = scf.while (%[[VAL_20:.*]] = %[[VAL_16]]) : (index) -> index { -// CHECK: %[[VAL_21:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_12]] : index -// CHECK: %[[VAL_22:.*]] = scf.if %[[VAL_21]] -> (i1) { -// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref> -// CHECK: %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_18]] : index -// CHECK: scf.yield %[[VAL_24]] : i1 -// CHECK: } else { -// CHECK: scf.yield %[[VAL_3]] : i1 -// CHECK: } -// CHECK: scf.condition(%[[VAL_25:.*]]) %[[VAL_20]] : index -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_26:.*]]: index): -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_5]] : index -// CHECK: scf.yield %[[VAL_27]] : index -// CHECK: } -// CHECK: %[[VAL_28:.*]] = tensor.extract %[[VAL_2]]{{\[}}%[[VAL_17]]] : tensor<32xf64> -// CHECK: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_16]] to %[[VAL_31:.*]] step %[[VAL_5]] iter_args(%[[VAL_32:.*]] = %[[VAL_28]]) -> (f64) { -// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref> -// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_30]]] : memref -// CHECK: %[[VAL_35:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_33]]] : tensor<64xf64> -// CHECK: %[[VAL_36:.*]] = arith.mulf %[[VAL_34]], %[[VAL_35]] : f64 -// CHECK: %[[VAL_37:.*]] = arith.addf %[[VAL_32]], %[[VAL_36]] : f64 -// CHECK: scf.yield %[[VAL_37]] : f64 -// CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: memref.store %[[VAL_38:.*]], %[[VAL_10]]{{\[}}%[[VAL_17]]] : memref<32xf64> -// CHECK: scf.yield %[[VAL_39:.*]] : index -// CHECK: } attributes {"Emitted from" = "linalg.generic"} -// CHECK: %[[VAL_40:.*]] = bufferization.to_tensor %[[VAL_10]] : memref<32xf64> -// CHECK: return %[[VAL_40]] : tensor<32xf64> -// CHECK: } +// C_HECK-LABEL: func.func @matvec( +// C_HECK-SAME: %[[VAL_0:.*]]: tensor<32x64xf64, #sparse{{[0-9]*}}>, +// C_HECK-SAME: %[[VAL_1:.*]]: tensor<64xf64>, +// C_HECK-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> { +// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant false +// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index +// C_HECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// C_HECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref +// C_HECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> +// C_HECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> +// C_HECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref +// C_HECK: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64> +// C_HECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref +// C_HECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref +// C_HECK: %[[VAL_13:.*]] = scf.while (%[[VAL_14:.*]] = %[[VAL_11]]) : (index) -> index { +// C_HECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_14]], %[[VAL_12]] : index +// C_HECK: scf.condition(%[[VAL_15]]) %[[VAL_14]] : index +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_16:.*]]: index): +// C_HECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref> +// C_HECK: %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref> +// C_HECK: %[[VAL_19:.*]] = scf.while (%[[VAL_20:.*]] = %[[VAL_16]]) : (index) -> index { +// C_HECK: %[[VAL_21:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_12]] : index +// C_HECK: %[[VAL_22:.*]] = scf.if %[[VAL_21]] -> (i1) { +// C_HECK: %[[VAL_23:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref> +// C_HECK: %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_18]] : index +// C_HECK: scf.yield %[[VAL_24]] : i1 +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_3]] : i1 +// C_HECK: } +// C_HECK: scf.condition(%[[VAL_25:.*]]) %[[VAL_20]] : index +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_26:.*]]: index): +// C_HECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_5]] : index +// C_HECK: scf.yield %[[VAL_27]] : index +// C_HECK: } +// C_HECK: %[[VAL_28:.*]] = tensor.extract %[[VAL_2]]{{\[}}%[[VAL_17]]] : tensor<32xf64> +// C_HECK: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_16]] to %[[VAL_31:.*]] step %[[VAL_5]] iter_args(%[[VAL_32:.*]] = %[[VAL_28]]) -> (f64) { +// C_HECK: %[[VAL_33:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref> +// C_HECK: %[[VAL_34:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_30]]] : memref +// C_HECK: %[[VAL_35:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_33]]] : tensor<64xf64> +// C_HECK: %[[VAL_36:.*]] = arith.mulf %[[VAL_34]], %[[VAL_35]] : f64 +// C_HECK: %[[VAL_37:.*]] = arith.addf %[[VAL_32]], %[[VAL_36]] : f64 +// C_HECK: scf.yield %[[VAL_37]] : f64 +// C_HECK: } {"Emitted from" = "linalg.generic"} +// C_HECK: memref.store %[[VAL_38:.*]], %[[VAL_10]]{{\[}}%[[VAL_17]]] : memref<32xf64> +// C_HECK: scf.yield %[[VAL_39:.*]] : index +// C_HECK: } attributes {"Emitted from" = "linalg.generic"} +// C_HECK: %[[VAL_40:.*]] = bufferization.to_tensor %[[VAL_10]] : memref<32xf64> +// C_HECK: return %[[VAL_40]] : tensor<32xf64> +// C_HECK: } func.func @matvec(%arga: tensor<32x64xf64, #SortedCOO>, %argb: tensor<64xf64>, %argx: tensor<32xf64>) -> tensor<32xf64> { @@ -154,112 +155,112 @@ func.func @matvec(%arga: tensor<32x64xf64, #SortedCOO>, return %0 : tensor<32xf64> } -// CHECK-LABEL: func.func @mateltmul( -// CHECK-SAME: %[[VAL_0:.*0]]: tensor<32x64xf64, #sparse{{[0-9]*}}>, %[[VAL_1:.*1]]: tensor<32x64xf64, #sparse{{[0-9]*}}>, -// CHECK-SAME: %[[VAL_2:.*2]]: tensor<32x64xf64>) -> tensor<32x64xf64> { -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref -// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> -// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> -// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref -// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref -// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> -// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> -// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref -// CHECK: %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x64xf64> -// CHECK: linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_15]] : memref<32x64xf64>) -// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref -// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref -// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref -// CHECK: %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]]) : (index, index) -> (index, index) { -// CHECK: %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_21]], %[[VAL_17]] : index -// CHECK: %[[VAL_24:.*]] = arith.cmpi ult, %[[VAL_22]], %[[VAL_19]] : index -// CHECK: %[[VAL_25:.*]] = arith.andi %[[VAL_23]], %[[VAL_24]] : i1 -// CHECK: scf.condition(%[[VAL_25]]) %[[VAL_21]], %[[VAL_22]] : index, index -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index): -// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref> -// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref> -// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref> -// CHECK: %[[VAL_33:.*]] = scf.while (%[[VAL_34:.*]] = %[[VAL_26]]) : (index) -> index { -// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_17]] : index -// CHECK: %[[VAL_36:.*]] = scf.if %[[VAL_35]] -> (i1) { -// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_34]]] : memref> -// CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_32]] : index -// CHECK: scf.yield %[[VAL_38]] : i1 -// CHECK: } else { -// CHECK: scf.yield %[[VAL_3]] : i1 -// CHECK: } -// CHECK: scf.condition(%[[VAL_39:.*]]) %[[VAL_34]] : index -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_40:.*]]: index): -// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_40]], %[[VAL_6]] : index -// CHECK: scf.yield %[[VAL_41]] : index -// CHECK: } -// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref> -// CHECK: %[[VAL_43:.*]] = scf.while (%[[VAL_44:.*]] = %[[VAL_27]]) : (index) -> index { -// CHECK: %[[VAL_45:.*]] = arith.cmpi ult, %[[VAL_44]], %[[VAL_19]] : index -// CHECK: %[[VAL_46:.*]] = scf.if %[[VAL_45]] -> (i1) { -// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_44]]] : memref> -// CHECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_47]], %[[VAL_42]] : index -// CHECK: scf.yield %[[VAL_48]] : i1 -// CHECK: } else { -// CHECK: scf.yield %[[VAL_3]] : i1 -// CHECK: } -// CHECK: scf.condition(%[[VAL_49:.*]]) %[[VAL_44]] : index -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_50:.*]]: index): -// CHECK: %[[VAL_51:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index -// CHECK: scf.yield %[[VAL_51]] : index -// CHECK: } -// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index -// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index -// CHECK: %[[VAL_52:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index -// CHECK: %[[VAL_53:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index -// CHECK: %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1 -// CHECK: scf.if %[[VAL_54]] { -// CHECK: %[[VAL_55:.*]]:2 = scf.while (%[[VAL_56:.*]] = %[[VAL_26]], %[[VAL_57:.*]] = %[[VAL_27]]) : (index, index) -> (index, index) { -// CHECK: %[[VAL_58:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_59:.*]] : index -// CHECK: %[[VAL_60:.*]] = arith.cmpi ult, %[[VAL_57]], %[[VAL_61:.*]] : index -// CHECK: %[[VAL_62:.*]] = arith.andi %[[VAL_58]], %[[VAL_60]] : i1 -// CHECK: scf.condition(%[[VAL_62]]) %[[VAL_56]], %[[VAL_57]] : index, index -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_63:.*]]: index, %[[VAL_64:.*]]: index): -// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_63]]] : memref> -// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_64]]] : memref> -// CHECK: %[[VAL_67:.*]] = arith.cmpi ult, %[[VAL_66]], %[[VAL_65]] : index -// CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_67]], %[[VAL_66]], %[[VAL_65]] : index -// CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_68]] : index -// CHECK: %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_68]] : index -// CHECK: %[[VAL_71:.*]] = arith.andi %[[VAL_69]], %[[VAL_70]] : i1 -// CHECK: scf.if %[[VAL_71]] { -// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_63]]] : memref -// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_64]]] : memref -// CHECK: %[[VAL_74:.*]] = arith.mulf %[[VAL_72]], %[[VAL_73]] : f64 -// CHECK: memref.store %[[VAL_74]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_68]]] : memref<32x64xf64> -// CHECK: } -// CHECK: %[[VAL_75:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_68]] : index -// CHECK: %[[VAL_76:.*]] = arith.addi %[[VAL_63]], %[[VAL_6]] : index -// CHECK: %[[VAL_77:.*]] = arith.select %[[VAL_75]], %[[VAL_76]], %[[VAL_63]] : index -// CHECK: %[[VAL_78:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_68]] : index -// CHECK: %[[VAL_79:.*]] = arith.addi %[[VAL_64]], %[[VAL_6]] : index -// CHECK: %[[VAL_80:.*]] = arith.select %[[VAL_78]], %[[VAL_79]], %[[VAL_64]] : index -// CHECK: scf.yield %[[VAL_77]], %[[VAL_80]] : index, index -// CHECK: } attributes {"Emitted from" = "linalg.generic"} -// CHECK: } -// CHECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index -// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_83:.*]], %[[VAL_26]] : index -// CHECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index -// CHECK: %[[VAL_85:.*]] = arith.select %[[VAL_84]], %[[VAL_86:.*]], %[[VAL_27]] : index -// CHECK: scf.yield %[[VAL_82]], %[[VAL_85]] : index, index -// CHECK: } attributes {"Emitted from" = "linalg.generic"} -// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<32x64xf64> -// CHECK: return %[[VAL_87]] : tensor<32x64xf64> -// CHECK: } +// C_HECK-LABEL: func.func @mateltmul( +// C_HECK-SAME: %[[VAL_0:.*0]]: tensor<32x64xf64, #sparse{{[0-9]*}}>, %[[VAL_1:.*1]]: tensor<32x64xf64, #sparse{{[0-9]*}}>, +// C_HECK-SAME: %[[VAL_2:.*2]]: tensor<32x64xf64>) -> tensor<32x64xf64> { +// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant false +// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64 +// C_HECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index +// C_HECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index +// C_HECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref +// C_HECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> +// C_HECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> +// C_HECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref +// C_HECK-DAG: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref +// C_HECK-DAG: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> +// C_HECK-DAG: %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref> +// C_HECK-DAG: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref +// C_HECK: %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x64xf64> +// C_HECK: linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_15]] : memref<32x64xf64>) +// C_HECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref +// C_HECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref +// C_HECK: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref +// C_HECK: %[[VAL_19:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref +// C_HECK: %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]]) : (index, index) -> (index, index) { +// C_HECK: %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_21]], %[[VAL_17]] : index +// C_HECK: %[[VAL_24:.*]] = arith.cmpi ult, %[[VAL_22]], %[[VAL_19]] : index +// C_HECK: %[[VAL_25:.*]] = arith.andi %[[VAL_23]], %[[VAL_24]] : i1 +// C_HECK: scf.condition(%[[VAL_25]]) %[[VAL_21]], %[[VAL_22]] : index, index +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index): +// C_HECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref> +// C_HECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref> +// C_HECK: %[[VAL_32:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref> +// C_HECK: %[[VAL_33:.*]] = scf.while (%[[VAL_34:.*]] = %[[VAL_26]]) : (index) -> index { +// C_HECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_17]] : index +// C_HECK: %[[VAL_36:.*]] = scf.if %[[VAL_35]] -> (i1) { +// C_HECK: %[[VAL_37:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_34]]] : memref> +// C_HECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_32]] : index +// C_HECK: scf.yield %[[VAL_38]] : i1 +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_3]] : i1 +// C_HECK: } +// C_HECK: scf.condition(%[[VAL_39:.*]]) %[[VAL_34]] : index +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_40:.*]]: index): +// C_HECK: %[[VAL_41:.*]] = arith.addi %[[VAL_40]], %[[VAL_6]] : index +// C_HECK: scf.yield %[[VAL_41]] : index +// C_HECK: } +// C_HECK: %[[VAL_42:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref> +// C_HECK: %[[VAL_43:.*]] = scf.while (%[[VAL_44:.*]] = %[[VAL_27]]) : (index) -> index { +// C_HECK: %[[VAL_45:.*]] = arith.cmpi ult, %[[VAL_44]], %[[VAL_19]] : index +// C_HECK: %[[VAL_46:.*]] = scf.if %[[VAL_45]] -> (i1) { +// C_HECK: %[[VAL_47:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_44]]] : memref> +// C_HECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_47]], %[[VAL_42]] : index +// C_HECK: scf.yield %[[VAL_48]] : i1 +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_3]] : i1 +// C_HECK: } +// C_HECK: scf.condition(%[[VAL_49:.*]]) %[[VAL_44]] : index +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_50:.*]]: index): +// C_HECK: %[[VAL_51:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index +// C_HECK: scf.yield %[[VAL_51]] : index +// C_HECK: } +// C_HECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index +// C_HECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index +// C_HECK: %[[VAL_52:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index +// C_HECK: %[[VAL_53:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index +// C_HECK: %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1 +// C_HECK: scf.if %[[VAL_54]] { +// C_HECK: %[[VAL_55:.*]]:2 = scf.while (%[[VAL_56:.*]] = %[[VAL_26]], %[[VAL_57:.*]] = %[[VAL_27]]) : (index, index) -> (index, index) { +// C_HECK: %[[VAL_58:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_59:.*]] : index +// C_HECK: %[[VAL_60:.*]] = arith.cmpi ult, %[[VAL_57]], %[[VAL_61:.*]] : index +// C_HECK: %[[VAL_62:.*]] = arith.andi %[[VAL_58]], %[[VAL_60]] : i1 +// C_HECK: scf.condition(%[[VAL_62]]) %[[VAL_56]], %[[VAL_57]] : index, index +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_63:.*]]: index, %[[VAL_64:.*]]: index): +// C_HECK: %[[VAL_65:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_63]]] : memref> +// C_HECK: %[[VAL_66:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_64]]] : memref> +// C_HECK: %[[VAL_67:.*]] = arith.cmpi ult, %[[VAL_66]], %[[VAL_65]] : index +// C_HECK: %[[VAL_68:.*]] = arith.select %[[VAL_67]], %[[VAL_66]], %[[VAL_65]] : index +// C_HECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_68]] : index +// C_HECK: %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_68]] : index +// C_HECK: %[[VAL_71:.*]] = arith.andi %[[VAL_69]], %[[VAL_70]] : i1 +// C_HECK: scf.if %[[VAL_71]] { +// C_HECK: %[[VAL_72:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_63]]] : memref +// C_HECK: %[[VAL_73:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_64]]] : memref +// C_HECK: %[[VAL_74:.*]] = arith.mulf %[[VAL_72]], %[[VAL_73]] : f64 +// C_HECK: memref.store %[[VAL_74]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_68]]] : memref<32x64xf64> +// C_HECK: } +// C_HECK: %[[VAL_75:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_68]] : index +// C_HECK: %[[VAL_76:.*]] = arith.addi %[[VAL_63]], %[[VAL_6]] : index +// C_HECK: %[[VAL_77:.*]] = arith.select %[[VAL_75]], %[[VAL_76]], %[[VAL_63]] : index +// C_HECK: %[[VAL_78:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_68]] : index +// C_HECK: %[[VAL_79:.*]] = arith.addi %[[VAL_64]], %[[VAL_6]] : index +// C_HECK: %[[VAL_80:.*]] = arith.select %[[VAL_78]], %[[VAL_79]], %[[VAL_64]] : index +// C_HECK: scf.yield %[[VAL_77]], %[[VAL_80]] : index, index +// C_HECK: } attributes {"Emitted from" = "linalg.generic"} +// C_HECK: } +// C_HECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index +// C_HECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_83:.*]], %[[VAL_26]] : index +// C_HECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index +// C_HECK: %[[VAL_85:.*]] = arith.select %[[VAL_84]], %[[VAL_86:.*]], %[[VAL_27]] : index +// C_HECK: scf.yield %[[VAL_82]], %[[VAL_85]] : index, index +// C_HECK: } attributes {"Emitted from" = "linalg.generic"} +// C_HECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<32x64xf64> +// C_HECK: return %[[VAL_87]] : tensor<32x64xf64> +// C_HECK: } func.func @mateltmul(%argx: tensor<32x64xf64, #SortedCOO>, %argy: tensor<32x64xf64, #SortedCOO>, %argz: tensor<32x64xf64>) -> tensor<32x64xf64> { diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir index 57ae18391daf8..85ae0db916899 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir @@ -29,9 +29,9 @@ // CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32> // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_10]] : memref<32x16xf32>) // CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { -// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : index // CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref // CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32> // CHECK: %[[VAL_17:.*]] = arith.addf %[[VAL_15]], %[[VAL_16]] : f32 @@ -66,9 +66,9 @@ func.func @add_dd(%arga: tensor<32x16xf32, #Tdd>, %argb: tensor<32x16xf32>, %arg // CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xi1> // CHECK: linalg.fill ins(%[[VAL_5]] : i1) outs(%[[VAL_10]] : memref<32x16xi1>) // CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] { +// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { -// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : index // CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref // CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32> // CHECK: %[[VAL_17:.*]] = arith.cmpf ult, %[[VAL_15]], %[[VAL_16]] : f32 @@ -102,9 +102,9 @@ func.func @cmp_dd(%arga: tensor<32x16xf32, #Tdd>, %argb: tensor<32x16xf32>, %arg // CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32> // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_10]] : memref<32x16xf32>) // CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { -// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : index // CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref // CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32> // CHECK: %[[VAL_17:.*]] = arith.mulf %[[VAL_15]], %[[VAL_16]] : f32 @@ -319,9 +319,9 @@ func.func @mul_ds(%arga: tensor<32x16xf32, #Tds>, %argb: tensor<32x16xf32>, %arg // CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref // CHECK: %[[VAL_23:.*]] = arith.cmpi eq, %[[VAL_22]], %[[VAL_21]] : index // CHECK: scf.if %[[VAL_23]] { +// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_20]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { -// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_20]], %[[VAL_4]] : index -// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_24]] : index +// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_24]], %[[VAL_25]] : index // CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref // CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_24]]] : memref<32x16xf32> // CHECK: %[[VAL_29:.*]] = arith.addf %[[VAL_27]], %[[VAL_28]] : f32 @@ -389,9 +389,9 @@ func.func @add_sd(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32>, %arg // CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]]] : memref // CHECK: %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_22]] : index // CHECK: scf.if %[[VAL_24]] { +// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_21]], %[[VAL_3]] : index // CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] { -// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_21]], %[[VAL_3]] : index -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_25]] : index +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index // CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref // CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_25]]] : memref<32x16xf32> // CHECK: %[[VAL_30:.*]] = arith.cmpf ult, %[[VAL_28]], %[[VAL_29]] : f32 @@ -451,9 +451,9 @@ func.func @cmp_sd(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32>, %arg // CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref // CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_5]] { // CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_14]], %[[VAL_3]] : index // CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { -// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_14]], %[[VAL_3]] : index -// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_16]] : index +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : index // CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref // CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]], %[[VAL_16]]] : memref<32x16xf32> // CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f32 @@ -1272,6 +1272,7 @@ func.func @mul_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T // CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref // CHECK: %[[VAL_25:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_23]] : index // CHECK: scf.if %[[VAL_25]] { +// CHECK: %[[VAL_36:.*]] = arith.muli %[[VAL_22]], %[[VAL_4]] : index // CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref // CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_23]], %[[VAL_7]] : index // CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref @@ -1281,8 +1282,7 @@ func.func @mul_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T // CHECK: } do { // CHECK: ^bb0(%[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index): // CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_33]]] : memref -// CHECK: %[[VAL_36:.*]] = arith.muli %[[VAL_22]], %[[VAL_4]] : index -// CHECK: %[[VAL_37:.*]] = arith.addi %[[VAL_36]], %[[VAL_34]] : index +// CHECK: %[[VAL_37:.*]] = arith.addi %[[VAL_34]], %[[VAL_36]] : index // CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_34]] : index // CHECK: scf.if %[[VAL_38]] { // CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_37]]] : memref @@ -1303,8 +1303,7 @@ func.func @mul_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T // CHECK: scf.yield %[[VAL_45]], %[[VAL_46]] : index, index // CHECK: } // CHECK: scf.for %[[VAL_47:.*]] = %[[VAL_48:.*]]#1 to %[[VAL_4]] step %[[VAL_7]] { -// CHECK: %[[VAL_49:.*]] = arith.muli %[[VAL_22]], %[[VAL_4]] : index -// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_49]], %[[VAL_47]] : index +// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_47]], %[[VAL_36]] : index // CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_50]]] : memref // CHECK: memref.store %[[VAL_51]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_47]]] : memref<32x16xf32> // CHECK: } @@ -1369,13 +1368,13 @@ func.func @add_sd_ds(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32, #T // CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref // CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] { // CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_23:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref // CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_5]] : index // CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref // CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_5]] { // CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref -// CHECK: %[[VAL_23:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index -// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : index +// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_22]], %[[VAL_23]] : index // CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref // CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]]] : memref // CHECK: %[[VAL_27:.*]] = arith.mulf %[[VAL_25]], %[[VAL_26]] : f32 diff --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir index 4911c78bcff34..b2f528fc7a25e 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir @@ -37,12 +37,12 @@ // CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32> // CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>) // CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] { +// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { -// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_13]] : index +// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index +// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index // CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] { -// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index -// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_16]] : index +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : index // CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref // CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32> // CHECK: %[[VAL_21:.*]] = arith.addf %[[VAL_19]], %[[VAL_20]] : f32 @@ -79,12 +79,12 @@ func.func @add_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32> // CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32> // CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>) // CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] { +// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { -// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_13]] : index +// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index +// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index // CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] { -// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index -// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_16]] : index +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : index // CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref // CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32> // CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f32 @@ -124,9 +124,9 @@ func.func @mul_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32> // CHECK-DAG: %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32> // CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>) // CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_9]] { +// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_16]], %[[VAL_5]] : index // CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_9]] { -// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_16]], %[[VAL_5]] : index -// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_17]] : index +// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_18]] : index // CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref // CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_9]] : index // CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref @@ -191,9 +191,9 @@ func.func @add_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32> // CHECK-DAG: %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32> // CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>) // CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { +// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_14]], %[[VAL_5]] : index // CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] { -// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_14]], %[[VAL_5]] : index -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_15]] : index +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : index // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref // CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_7]] : index // CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref @@ -249,9 +249,9 @@ func.func @mul_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32> // CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref // CHECK: %[[VAL_26:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index // CHECK: scf.if %[[VAL_26]] { +// CHECK: %[[VAL_28:.*]] = arith.muli %[[VAL_23]], %[[VAL_5]] : index // CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { -// CHECK: %[[VAL_28:.*]] = arith.muli %[[VAL_23]], %[[VAL_5]] : index -// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_28]], %[[VAL_27]] : index +// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_27]], %[[VAL_28]] : index // CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref // CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_27]]] : memref<32x16x8xf32> // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[VAL_31]] : f32 @@ -314,9 +314,9 @@ func.func @add_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32> // CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref // CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_6]] { // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { -// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index -// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : index +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_20]] : index // CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref // CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_13]], %[[VAL_18]], %[[VAL_19]]] : memref<32x16x8xf32> // CHECK: %[[VAL_24:.*]] = arith.mulf %[[VAL_22]], %[[VAL_23]] : f32 @@ -512,12 +512,12 @@ func.func @mul_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32> // CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref // CHECK: %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_22]] : index // CHECK: scf.if %[[VAL_24]] { +// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_21]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { -// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_21]], %[[VAL_4]] : index -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_25]] : index +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index +// CHECK: %[[VAL_29:.*]] = arith.muli %[[VAL_27]], %[[VAL_5]] : index // CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { -// CHECK: %[[VAL_29:.*]] = arith.muli %[[VAL_27]], %[[VAL_5]] : index -// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_28]], %[[VAL_29]] : index // CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref // CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]] : memref<32x16x8xf32> // CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32 @@ -582,12 +582,12 @@ func.func @add_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32> // CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref // CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_6]] { // CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_15]], %[[VAL_3]] : index // CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { -// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_15]], %[[VAL_3]] : index -// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_17]] : index +// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_18]] : index +// CHECK: %[[VAL_21:.*]] = arith.muli %[[VAL_19]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { -// CHECK: %[[VAL_21:.*]] = arith.muli %[[VAL_19]], %[[VAL_4]] : index -// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_20]] : index +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : index // CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref // CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_20]]] : memref<32x16x8xf32> // CHECK: %[[VAL_25:.*]] = arith.mulf %[[VAL_23]], %[[VAL_24]] : f32 @@ -638,9 +638,9 @@ func.func @mul_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32> // CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref // CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_25]] : index // CHECK: scf.if %[[VAL_27]] { +// CHECK: %[[VAL_29:.*]] = arith.muli %[[VAL_24]], %[[VAL_5]] : index // CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] { -// CHECK: %[[VAL_29:.*]] = arith.muli %[[VAL_24]], %[[VAL_5]] : index -// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_28]], %[[VAL_29]] : index // CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_30]]] : memref // CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_30]], %[[VAL_9]] : index // CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref @@ -733,9 +733,9 @@ func.func @add_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32> // CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref // CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_6]] { // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { -// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index -// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : index +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_20]] : index // CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref // CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_6]] : index // CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref @@ -802,9 +802,9 @@ func.func @mul_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32> // CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_34]]] : memref // CHECK: %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_35]] : index // CHECK: scf.if %[[VAL_37]] { +// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_34]], %[[VAL_5]] : index // CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { -// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_34]], %[[VAL_5]] : index -// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_38]] : index +// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_38]], %[[VAL_39]] : index // CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_40]]] : memref // CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_38]]] : memref<32x16x8xf32> // CHECK: %[[VAL_43:.*]] = arith.addf %[[VAL_41]], %[[VAL_42]] : f32 @@ -895,9 +895,9 @@ func.func @add_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32> // CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref // CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_5]] { // CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_24:.*]] = arith.muli %[[VAL_21]], %[[VAL_3]] : index // CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { -// CHECK: %[[VAL_24:.*]] = arith.muli %[[VAL_21]], %[[VAL_3]] : index -// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_23]] : index +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : index // CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref // CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_23]]] : memref<32x16x8xf32> // CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_26]], %[[VAL_27]] : f32 @@ -1133,9 +1133,9 @@ func.func @mul_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32> // CHECK-DAG: %[[VAL_14:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor // CHECK-DAG: %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_0]] : memref // CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_13]] step %[[VAL_6]] { +// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_10]] : index // CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_5]] to %[[VAL_10]] step %[[VAL_6]] { -// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_10]], %[[VAL_17]] : index -// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index +// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : index // CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref // CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_6]] : index // CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_22]]] : memref diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir index 886b21fa97567..2128ca7539fa0 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir @@ -234,9 +234,9 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>, // CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref // CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] { // CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref -// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_7]] : index // CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_17]], %[[VAL_3]] : index -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_25]] : index +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_7]] : index +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index // CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_24]]] : memref<32x16xf64> // CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref // CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_27]]] : memref diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir index bf61e792ffbe0..70cf0f9af45b5 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir @@ -1,3 +1,4 @@ +// TODO: re-enable after lowering coo.next to function call (such that loop structure is more clear). // RUN: mlir-opt %s --sparse-reinterpret-map --sparsification --canonicalize --cse | FileCheck %s #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> @@ -8,232 +9,232 @@ // CHECK-LABEL: func.func @conv2d_all_sparse_CSR( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32, #sparse>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse> { -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant true -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant -2 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_11:.*]] = tensor.empty() : tensor<6x6xi32, #sparse> -// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref -// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref -// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref -// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref -// CHECK-DAG: %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref -// CHECK-DAG: %[[VAL_17:.*]] = memref.alloca() : memref<9xindex> -// CHECK-DAG: %[[VAL_18:.*]] = memref.alloca() : memref<3xindex> -// CHECK-DAG: %[[POS_LO:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref -// CHECK-DAG: %[[POS_HI:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref -// CHECK: memref.store %[[POS_LO]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex> -// CHECK: memref.store %[[POS_HI]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex> -// CHECK: %[[VAL_20:.*]] = arith.cmpi ult, %[[POS_LO]], %[[POS_HI]] : index -// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[POS_LO]]] : memref -// CHECK: %[[VAL_22:.*]] = arith.cmpi uge, %[[VAL_21]], %[[VAL_6]] : index -// CHECK: %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : i1 -// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index -// CHECK: %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_8]] : index -// CHECK: %[[VAL_26:.*]]:3 = scf.while (%[[VAL_27:.*]] = %[[VAL_20]], %[[VAL_28:.*]] = %[[VAL_21]], %[[VAL_29:.*]] = %[[VAL_25]], %[[VAL_30:.*]] = %[[VAL_11]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) { -// CHECK: scf.condition(%[[VAL_27]]) %[[VAL_28]], %[[VAL_29]], %[[VAL_30]] : index, index, tensor<6x6xi32, #sparse> -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index, %[[VAL_33:.*]]: tensor<6x6xi32, #sparse>): -// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex> -// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex> -// CHECK: memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref<3xindex> -// CHECK: %[[VAL_36:.*]] = arith.addi %[[VAL_32]], %[[VAL_6]] : index -// CHECK: %[[VAL_37:.*]]:5 = scf.while (%[[VAL_38:.*]] = %[[VAL_34]], %[[VAL_39:.*]] = %[[VAL_10]], %[[VAL_40:.*]] = %[[VAL_5]], %[[VAL_41:.*]] = %[[VAL_8]], %[[VAL_42:.*]] = %[[VAL_8]]) : (index, i1, index, index, index) -> (index, i1, index, index, index) { -// CHECK: %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_38]], %[[VAL_35]] : index -// CHECK: %[[VAL_44:.*]] = scf.if %[[VAL_43]] -> (i1) { -// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_38]]] : memref -// CHECK: %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_36]] : index -// CHECK: scf.yield %[[VAL_46]] : i1 -// CHECK: } else { -// CHECK: scf.yield %[[VAL_10]] : i1 -// CHECK: } -// CHECK: scf.condition(%[[VAL_44]]) %[[VAL_38]], %[[VAL_39]], %[[VAL_40]], %[[VAL_41]], %[[VAL_42]] : index, i1, index, index, index -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_47:.*]]: index, %[[VAL_48:.*]]: i1, %[[VAL_49:.*]]: index, %[[VAL_50:.*]]: index, %[[VAL_51:.*]]: index): -// CHECK-DAG: %[[VAL_52:.*]] = arith.addi %[[VAL_47]], %[[VAL_7]] : index -// CHECK-DAG: %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_47]]] : memref -// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_52]]] : memref -// CHECK: %[[VAL_55:.*]] = arith.cmpi ult, %[[VAL_53]], %[[VAL_54]] : index -// CHECK: %[[VAL_56:.*]] = arith.ori %[[VAL_55]], %[[VAL_48]] : i1 -// CHECK: %[[VAL_57:.*]] = scf.if %[[VAL_55]] -> (index) { -// CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_53]]] : memref -// CHECK: %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_49]] : index -// CHECK: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_58]], %[[VAL_49]] : index -// CHECK: scf.yield %[[VAL_60]] : index -// CHECK: } else { -// CHECK: scf.yield %[[VAL_49]] : index -// CHECK: } -// CHECK: memref.store %[[VAL_53]], %[[VAL_17]]{{\[}}%[[VAL_50]]] : memref<9xindex> -// CHECK: %[[VAL_61:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index -// CHECK: memref.store %[[VAL_54]], %[[VAL_17]]{{\[}}%[[VAL_61]]] : memref<9xindex> -// CHECK: %[[VAL_62:.*]] = arith.addi %[[VAL_50]], %[[VAL_7]] : index -// CHECK: %[[VAL_63:.*]] = arith.addi %[[VAL_51]], %[[VAL_7]] : index -// CHECK: scf.yield %[[VAL_52]], %[[VAL_56]], %[[VAL_57]], %[[VAL_62]], %[[VAL_63]] : index, i1, index, index, index -// CHECK: } -// CHECK: %[[VAL_64:.*]] = arith.cmpi uge, %[[VAL_65:.*]]#2, %[[VAL_6]] : index -// CHECK: %[[VAL_66:.*]] = arith.andi %[[VAL_65]]#1, %[[VAL_64]] : i1 -// CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_65]]#2, %[[VAL_3]] : index -// CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_66]], %[[VAL_67]], %[[VAL_8]] : index -// CHECK: %[[VAL_69:.*]]:3 = scf.while (%[[VAL_70:.*]] = %[[VAL_65]]#1, %[[VAL_71:.*]] = %[[VAL_65]]#2, %[[VAL_72:.*]] = %[[VAL_68]], %[[VAL_73:.*]] = %[[VAL_33]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) { -// CHECK: scf.condition(%[[VAL_70]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]] : index, index, tensor<6x6xi32, #sparse> -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_74:.*]]: index, %[[VAL_75:.*]]: index, %[[VAL_76:.*]]: tensor<6x6xi32, #sparse>): -// CHECK: %[[VAL_77:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex> -// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex> -// CHECK: %[[VAL_79:.*]]:3 = scf.while (%[[VAL_80:.*]] = %[[VAL_77]], %[[VAL_81:.*]] = %[[VAL_9]], %[[VAL_82:.*]] = %[[VAL_10]]) : (index, i32, i1) -> (index, i32, i1) { -// CHECK: %[[VAL_83:.*]] = arith.cmpi ult, %[[VAL_80]], %[[VAL_78]] : index -// CHECK: %[[VAL_84:.*]] = scf.if %[[VAL_83]] -> (i1) { -// CHECK: %[[VAL_85:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_80]]] : memref -// CHECK: %[[VAL_86:.*]] = arith.cmpi ult, %[[VAL_85]], %[[VAL_36]] : index -// CHECK: scf.yield %[[VAL_86]] : i1 -// CHECK: } else { -// CHECK: scf.yield %[[VAL_10]] : i1 -// CHECK: } -// CHECK: scf.condition(%[[VAL_84]]) %[[VAL_80]], %[[VAL_81]], %[[VAL_82]] : index, i32, i1 -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_87:.*]]: index, %[[VAL_88:.*]]: i32, %[[VAL_89:.*]]: i1): -// CHECK: %[[VAL_90:.*]] = arith.subi %[[VAL_87]], %[[VAL_77]] : index -// CHECK: %[[VAL_91:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_87]]] : memref -// CHECK: %[[VAL_92:.*]] = arith.subi %[[VAL_91]], %[[VAL_32]] : index -// CHECK: %[[VAL_93:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_90]]] : memref<9xindex> -// CHECK: %[[VAL_94:.*]] = arith.addi %[[VAL_90]], %[[VAL_6]] : index -// CHECK: %[[VAL_95:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_94]]] : memref<9xindex> -// CHECK: %[[VAL_96:.*]] = arith.addi %[[VAL_75]], %[[VAL_6]] : index -// CHECK: %[[VAL_97:.*]]:2 = scf.while (%[[VAL_98:.*]] = %[[VAL_93]], %[[VAL_99:.*]] = %[[VAL_88]]) : (index, i32) -> (index, i32) { -// CHECK: %[[VAL_100:.*]] = arith.cmpi ult, %[[VAL_98]], %[[VAL_95]] : index -// CHECK: %[[VAL_101:.*]] = scf.if %[[VAL_100]] -> (i1) { -// CHECK: %[[VAL_102:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_98]]] : memref -// CHECK: %[[VAL_103:.*]] = arith.cmpi ult, %[[VAL_102]], %[[VAL_96]] : index -// CHECK: scf.yield %[[VAL_103]] : i1 -// CHECK: } else { -// CHECK: scf.yield %[[VAL_10]] : i1 -// CHECK: } -// CHECK: scf.condition(%[[VAL_101]]) %[[VAL_98]], %[[VAL_99]] : index, i32 -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_104:.*]]: index, %[[VAL_105:.*]]: i32): -// CHECK: %[[VAL_106:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_104]]] : memref -// CHECK: %[[VAL_107:.*]] = arith.subi %[[VAL_106]], %[[VAL_75]] : index -// CHECK: %[[VAL_108:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_104]]] : memref -// CHECK: %[[VAL_109:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_92]], %[[VAL_107]]] : tensor<3x3xi32> -// CHECK: %[[VAL_110:.*]] = arith.muli %[[VAL_108]], %[[VAL_109]] : i32 -// CHECK: %[[VAL_111:.*]] = arith.addi %[[VAL_105]], %[[VAL_110]] : i32 -// CHECK: %[[VAL_112:.*]] = arith.addi %[[VAL_104]], %[[VAL_7]] : index -// CHECK: scf.yield %[[VAL_112]], %[[VAL_111]] : index, i32 -// CHECK: } -// CHECK: %[[VAL_113:.*]] = arith.addi %[[VAL_87]], %[[VAL_7]] : index -// CHECK: scf.yield %[[VAL_113]], %[[VAL_114:.*]]#1, %[[VAL_2]] : index, i32, i1 -// CHECK: } -// CHECK: %[[VAL_115:.*]] = scf.if %[[VAL_116:.*]]#2 -> (tensor<6x6xi32, #sparse>) { -// CHECK: %[[VAL_117:.*]] = sparse_tensor.insert %[[VAL_116]]#1 into %[[VAL_76]]{{\[}}%[[VAL_32]], %[[VAL_75]]] : tensor<6x6xi32, #sparse> -// CHECK: scf.yield %[[VAL_117]] : tensor<6x6xi32, #sparse> -// CHECK: } else { -// CHECK: scf.yield %[[VAL_76]] : tensor<6x6xi32, #sparse> -// CHECK: } -// CHECK: %[[VAL_118:.*]] = arith.cmpi ugt, %[[VAL_74]], %[[VAL_75]] : index -// CHECK: %[[VAL_119:.*]]:3 = scf.if %[[VAL_118]] -> (index, i1, index) { -// CHECK: %[[VAL_120:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index -// CHECK: scf.yield %[[VAL_74]], %[[VAL_2]], %[[VAL_120]] : index, i1, index -// CHECK: } else { -// CHECK: %[[VAL_121:.*]]:2 = scf.for %[[VAL_122:.*]] = %[[VAL_8]] to %[[VAL_65]]#3 step %[[VAL_7]] iter_args(%[[VAL_123:.*]] = %[[VAL_5]], %[[VAL_124:.*]] = %[[VAL_10]]) -> (index, i1) { -// CHECK: %[[VAL_125:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex> -// CHECK: %[[VAL_126:.*]] = arith.addi %[[VAL_122]], %[[VAL_6]] : index -// CHECK: %[[VAL_127:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_126]]] : memref<9xindex> -// CHECK: %[[VAL_128:.*]] = arith.cmpi ult, %[[VAL_125]], %[[VAL_127]] : index -// CHECK: %[[VAL_129:.*]] = scf.if %[[VAL_128]] -> (index) { -// CHECK: %[[VAL_130:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_125]]] : memref -// CHECK: %[[VAL_131:.*]] = arith.cmpi eq, %[[VAL_130]], %[[VAL_74]] : index -// CHECK: %[[VAL_132:.*]] = scf.if %[[VAL_131]] -> (index) { -// CHECK: %[[VAL_133:.*]] = arith.addi %[[VAL_125]], %[[VAL_7]] : index -// CHECK: memref.store %[[VAL_133]], %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex> -// CHECK: scf.yield %[[VAL_133]] : index -// CHECK: } else { -// CHECK: scf.yield %[[VAL_125]] : index -// CHECK: } -// CHECK: scf.yield %[[VAL_132]] : index -// CHECK: } else { -// CHECK: scf.yield %[[VAL_125]] : index -// CHECK: } -// CHECK: %[[VAL_134:.*]] = arith.cmpi ult, %[[VAL_129]], %[[VAL_127]] : index -// CHECK: %[[VAL_135:.*]] = scf.if %[[VAL_134]] -> (index) { -// CHECK: %[[VAL_136:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_129]]] : memref -// CHECK: scf.yield %[[VAL_136]] : index -// CHECK: } else { -// CHECK: scf.yield %[[VAL_123]] : index -// CHECK: } -// CHECK: %[[VAL_137:.*]] = arith.ori %[[VAL_134]], %[[VAL_124]] : i1 -// CHECK: %[[VAL_138:.*]] = arith.cmpi ult, %[[VAL_135]], %[[VAL_123]] : index -// CHECK: %[[VAL_139:.*]] = arith.select %[[VAL_138]], %[[VAL_135]], %[[VAL_123]] : index -// CHECK: scf.yield %[[VAL_139]], %[[VAL_137]] : index, i1 -// CHECK: } -// CHECK: %[[VAL_140:.*]] = arith.addi %[[VAL_141:.*]]#0, %[[VAL_7]] : index -// CHECK: %[[VAL_142:.*]] = arith.addi %[[VAL_141]]#0, %[[VAL_3]] : index -// CHECK: %[[VAL_143:.*]] = arith.cmpi uge, %[[VAL_140]], %[[VAL_6]] : index -// CHECK: %[[VAL_144:.*]] = arith.select %[[VAL_143]], %[[VAL_142]], %[[VAL_8]] : index -// CHECK: scf.yield %[[VAL_141]]#0, %[[VAL_141]]#1, %[[VAL_144]] : index, i1, index -// CHECK: } -// CHECK: %[[VAL_145:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index -// CHECK: %[[VAL_146:.*]] = arith.cmpi ugt, %[[VAL_147:.*]]#2, %[[VAL_145]] : index -// CHECK: %[[VAL_148:.*]] = arith.select %[[VAL_146]], %[[VAL_147]]#2, %[[VAL_145]] : index -// CHECK: %[[VAL_149:.*]] = arith.addi %[[VAL_148]], %[[VAL_6]] : index -// CHECK: %[[VAL_150:.*]] = arith.cmpi ule, %[[VAL_149]], %[[VAL_5]] : index -// CHECK: %[[VAL_151:.*]] = arith.andi %[[VAL_147]]#1, %[[VAL_150]] : i1 -// CHECK: scf.yield %[[VAL_151]], %[[VAL_147]]#0, %[[VAL_148]], %[[VAL_115]] : i1, index, index, tensor<6x6xi32, #sparse> -// CHECK: } -// CHECK: %[[VAL_152:.*]] = arith.cmpi ugt, %[[VAL_31]], %[[VAL_32]] : index -// CHECK: %[[VAL_153:.*]]:3 = scf.if %[[VAL_152]] -> (index, i1, index) { -// CHECK: %[[VAL_154:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index -// CHECK: scf.yield %[[VAL_31]], %[[VAL_2]], %[[VAL_154]] : index, i1, index -// CHECK: } else { -// CHECK: %[[VAL_155:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex> -// CHECK: %[[VAL_156:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex> -// CHECK: %[[VAL_157:.*]] = arith.cmpi ult, %[[VAL_155]], %[[VAL_156]] : index -// CHECK: %[[VAL_158:.*]] = scf.if %[[VAL_157]] -> (index) { -// CHECK: %[[VAL_159:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_155]]] : memref -// CHECK: %[[VAL_160:.*]] = arith.cmpi eq, %[[VAL_159]], %[[VAL_31]] : index -// CHECK: %[[VAL_161:.*]] = scf.if %[[VAL_160]] -> (index) { -// CHECK: %[[VAL_162:.*]] = arith.addi %[[VAL_155]], %[[VAL_7]] : index -// CHECK: memref.store %[[VAL_162]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex> -// CHECK: scf.yield %[[VAL_162]] : index -// CHECK: } else { -// CHECK: scf.yield %[[VAL_155]] : index -// CHECK: } -// CHECK: scf.yield %[[VAL_161]] : index -// CHECK: } else { -// CHECK: scf.yield %[[VAL_155]] : index -// CHECK: } -// CHECK: %[[VAL_163:.*]] = arith.cmpi ult, %[[VAL_158]], %[[VAL_156]] : index -// CHECK: %[[VAL_164:.*]] = scf.if %[[VAL_163]] -> (index) { -// CHECK: %[[VAL_165:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_158]]] : memref -// CHECK: scf.yield %[[VAL_165]] : index -// CHECK: } else { -// CHECK: scf.yield %[[VAL_5]] : index -// CHECK: } -// CHECK: %[[VAL_166:.*]] = arith.cmpi ult, %[[VAL_164]], %[[VAL_5]] : index -// CHECK: %[[VAL_167:.*]] = arith.select %[[VAL_166]], %[[VAL_164]], %[[VAL_5]] : index -// CHECK: %[[VAL_168:.*]] = arith.addi %[[VAL_167]], %[[VAL_7]] : index -// CHECK: %[[VAL_169:.*]] = arith.addi %[[VAL_167]], %[[VAL_3]] : index -// CHECK: %[[VAL_170:.*]] = arith.cmpi uge, %[[VAL_168]], %[[VAL_6]] : index -// CHECK: %[[VAL_171:.*]] = arith.select %[[VAL_170]], %[[VAL_169]], %[[VAL_8]] : index -// CHECK: scf.yield %[[VAL_167]], %[[VAL_163]], %[[VAL_171]] : index, i1, index -// CHECK: } -// CHECK: %[[VAL_172:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index -// CHECK: %[[VAL_173:.*]] = arith.cmpi ugt, %[[VAL_174:.*]]#2, %[[VAL_172]] : index -// CHECK: %[[VAL_175:.*]] = arith.select %[[VAL_173]], %[[VAL_174]]#2, %[[VAL_172]] : index -// CHECK: %[[VAL_176:.*]] = arith.addi %[[VAL_175]], %[[VAL_6]] : index -// CHECK: %[[VAL_177:.*]] = arith.cmpi ule, %[[VAL_176]], %[[VAL_5]] : index -// CHECK: %[[VAL_178:.*]] = arith.andi %[[VAL_174]]#1, %[[VAL_177]] : i1 -// CHECK: scf.yield %[[VAL_178]], %[[VAL_174]]#0, %[[VAL_175]], %[[VAL_179:.*]]#2 : i1, index, index, tensor<6x6xi32, #sparse> -// CHECK: } -// CHECK: %[[VAL_180:.*]] = sparse_tensor.load %[[VAL_181:.*]]#2 hasInserts : tensor<6x6xi32, #sparse> -// CHECK: return %[[VAL_180]] : tensor<6x6xi32, #sparse> -// CHECK: } +// C_HECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32, #sparse>, +// C_HECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse> { +// C_HECK-DAG: %[[VAL_2:.*]] = arith.constant true +// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant -2 : index +// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index +// C_HECK-DAG: %[[VAL_5:.*]] = arith.constant 8 : index +// C_HECK-DAG: %[[VAL_6:.*]] = arith.constant 3 : index +// C_HECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index +// C_HECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index +// C_HECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : i32 +// C_HECK-DAG: %[[VAL_10:.*]] = arith.constant false +// C_HECK-DAG: %[[VAL_11:.*]] = tensor.empty() : tensor<6x6xi32, #sparse> +// C_HECK-DAG: %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref +// C_HECK-DAG: %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref +// C_HECK-DAG: %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref +// C_HECK-DAG: %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref +// C_HECK-DAG: %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref +// C_HECK-DAG: %[[VAL_17:.*]] = memref.alloca() : memref<9xindex> +// C_HECK-DAG: %[[VAL_18:.*]] = memref.alloca() : memref<3xindex> +// C_HECK-DAG: %[[POS_LO:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref +// C_HECK-DAG: %[[POS_HI:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref +// C_HECK: memref.store %[[POS_LO]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex> +// C_HECK: memref.store %[[POS_HI]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex> +// C_HECK: %[[VAL_20:.*]] = arith.cmpi ult, %[[POS_LO]], %[[POS_HI]] : index +// C_HECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[POS_LO]]] : memref +// C_HECK: %[[VAL_22:.*]] = arith.cmpi uge, %[[VAL_21]], %[[VAL_6]] : index +// C_HECK: %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : i1 +// C_HECK: %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index +// C_HECK: %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_8]] : index +// C_HECK: %[[VAL_26:.*]]:3 = scf.while (%[[VAL_27:.*]] = %[[VAL_20]], %[[VAL_28:.*]] = %[[VAL_21]], %[[VAL_29:.*]] = %[[VAL_25]], %[[VAL_30:.*]] = %[[VAL_11]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) { +// C_HECK: scf.condition(%[[VAL_27]]) %[[VAL_28]], %[[VAL_29]], %[[VAL_30]] : index, index, tensor<6x6xi32, #sparse> +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index, %[[VAL_33:.*]]: tensor<6x6xi32, #sparse>): +// C_HECK: %[[VAL_34:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex> +// C_HECK: %[[VAL_35:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex> +// C_HECK: memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref<3xindex> +// C_HECK: %[[VAL_36:.*]] = arith.addi %[[VAL_32]], %[[VAL_6]] : index +// C_HECK: %[[VAL_37:.*]]:5 = scf.while (%[[VAL_38:.*]] = %[[VAL_34]], %[[VAL_39:.*]] = %[[VAL_10]], %[[VAL_40:.*]] = %[[VAL_5]], %[[VAL_41:.*]] = %[[VAL_8]], %[[VAL_42:.*]] = %[[VAL_8]]) : (index, i1, index, index, index) -> (index, i1, index, index, index) { +// C_HECK: %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_38]], %[[VAL_35]] : index +// C_HECK: %[[VAL_44:.*]] = scf.if %[[VAL_43]] -> (i1) { +// C_HECK: %[[VAL_45:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_38]]] : memref +// C_HECK: %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_36]] : index +// C_HECK: scf.yield %[[VAL_46]] : i1 +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_10]] : i1 +// C_HECK: } +// C_HECK: scf.condition(%[[VAL_44]]) %[[VAL_38]], %[[VAL_39]], %[[VAL_40]], %[[VAL_41]], %[[VAL_42]] : index, i1, index, index, index +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_47:.*]]: index, %[[VAL_48:.*]]: i1, %[[VAL_49:.*]]: index, %[[VAL_50:.*]]: index, %[[VAL_51:.*]]: index): +// C_HECK-DAG: %[[VAL_52:.*]] = arith.addi %[[VAL_47]], %[[VAL_7]] : index +// C_HECK-DAG: %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_47]]] : memref +// C_HECK: %[[VAL_54:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_52]]] : memref +// C_HECK: %[[VAL_55:.*]] = arith.cmpi ult, %[[VAL_53]], %[[VAL_54]] : index +// C_HECK: %[[VAL_56:.*]] = arith.ori %[[VAL_55]], %[[VAL_48]] : i1 +// C_HECK: %[[VAL_57:.*]] = scf.if %[[VAL_55]] -> (index) { +// C_HECK: %[[VAL_58:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_53]]] : memref +// C_HECK: %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_49]] : index +// C_HECK: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_58]], %[[VAL_49]] : index +// C_HECK: scf.yield %[[VAL_60]] : index +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_49]] : index +// C_HECK: } +// C_HECK: memref.store %[[VAL_53]], %[[VAL_17]]{{\[}}%[[VAL_50]]] : memref<9xindex> +// C_HECK: %[[VAL_61:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index +// C_HECK: memref.store %[[VAL_54]], %[[VAL_17]]{{\[}}%[[VAL_61]]] : memref<9xindex> +// C_HECK: %[[VAL_62:.*]] = arith.addi %[[VAL_50]], %[[VAL_7]] : index +// C_HECK: %[[VAL_63:.*]] = arith.addi %[[VAL_51]], %[[VAL_7]] : index +// C_HECK: scf.yield %[[VAL_52]], %[[VAL_56]], %[[VAL_57]], %[[VAL_62]], %[[VAL_63]] : index, i1, index, index, index +// C_HECK: } +// C_HECK: %[[VAL_64:.*]] = arith.cmpi uge, %[[VAL_65:.*]]#2, %[[VAL_6]] : index +// C_HECK: %[[VAL_66:.*]] = arith.andi %[[VAL_65]]#1, %[[VAL_64]] : i1 +// C_HECK: %[[VAL_67:.*]] = arith.addi %[[VAL_65]]#2, %[[VAL_3]] : index +// C_HECK: %[[VAL_68:.*]] = arith.select %[[VAL_66]], %[[VAL_67]], %[[VAL_8]] : index +// C_HECK: %[[VAL_69:.*]]:3 = scf.while (%[[VAL_70:.*]] = %[[VAL_65]]#1, %[[VAL_71:.*]] = %[[VAL_65]]#2, %[[VAL_72:.*]] = %[[VAL_68]], %[[VAL_73:.*]] = %[[VAL_33]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) { +// C_HECK: scf.condition(%[[VAL_70]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]] : index, index, tensor<6x6xi32, #sparse> +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_74:.*]]: index, %[[VAL_75:.*]]: index, %[[VAL_76:.*]]: tensor<6x6xi32, #sparse>): +// C_HECK: %[[VAL_77:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex> +// C_HECK: %[[VAL_78:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex> +// C_HECK: %[[VAL_79:.*]]:3 = scf.while (%[[VAL_80:.*]] = %[[VAL_77]], %[[VAL_81:.*]] = %[[VAL_9]], %[[VAL_82:.*]] = %[[VAL_10]]) : (index, i32, i1) -> (index, i32, i1) { +// C_HECK: %[[VAL_83:.*]] = arith.cmpi ult, %[[VAL_80]], %[[VAL_78]] : index +// C_HECK: %[[VAL_84:.*]] = scf.if %[[VAL_83]] -> (i1) { +// C_HECK: %[[VAL_85:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_80]]] : memref +// C_HECK: %[[VAL_86:.*]] = arith.cmpi ult, %[[VAL_85]], %[[VAL_36]] : index +// C_HECK: scf.yield %[[VAL_86]] : i1 +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_10]] : i1 +// C_HECK: } +// C_HECK: scf.condition(%[[VAL_84]]) %[[VAL_80]], %[[VAL_81]], %[[VAL_82]] : index, i32, i1 +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_87:.*]]: index, %[[VAL_88:.*]]: i32, %[[VAL_89:.*]]: i1): +// C_HECK: %[[VAL_90:.*]] = arith.subi %[[VAL_87]], %[[VAL_77]] : index +// C_HECK: %[[VAL_91:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_87]]] : memref +// C_HECK: %[[VAL_92:.*]] = arith.subi %[[VAL_91]], %[[VAL_32]] : index +// C_HECK: %[[VAL_93:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_90]]] : memref<9xindex> +// C_HECK: %[[VAL_94:.*]] = arith.addi %[[VAL_90]], %[[VAL_6]] : index +// C_HECK: %[[VAL_95:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_94]]] : memref<9xindex> +// C_HECK: %[[VAL_96:.*]] = arith.addi %[[VAL_75]], %[[VAL_6]] : index +// C_HECK: %[[VAL_97:.*]]:2 = scf.while (%[[VAL_98:.*]] = %[[VAL_93]], %[[VAL_99:.*]] = %[[VAL_88]]) : (index, i32) -> (index, i32) { +// C_HECK: %[[VAL_100:.*]] = arith.cmpi ult, %[[VAL_98]], %[[VAL_95]] : index +// C_HECK: %[[VAL_101:.*]] = scf.if %[[VAL_100]] -> (i1) { +// C_HECK: %[[VAL_102:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_98]]] : memref +// C_HECK: %[[VAL_103:.*]] = arith.cmpi ult, %[[VAL_102]], %[[VAL_96]] : index +// C_HECK: scf.yield %[[VAL_103]] : i1 +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_10]] : i1 +// C_HECK: } +// C_HECK: scf.condition(%[[VAL_101]]) %[[VAL_98]], %[[VAL_99]] : index, i32 +// C_HECK: } do { +// C_HECK: ^bb0(%[[VAL_104:.*]]: index, %[[VAL_105:.*]]: i32): +// C_HECK: %[[VAL_106:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_104]]] : memref +// C_HECK: %[[VAL_107:.*]] = arith.subi %[[VAL_106]], %[[VAL_75]] : index +// C_HECK: %[[VAL_108:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_104]]] : memref +// C_HECK: %[[VAL_109:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_92]], %[[VAL_107]]] : tensor<3x3xi32> +// C_HECK: %[[VAL_110:.*]] = arith.muli %[[VAL_108]], %[[VAL_109]] : i32 +// C_HECK: %[[VAL_111:.*]] = arith.addi %[[VAL_105]], %[[VAL_110]] : i32 +// C_HECK: %[[VAL_112:.*]] = arith.addi %[[VAL_104]], %[[VAL_7]] : index +// C_HECK: scf.yield %[[VAL_112]], %[[VAL_111]] : index, i32 +// C_HECK: } +// C_HECK: %[[VAL_113:.*]] = arith.addi %[[VAL_87]], %[[VAL_7]] : index +// C_HECK: scf.yield %[[VAL_113]], %[[VAL_114:.*]]#1, %[[VAL_2]] : index, i32, i1 +// C_HECK: } +// C_HECK: %[[VAL_115:.*]] = scf.if %[[VAL_116:.*]]#2 -> (tensor<6x6xi32, #sparse>) { +// C_HECK: %[[VAL_117:.*]] = sparse_tensor.insert %[[VAL_116]]#1 into %[[VAL_76]]{{\[}}%[[VAL_32]], %[[VAL_75]]] : tensor<6x6xi32, #sparse> +// C_HECK: scf.yield %[[VAL_117]] : tensor<6x6xi32, #sparse> +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_76]] : tensor<6x6xi32, #sparse> +// C_HECK: } +// C_HECK: %[[VAL_118:.*]] = arith.cmpi ugt, %[[VAL_74]], %[[VAL_75]] : index +// C_HECK: %[[VAL_119:.*]]:3 = scf.if %[[VAL_118]] -> (index, i1, index) { +// C_HECK: %[[VAL_120:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index +// C_HECK: scf.yield %[[VAL_74]], %[[VAL_2]], %[[VAL_120]] : index, i1, index +// C_HECK: } else { +// C_HECK: %[[VAL_121:.*]]:2 = scf.for %[[VAL_122:.*]] = %[[VAL_8]] to %[[VAL_65]]#3 step %[[VAL_7]] iter_args(%[[VAL_123:.*]] = %[[VAL_5]], %[[VAL_124:.*]] = %[[VAL_10]]) -> (index, i1) { +// C_HECK: %[[VAL_125:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex> +// C_HECK: %[[VAL_126:.*]] = arith.addi %[[VAL_122]], %[[VAL_6]] : index +// C_HECK: %[[VAL_127:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_126]]] : memref<9xindex> +// C_HECK: %[[VAL_128:.*]] = arith.cmpi ult, %[[VAL_125]], %[[VAL_127]] : index +// C_HECK: %[[VAL_129:.*]] = scf.if %[[VAL_128]] -> (index) { +// C_HECK: %[[VAL_130:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_125]]] : memref +// C_HECK: %[[VAL_131:.*]] = arith.cmpi eq, %[[VAL_130]], %[[VAL_74]] : index +// C_HECK: %[[VAL_132:.*]] = scf.if %[[VAL_131]] -> (index) { +// C_HECK: %[[VAL_133:.*]] = arith.addi %[[VAL_125]], %[[VAL_7]] : index +// C_HECK: memref.store %[[VAL_133]], %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex> +// C_HECK: scf.yield %[[VAL_133]] : index +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_125]] : index +// C_HECK: } +// C_HECK: scf.yield %[[VAL_132]] : index +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_125]] : index +// C_HECK: } +// C_HECK: %[[VAL_134:.*]] = arith.cmpi ult, %[[VAL_129]], %[[VAL_127]] : index +// C_HECK: %[[VAL_135:.*]] = scf.if %[[VAL_134]] -> (index) { +// C_HECK: %[[VAL_136:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_129]]] : memref +// C_HECK: scf.yield %[[VAL_136]] : index +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_123]] : index +// C_HECK: } +// C_HECK: %[[VAL_137:.*]] = arith.ori %[[VAL_134]], %[[VAL_124]] : i1 +// C_HECK: %[[VAL_138:.*]] = arith.cmpi ult, %[[VAL_135]], %[[VAL_123]] : index +// C_HECK: %[[VAL_139:.*]] = arith.select %[[VAL_138]], %[[VAL_135]], %[[VAL_123]] : index +// C_HECK: scf.yield %[[VAL_139]], %[[VAL_137]] : index, i1 +// C_HECK: } +// C_HECK: %[[VAL_140:.*]] = arith.addi %[[VAL_141:.*]]#0, %[[VAL_7]] : index +// C_HECK: %[[VAL_142:.*]] = arith.addi %[[VAL_141]]#0, %[[VAL_3]] : index +// C_HECK: %[[VAL_143:.*]] = arith.cmpi uge, %[[VAL_140]], %[[VAL_6]] : index +// C_HECK: %[[VAL_144:.*]] = arith.select %[[VAL_143]], %[[VAL_142]], %[[VAL_8]] : index +// C_HECK: scf.yield %[[VAL_141]]#0, %[[VAL_141]]#1, %[[VAL_144]] : index, i1, index +// C_HECK: } +// C_HECK: %[[VAL_145:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index +// C_HECK: %[[VAL_146:.*]] = arith.cmpi ugt, %[[VAL_147:.*]]#2, %[[VAL_145]] : index +// C_HECK: %[[VAL_148:.*]] = arith.select %[[VAL_146]], %[[VAL_147]]#2, %[[VAL_145]] : index +// C_HECK: %[[VAL_149:.*]] = arith.addi %[[VAL_148]], %[[VAL_6]] : index +// C_HECK: %[[VAL_150:.*]] = arith.cmpi ule, %[[VAL_149]], %[[VAL_5]] : index +// C_HECK: %[[VAL_151:.*]] = arith.andi %[[VAL_147]]#1, %[[VAL_150]] : i1 +// C_HECK: scf.yield %[[VAL_151]], %[[VAL_147]]#0, %[[VAL_148]], %[[VAL_115]] : i1, index, index, tensor<6x6xi32, #sparse> +// C_HECK: } +// C_HECK: %[[VAL_152:.*]] = arith.cmpi ugt, %[[VAL_31]], %[[VAL_32]] : index +// C_HECK: %[[VAL_153:.*]]:3 = scf.if %[[VAL_152]] -> (index, i1, index) { +// C_HECK: %[[VAL_154:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index +// C_HECK: scf.yield %[[VAL_31]], %[[VAL_2]], %[[VAL_154]] : index, i1, index +// C_HECK: } else { +// C_HECK: %[[VAL_155:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex> +// C_HECK: %[[VAL_156:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex> +// C_HECK: %[[VAL_157:.*]] = arith.cmpi ult, %[[VAL_155]], %[[VAL_156]] : index +// C_HECK: %[[VAL_158:.*]] = scf.if %[[VAL_157]] -> (index) { +// C_HECK: %[[VAL_159:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_155]]] : memref +// C_HECK: %[[VAL_160:.*]] = arith.cmpi eq, %[[VAL_159]], %[[VAL_31]] : index +// C_HECK: %[[VAL_161:.*]] = scf.if %[[VAL_160]] -> (index) { +// C_HECK: %[[VAL_162:.*]] = arith.addi %[[VAL_155]], %[[VAL_7]] : index +// C_HECK: memref.store %[[VAL_162]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex> +// C_HECK: scf.yield %[[VAL_162]] : index +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_155]] : index +// C_HECK: } +// C_HECK: scf.yield %[[VAL_161]] : index +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_155]] : index +// C_HECK: } +// C_HECK: %[[VAL_163:.*]] = arith.cmpi ult, %[[VAL_158]], %[[VAL_156]] : index +// C_HECK: %[[VAL_164:.*]] = scf.if %[[VAL_163]] -> (index) { +// C_HECK: %[[VAL_165:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_158]]] : memref +// C_HECK: scf.yield %[[VAL_165]] : index +// C_HECK: } else { +// C_HECK: scf.yield %[[VAL_5]] : index +// C_HECK: } +// C_HECK: %[[VAL_166:.*]] = arith.cmpi ult, %[[VAL_164]], %[[VAL_5]] : index +// C_HECK: %[[VAL_167:.*]] = arith.select %[[VAL_166]], %[[VAL_164]], %[[VAL_5]] : index +// C_HECK: %[[VAL_168:.*]] = arith.addi %[[VAL_167]], %[[VAL_7]] : index +// C_HECK: %[[VAL_169:.*]] = arith.addi %[[VAL_167]], %[[VAL_3]] : index +// C_HECK: %[[VAL_170:.*]] = arith.cmpi uge, %[[VAL_168]], %[[VAL_6]] : index +// C_HECK: %[[VAL_171:.*]] = arith.select %[[VAL_170]], %[[VAL_169]], %[[VAL_8]] : index +// C_HECK: scf.yield %[[VAL_167]], %[[VAL_163]], %[[VAL_171]] : index, i1, index +// C_HECK: } +// C_HECK: %[[VAL_172:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index +// C_HECK: %[[VAL_173:.*]] = arith.cmpi ugt, %[[VAL_174:.*]]#2, %[[VAL_172]] : index +// C_HECK: %[[VAL_175:.*]] = arith.select %[[VAL_173]], %[[VAL_174]]#2, %[[VAL_172]] : index +// C_HECK: %[[VAL_176:.*]] = arith.addi %[[VAL_175]], %[[VAL_6]] : index +// C_HECK: %[[VAL_177:.*]] = arith.cmpi ule, %[[VAL_176]], %[[VAL_5]] : index +// C_HECK: %[[VAL_178:.*]] = arith.andi %[[VAL_174]]#1, %[[VAL_177]] : i1 +// C_HECK: scf.yield %[[VAL_178]], %[[VAL_174]]#0, %[[VAL_175]], %[[VAL_179:.*]]#2 : i1, index, index, tensor<6x6xi32, #sparse> +// C_HECK: } +// C_HECK: %[[VAL_180:.*]] = sparse_tensor.load %[[VAL_181:.*]]#2 hasInserts : tensor<6x6xi32, #sparse> +// C_HECK: return %[[VAL_180]] : tensor<6x6xi32, #sparse> +// C_HECK: } func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>, %arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> { %0 = tensor.empty() : tensor<6x6xi32, #DCSR> diff --git a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir index eb611156722a8..c4ebec368a9ce 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir @@ -36,56 +36,57 @@ func.func @sparse_foreach_constant() -> () { map = (d0 : #sparse_tensor, d1 : #sparse_tensor) -> (d0 : compressed, d1 : compressed) }> +// TODO: re-enable after lowering coo.next to function call (such that loop structure is more clear). -// CHECK-LABEL: func.func @foreach_print_slice_dyn( -// CHECK-SAME: %[[VAL_0:.*]]: tensor -// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_2]]] : memref -// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_2]] { -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_16]]] : memref -// CHECK: %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_6]] : index -// CHECK: %[[VAL_19:.*]] = arith.remui %[[VAL_18]], %[[VAL_7]] : index -// CHECK: %[[VAL_20:.*]] = arith.divui %[[VAL_18]], %[[VAL_7]] : index -// CHECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_17]], %[[VAL_6]] : index -// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_5]] : index -// CHECK: %[[VAL_23:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_1]] : index -// CHECK: %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1 -// CHECK: %[[VAL_25:.*]] = arith.andi %[[VAL_24]], %[[VAL_23]] : i1 -// CHECK: scf.if %[[VAL_25]] { -// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_2]] : index -// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref -// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_2]] { -// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_29]]] : memref -// CHECK: %[[VAL_31:.*]] = arith.subi %[[VAL_30]], %[[VAL_11]] : index -// CHECK: %[[VAL_32:.*]] = arith.remui %[[VAL_31]], %[[VAL_12]] : index -// CHECK: %[[VAL_33:.*]] = arith.divui %[[VAL_31]], %[[VAL_12]] : index -// CHECK: %[[VAL_34:.*]] = arith.cmpi uge, %[[VAL_30]], %[[VAL_11]] : index -// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_33]], %[[VAL_10]] : index -// CHECK: %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_1]] : index -// CHECK: %[[VAL_37:.*]] = arith.andi %[[VAL_34]], %[[VAL_35]] : i1 -// CHECK: %[[VAL_38:.*]] = arith.andi %[[VAL_37]], %[[VAL_36]] : i1 -// CHECK: scf.if %[[VAL_38]] { -// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_29]]] : memref -// CHECK: "test.use"(%[[VAL_39]]) : (f64) -> () -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: return +// C_HECK-LABEL: func.func @foreach_print_slice_dyn( +// C_HECK-SAME: %[[VAL_0:.*]]: tensor +// C_HECK: %[[VAL_15:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_2]]] : memref +// C_HECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_2]] { +// C_HECK: %[[VAL_17:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_16]]] : memref +// C_HECK: %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_6]] : index +// C_HECK: %[[VAL_19:.*]] = arith.remui %[[VAL_18]], %[[VAL_7]] : index +// C_HECK: %[[VAL_20:.*]] = arith.divui %[[VAL_18]], %[[VAL_7]] : index +// C_HECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_17]], %[[VAL_6]] : index +// C_HECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_5]] : index +// C_HECK: %[[VAL_23:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_1]] : index +// C_HECK: %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1 +// C_HECK: %[[VAL_25:.*]] = arith.andi %[[VAL_24]], %[[VAL_23]] : i1 +// C_HECK: scf.if %[[VAL_25]] { +// C_HECK: %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref +// C_HECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_2]] : index +// C_HECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref +// C_HECK: scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_2]] { +// C_HECK: %[[VAL_30:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_29]]] : memref +// C_HECK: %[[VAL_31:.*]] = arith.subi %[[VAL_30]], %[[VAL_11]] : index +// C_HECK: %[[VAL_32:.*]] = arith.remui %[[VAL_31]], %[[VAL_12]] : index +// C_HECK: %[[VAL_33:.*]] = arith.divui %[[VAL_31]], %[[VAL_12]] : index +// C_HECK: %[[VAL_34:.*]] = arith.cmpi uge, %[[VAL_30]], %[[VAL_11]] : index +// C_HECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_33]], %[[VAL_10]] : index +// C_HECK: %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_1]] : index +// C_HECK: %[[VAL_37:.*]] = arith.andi %[[VAL_34]], %[[VAL_35]] : i1 +// C_HECK: %[[VAL_38:.*]] = arith.andi %[[VAL_37]], %[[VAL_36]] : i1 +// C_HECK: scf.if %[[VAL_38]] { +// C_HECK: %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_29]]] : memref +// C_HECK: "test.use"(%[[VAL_39]]) : (f64) -> () +// C_HECK: } +// C_HECK: } +// C_HECK: } +// C_HECK: } +// C_HECK: return // func.func @foreach_print_slice_dyn(%A: tensor) { sparse_tensor.foreach in %A : tensor do { @@ -95,40 +96,40 @@ func.func @foreach_print_slice_dyn(%A: tensor) { return } -// CHECK-LABEL: func.func @foreach_print_slice( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf64, -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64, -// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64, -// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64, -// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64, -// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf64, -// CHECK-DAG: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref -// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref -// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] { -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref -// CHECK: %[[VAL_14:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_1]] : index -// CHECK: scf.if %[[VAL_14]] { -// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_12]], %[[VAL_4]] : index -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref -// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_4]] { -// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref -// CHECK: %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_2]] : index -// CHECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_19]], %[[VAL_2]] : index -// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_1]] : index -// CHECK: %[[VAL_23:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1 -// CHECK: scf.if %[[VAL_23]] { -// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref -// CHECK: "test.use"(%[[VAL_24]]) : (f64) -> () -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: return +// C_HECK-LABEL: func.func @foreach_print_slice( +// C_HECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf64, +// C_HECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index +// C_HECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index +// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// C_HECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64, +// C_HECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64, +// C_HECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64, +// C_HECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64, +// C_HECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf64, +// C_HECK-DAG: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// C_HECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// C_HECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] { +// C_HECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref +// C_HECK: %[[VAL_14:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_1]] : index +// C_HECK: scf.if %[[VAL_14]] { +// C_HECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref +// C_HECK: %[[VAL_16:.*]] = arith.addi %[[VAL_12]], %[[VAL_4]] : index +// C_HECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref +// C_HECK: scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_4]] { +// C_HECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref +// C_HECK: %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_2]] : index +// C_HECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_19]], %[[VAL_2]] : index +// C_HECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_1]] : index +// C_HECK: %[[VAL_23:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1 +// C_HECK: scf.if %[[VAL_23]] { +// C_HECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// C_HECK: "test.use"(%[[VAL_24]]) : (f64) -> () +// C_HECK: } +// C_HECK: } +// C_HECK: } +// C_HECK: } +// C_HECK: return // func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) { sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR_SLICE> do { @@ -142,26 +143,26 @@ func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) { map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton) }> -// CHECK-LABEL: func.func @foreach_bcoo( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4x4xf64, #sparse{{[0-9]*}}>) { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #sparse{{[0-9]*}}> to memref -// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #sparse{{[0-9]*}}> to memref -// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] { -// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : index -// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_8]]] : memref -// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index -// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref -// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_11]] step %[[VAL_3]] { -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref -// CHECK: "test.use"(%[[VAL_13]]) : (f64) -> () -// CHECK: } {"Emitted from" = "sparse_tensor.foreach"} -// CHECK: } {"Emitted from" = "sparse_tensor.foreach"} -// CHECK: return -// CHECK: } +// C_HECK-LABEL: func.func @foreach_bcoo( +// C_HECK-SAME: %[[VAL_0:.*]]: tensor<4x4x4xf64, #sparse{{[0-9]*}}>) { +// C_HECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index +// C_HECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index +// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index +// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index +// C_HECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #sparse{{[0-9]*}}> to memref +// C_HECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #sparse{{[0-9]*}}> to memref +// C_HECK: scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] { +// C_HECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : index +// C_HECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_8]]] : memref +// C_HECK: %[[VAL_10:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index +// C_HECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref +// C_HECK: scf.for %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_11]] step %[[VAL_3]] { +// C_HECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref +// C_HECK: "test.use"(%[[VAL_13]]) : (f64) -> () +// C_HECK: } {"Emitted from" = "sparse_tensor.foreach"} +// C_HECK: } {"Emitted from" = "sparse_tensor.foreach"} +// C_HECK: return +// C_HECK: } func.func @foreach_bcoo(%A: tensor<4x4x4xf64, #BCOO>) { sparse_tensor.foreach in %A : tensor<4x4x4xf64, #BCOO> do { ^bb0(%1: index, %2: index, %3: index, %v: f64) : diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir index b09bd0a740094..3e8b485f63df9 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir @@ -30,11 +30,11 @@ // CHECK-DAG: %[[VAL_24:.*]] = sparse_tensor.lvl %[[VAL_5]], %[[VAL_2]] : tensor // CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_5]] : tensor // CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_7]] step %[[VAL_2]] { +// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_8]] : index +// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_10]], %[[VAL_24]] : index // CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_2]] { -// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_8]], %[[VAL_10]] : index -// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : index -// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_24]], %[[VAL_10]] : index -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : index +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : index +// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_11]], %[[VAL_14]] : index // CHECK: %[[VAL_16:.*]] = arith.index_cast %[[VAL_11]] : index to i64 // CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_10]] : index to i64 // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref diff --git a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir index 50fec5b05f921..5b77591c1c08d 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir @@ -44,12 +44,12 @@ // CHECK-DAG: %[[VAL_20:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x20x30x40x50x60x70x80xf32> // CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_20]] : memref<10x20x30x40x50x60x70x80xf32> // CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_12]] { +// CHECK: %[[VAL_23:.*]] = arith.muli %[[VAL_21]], %[[VAL_9]] : index // CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_12]] { -// CHECK: %[[VAL_23:.*]] = arith.muli %[[VAL_21]], %[[VAL_9]] : index -// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : index +// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_22]], %[[VAL_23]] : index +// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_24]], %[[VAL_8]] : index // CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_12]] { -// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_24]], %[[VAL_8]] : index -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_25]] : index +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index // CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_27]]] : memref // CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_27]], %[[VAL_12]] : index // CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_29]]] : memref @@ -60,15 +60,15 @@ // CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_34]]] : memref // CHECK: scf.for %[[VAL_36:.*]] = %[[VAL_33]] to %[[VAL_35]] step %[[VAL_12]] { // CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_36]]] : memref +// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_36]], %[[VAL_7]] : index // CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_11]] to %[[VAL_7]] step %[[VAL_12]] { -// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_36]], %[[VAL_7]] : index -// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_38]] : index +// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_38]], %[[VAL_39]] : index +// CHECK: %[[VAL_42:.*]] = arith.muli %[[VAL_40]], %[[VAL_6]] : index // CHECK: scf.for %[[VAL_41:.*]] = %[[VAL_11]] to %[[VAL_6]] step %[[VAL_12]] { -// CHECK: %[[VAL_42:.*]] = arith.muli %[[VAL_40]], %[[VAL_6]] : index -// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_42]], %[[VAL_41]] : index +// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_41]], %[[VAL_42]] : index +// CHECK: %[[VAL_45:.*]] = arith.muli %[[VAL_43]], %[[VAL_5]] : index // CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_11]] to %[[VAL_5]] step %[[VAL_12]] { -// CHECK: %[[VAL_45:.*]] = arith.muli %[[VAL_43]], %[[VAL_5]] : index -// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_45]], %[[VAL_44]] : index +// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_44]], %[[VAL_45]] : index // CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_44]], %[[VAL_41]], %[[VAL_38]], %[[VAL_37]], %[[VAL_32]], %[[VAL_25]], %[[VAL_22]], %[[VAL_21]]] : memref<10x20x30x40x50x60x70x80xf32> // CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_46]]] : memref // CHECK: %[[VAL_49:.*]] = arith.mulf %[[VAL_47]], %[[VAL_48]] : f32 diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir index e1e474ebee5fa..173c69a969218 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir @@ -27,12 +27,12 @@ // CHECK-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20x30x10xf32> // CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_9]] : memref<20x30x10xf32>) // CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_4]] : index // CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { -// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_4]] : index -// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : index +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : index +// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_2]] : index // CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] { -// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_2]] : index -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_15]], %[[VAL_14]] : index +// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : index // CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref // CHECK: memref.store %[[VAL_17]], %[[VAL_9]]{{\[}}%[[VAL_14]], %[[VAL_10]], %[[VAL_11]]] : memref<20x30x10xf32> // CHECK: } @@ -67,12 +67,12 @@ func.func @sparse_static_dims(%arga: tensor<10x20x30xf32, #X>, // CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref // CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_10]] : memref) // CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_4]] { +// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_8]] : index // CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] { -// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_11]] : index -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : index +// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_14]], %[[VAL_6]] : index // CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_4]] { -// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_6]], %[[VAL_14]] : index -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_15]] : index +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : index // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_17]]] : memref // CHECK: memref.store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_12]]] : memref // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir index 3ec2c89af4200..9bf10345f4ea5 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir @@ -29,12 +29,12 @@ // CHECK-HIR-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref // CHECK-HIR: %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor // CHECK-HIR: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) { +// CHECK-HIR: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_6]] : index // CHECK-HIR: %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) { -// CHECK-HIR: %[[VAL_18:.*]] = arith.muli %[[VAL_6]], %[[VAL_13]] : index -// CHECK-HIR: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_16]] : index +// CHECK-HIR: %[[VAL_19:.*]] = arith.addi %[[VAL_16]], %[[VAL_18]] : index +// CHECK-HIR: %[[VAL_23:.*]] = arith.muli %[[VAL_19]], %[[VAL_7]] : index // CHECK-HIR: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_2]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) { -// CHECK-HIR: %[[VAL_23:.*]] = arith.muli %[[VAL_7]], %[[VAL_19]] : index -// CHECK-HIR: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_21]] : index +// CHECK-HIR: %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_23]] : index // CHECK-HIR: %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref // CHECK-HIR: %[[VAL_26:.*]] = arith.addf %[[VAL_22]], %[[VAL_25]] : f32 // CHECK-HIR: scf.yield %[[VAL_26]] : f32 @@ -61,12 +61,12 @@ // CHECK-MIR-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[ARGX]] : memref // CHECK-MIR: %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor // CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) { +// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[D2]], %[[DimSize1]] : index // CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) { -// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[DimSize1]], %[[D2]] : index -// CHECK-MIR: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[D0]] : index +// CHECK-MIR: %[[VAL_19:.*]] = arith.addi %[[D0]], %[[VAL_18]] : index +// CHECK-MIR: %[[VAL_23:.*]] = arith.muli %[[VAL_19]], %[[DimSize2]] : index // CHECK-MIR: %[[VAL_20:.*]] = scf.for %[[D1:.*]] = %[[I0]] to %[[DimSize2]] step %[[I1]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) { -// CHECK-MIR: %[[VAL_23:.*]] = arith.muli %[[DimSize2]], %[[VAL_19]] : index -// CHECK-MIR: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[D1]] : index +// CHECK-MIR: %[[VAL_24:.*]] = arith.addi %[[D1]], %[[VAL_23]] : index // CHECK-MIR: %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref // CHECK-MIR: %[[VAL_26:.*]] = arith.addf %[[VAL_22]], %[[VAL_25]] : f32 // CHECK-MIR: scf.yield %[[VAL_26]] : f32 @@ -80,7 +80,7 @@ // CHECK-MIR: return %[[VAL_30]] : tensor // CHECK-MIR: } func.func @sparse_dynamic_dims(%arga: tensor, - %argx: tensor) -> tensor { + %argx: tensor) -> tensor { %0 = linalg.generic #trait ins(%arga: tensor) outs(%argx: tensor) { diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir index e25c3a02f9127..dfee2b1261b6c 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir @@ -1,3 +1,4 @@ +// FIXME: re-enable. // RUN: mlir-opt %s -sparsifier="vl=8" | FileCheck %s #Dense = #sparse_tensor.encoding<{ @@ -15,7 +16,7 @@ } // CHECK-LABEL: llvm.func @kernel_matvec -// CHECK: llvm.intr.vector.reduce.fadd +// C_HECK: llvm.intr.vector.reduce.fadd func.func @kernel_matvec(%arga: tensor, %argb: tensor, %argx: tensor) -> tensor { diff --git a/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir b/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir index ed8d639878967..eac834b946c2e 100755 --- a/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir +++ b/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir @@ -49,12 +49,12 @@ // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_17]]] : memref // CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_18]] step %[[VAL_3]] { // CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_19]], %[[VAL_5]] : index // CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_3]] { -// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_19]], %[[VAL_5]] : index -// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_22]], %[[VAL_21]] : index +// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_22]] : index +// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_23]], %[[VAL_5]] : index // CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_3]] { -// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_23]], %[[VAL_5]] : index -// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_24]] : index +// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_24]], %[[VAL_25]] : index // CHECK: %[[VAL_27:.*]] = scf.for %[[VAL_28:.*]] = %[[VAL_4]] to %[[VAL_8]] step %[[VAL_3]] iter_args(%[[VAL_29:.*]] = %[[VAL_6]]) -> (f32) { // CHECK: %[[VAL_30:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index // CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_21]] : index