@@ -357,91 +357,24 @@ struct ScalarizeWhileConditionProducers
357357};
358358} // namespace
359359
360- // / Check if the add op is a valid induction variable increment.
361- static bool matchInductionVariableIncrement (stablehlo::AddOp op,
362- scf::WhileOp parentWhile) {
363- Value lhs = op.getLhs ();
364- Value rhs = op.getRhs ();
365- if (matchPattern (lhs, m_Constant ()) || matchPattern (rhs, m_Constant ()))
366- return true ;
367- Region *whileRegion = parentWhile->getParentRegion ();
368- return lhs.getParentRegion ()->isAncestor (whileRegion) ||
369- rhs.getParentRegion ()->isAncestor (whileRegion);
370- }
371-
372360namespace {
373361// / Scalarize any `stablehlo.add` operations in the 'after' region of
374362// / a scf.while op.
375- struct ScalarizeStablehloAddOp : public OpRewritePattern <stablehlo::AddOp > {
376- using OpRewritePattern<stablehlo::AddOp >::OpRewritePattern;
377- LogicalResult matchAndRewrite (stablehlo::AddOp op,
363+ struct ScalarizeStablehloAddOp : public OpRewritePattern <tensor::ExtractOp > {
364+ using OpRewritePattern<tensor::ExtractOp >::OpRewritePattern;
365+ LogicalResult matchAndRewrite (tensor::ExtractOp op,
378366 PatternRewriter &rewriter) const override {
379- if (!op->hasOneUse ())
380- return rewriter.notifyMatchFailure (
381- op, " op has more than one use, cannot scalarize" );
382- auto extractUser = dyn_cast<tensor::ExtractOp>(*op->user_begin ());
383- if (!extractUser || !extractUser->hasOneUse () ||
384- !isa<scf::YieldOp>(*extractUser->user_begin ()))
385- return rewriter.notifyMatchFailure (
386- op, " op result is not extracted and yielded from region" );
387-
388- auto scfWhile = extractUser->getParentOfType <scf::WhileOp>();
389- if (!scfWhile || scfWhile.getAfter () != op->getParentRegion ())
390- return rewriter.notifyMatchFailure (
391- op, " op is not in the after region of a scf.while op" );
392-
393- // One operand must be a constant or defined above in order to be
394- // considered as the loop step.
395- if (!matchInductionVariableIncrement (op, scfWhile))
396- return rewriter.notifyMatchFailure (
397- op, " op is not a valid induction variable increment" );
398-
399- // Find a block argument that has been scalarized.
400- auto findBlockArgument = [](Value v) -> BlockArgument {
401- Value source{};
402- if (matchPattern (v,
403- m_Op<tensor::FromElementsOp>(matchers::m_Any (&source))))
404- return dyn_cast<BlockArgument>(source);
405- return {};
406- };
407- BlockArgument arg = findBlockArgument (op.getLhs ());
408- if (!arg)
409- arg = findBlockArgument (op.getRhs ());
410- if (!arg || arg.getParentRegion () != scfWhile.getAfter ())
411- return rewriter.notifyMatchFailure (
412- op, " could not find block argument in after region" );
413-
414- // Check that the corresponding block argument in the `before` region feeds
415- // into a comparison.
416- Region &before = scfWhile.getBefore ();
417- if (arg.getArgNumber () >= before.getNumArguments () ||
418- before.getArgument (arg.getArgNumber ()).getType () != arg.getType ())
419- return rewriter.notifyMatchFailure (
420- op, " could not find block argument in before region" );
421- auto beforeArg = before.getArgument (arg.getArgNumber ());
422- if (!llvm::all_of (beforeArg.getUsers (),
423- llvm::IsaPred<scf::ConditionOp, arith::CmpIOp>))
424- return rewriter.notifyMatchFailure (
425- op, " block argument is not consumed by a comparison op" );
426-
427- // Check that the before region has a block argument in the same position
428- // and is consumed by a comparison op.
429- RankedTensorType rtt = op.getType ();
430- Type elementType = rtt.getElementType ();
431- if (!rtt.hasStaticShape () || rtt.getNumElements () != 1 ||
432- !elementType.isSignlessIntOrIndex ())
433- return rewriter.notifyMatchFailure (op, " op is not a scalar add op" );
434-
435- auto scalarOperands = llvm::map_to_vector (op.getOperands (), [&](Value v) {
436- return extractScalarFromTensorValue (rewriter, v);
437- });
438-
439- auto scalarAdd =
440- stablehlo::StablehloOpToStdScalarOp::mapOp<stablehlo::AddOp>(
441- op, elementType, scalarOperands, &rewriter);
442- auto fromElements =
443- rewriter.create <tensor::FromElementsOp>(op.getLoc (), rtt, scalarAdd);
444- rewriter.replaceOp (op, fromElements);
367+ auto addOp = op.getTensor ().getDefiningOp <stablehlo::AddOp>();
368+ if (!addOp || !addOp.getType ().hasStaticShape () ||
369+ addOp.getType ().getNumElements () != 1 )
370+ return failure ();
371+ rewriter.setInsertionPoint (addOp);
372+ SmallVector<Value> scalarOperands;
373+ for (Value operand : addOp.getOperands ())
374+ scalarOperands.push_back (extractScalarFromTensorValue (rewriter, operand));
375+ auto scalarAdd = stablehlo::StablehloOpToStdScalarOp::mapOp (
376+ addOp, addOp.getType ().getElementType (), scalarOperands, &rewriter);
377+ rewriter.replaceOp (op, scalarAdd);
445378 return success ();
446379 }
447380};
@@ -453,9 +386,7 @@ struct ScalarizeStablehloAddOp : public OpRewritePattern<stablehlo::AddOp> {
453386// / for loop. It will have a user like `stablehlo.compare` or `tensor.extract`.
454387static bool shouldScalarizeWhileBeforeArg (BlockArgument arg, Value initOperand,
455388 Value yieldOperand) {
456- return cast<RankedTensorType>(arg.getType ())
457- .getElementType ()
458- .isSignlessIntOrIndex () &&
389+ return cast<RankedTensorType>(arg.getType ()).getElementType () &&
459390 llvm::count_if (arg.getUsers (),
460391 llvm::IsaPred<stablehlo::CompareOp, arith::CmpIOp,
461392 tensor::ExtractOp>) >= 1 ;
@@ -473,17 +404,17 @@ static bool shouldScalarizeWhileAfterArg(BlockArgument arg, Value condOperand,
473404 if (before.getNumArguments () <= arg.getArgNumber () ||
474405 before.getArgument (arg.getArgNumber ()).getType () !=
475406 rtt.getElementType () ||
476- !llvm::all_of (before.getArgument (arg.getArgNumber ()).getUsers (),
477- llvm::IsaPred<arith::CmpIOp, tensor::FromElementsOp>))
407+ !llvm::all_of (
408+ before.getArgument (arg.getArgNumber ()).getUsers (),
409+ llvm::IsaPred<arith::CmpIOp, arith::CmpFOp, tensor::FromElementsOp>))
478410 return false ;
479411
480412 auto condProducer = condOperand.getDefiningOp <tensor::FromElementsOp>();
481413 if (!condProducer || condProducer.getElements ().size () != 1 ||
482414 !isa<BlockArgument>(condProducer.getElements ().front ()))
483415 return false ;
484416
485- return rtt.getElementType ().isSignlessIntOrIndex () &&
486- llvm::count_if (arg.getUsers (),
417+ return llvm::count_if (arg.getUsers (),
487418 llvm::IsaPred<stablehlo::AddOp, arith::AddIOp,
488419 tensor::ExtractOp>) >= 1 ;
489420}
0 commit comments