diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index ada9539e87121..70092908d961f 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -17,6 +17,7 @@ #include #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" namespace mlir { class OpBuilder; @@ -259,18 +260,18 @@ struct BufferizationOptions { std::function; /// Initializer function for analysis state. using AnalysisStateInitFn = std::function; - /// Tensor -> MemRef type converter. - /// Parameters: tensor type, memory space, func op, bufferization options + /// TensorLike -> BufferLike type converter. + /// Parameters: tensor like type, memory space, func op, bufferization options using FunctionArgTypeConverterFn = - std::function; - /// Tensor -> MemRef type converter. + /// TensorLike -> BufferLike type converter. /// Parameters: Value, memory space, bufferization options - using UnknownTypeConverterFn = std::function; // Produce a MemorySpace attribute from a tensor type using DefaultMemorySpaceFn = - std::function(TensorType t)>; + std::function(TensorLikeType t)>; BufferizationOptions(); @@ -360,7 +361,7 @@ struct BufferizationOptions { // Returning std::nullopt will cause bufferization to fail (useful to indicate // failure to determine memory space for a tensor type). DefaultMemorySpaceFn defaultMemorySpaceFn = - [](TensorType t) -> std::optional { return Attribute(); }; + [](TensorLikeType t) -> std::optional { return Attribute(); }; /// If set to `true`, the analysis is skipped. A buffer is copied before every /// write. This flag cannot be used together with `testAnalysisOnly = true`. @@ -600,7 +601,7 @@ FailureOr getBuffer(RewriterBase &rewriter, Value value, /// IR, this function can be used. /// /// This function is a wrapper around BufferizableOpInterface::getBufferType. -FailureOr getBufferType(Value value, +FailureOr getBufferType(Value value, const BufferizationOptions &options); /// Return the buffer type for a given Value (tensor) after bufferization @@ -613,7 +614,7 @@ FailureOr getBufferType(Value value, /// IR, this function can be used. /// /// This function is a wrapper around `BufferizableOpInterface::getBufferType`. -FailureOr getBufferType(Value value, +FailureOr getBufferType(Value value, const BufferizationOptions &options, SmallVector &invocationStack); @@ -693,7 +694,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value, /// This is the default implementation of /// BufferizableOpInterface::getBufferType. Should not be called from other /// places. -FailureOr +FailureOr defaultGetBufferType(Value value, const BufferizationOptions &options, SmallVector &invocationStack); diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td index 95022d7d665d2..1de1742fab81a 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { Note: This interface method should never be called directly from user code. Always use `bufferization::getBufferType`. }], - /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>", + /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>", /*methodName=*/"getBufferType", /*args=*/(ins "::mlir::Value":$value, "const ::mlir::bufferization::BufferizationOptions &":$options, diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index fad78a63444b9..81ce0f3fb650b 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizationBase.td" +include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -109,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor", AliasingValueList getAliasingValues( OpOperand &opOperand, const AnalysisState &state); - FailureOr getBufferType( + FailureOr getBufferType( Value value, const BufferizationOptions &options, SmallVector &invocationStack); @@ -438,11 +439,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ away. However, such IR is no longer bufferizable with One-Shot Bufferize. }]; - let arguments = (ins Arg]>:$memref, UnitAttr:$restrict, UnitAttr:$writable); - let results = (outs AnyTensor:$result); + let results = (outs Bufferization_TensorLikeTypeInterface:$result); let extraClassDeclaration = [{ /// The result of a to_tensor is always a tensor. @@ -465,10 +466,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ bool isWritable(Value value, const AnalysisState &state); - FailureOr getBufferType( + FailureOr getBufferType( Value value, const BufferizationOptions &options, SmallVector &invocationStack) { - return ::llvm::cast(getMemref().getType()); + return ::llvm::cast(getMemref().getType()); } }]; @@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ // ToMemrefOp //===----------------------------------------------------------------------===// +// TODO: rename to "to_buffer" def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [ BufferizableOpInterface, SameOperandsAndResultShape, @@ -519,8 +521,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [ the returned buffer) will not be written to. }]; - let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only); - let results = (outs AnyRankedOrUnrankedMemRef:$memref); + let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, + UnitAttr:$read_only); + let results = (outs Bufferization_BufferLikeTypeInterface:$memref); let extraClassDeclaration = [{ //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h index 5faa1479ee542..290f1298f2501 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h @@ -13,6 +13,7 @@ // Bufferization Type Interfaces //===----------------------------------------------------------------------===// +#include "mlir/IR/Attributes.h" // mlir::Attribute #include "mlir/IR/Types.h" #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc" diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td index f19224a295648..c053a6bdc1a91 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td @@ -33,10 +33,17 @@ def Bufferization_BufferLikeTypeInterface let description = [{ Indicates that this type is a buffer type (similarly to a MLIR builtin memref) for bufferization purposes. - - The interface currently has no methods as it is used by types to opt into - being supported by the bufferization procedures. }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns the memory space in which data referred to by this buffer resides. + }], + /*retType=*/"::mlir::Attribute", + /*methodName=*/"getMemorySpace" + >, + ]; } #endif // BUFFERIZATION_TYPE_INTERFACES diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h index 78109770efab7..89eb65c4a0942 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h @@ -32,7 +32,7 @@ template struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel : public BufferizableOpInterface::ExternalModel { - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { // Note: The user may want to override this function for OpResults in @@ -46,7 +46,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel // operand types of all forwarded values. If these are all the same type, // take that type. Otherwise, take only the memory space and fall back to a // buffer type with a fully dynamic layout map. - BaseMemRefType bufferType; + BufferLikeType bufferType; auto tensorType = cast(value.getType()); for (OpOperand *opOperand : detail::getCallerOpOperands(cast(value))) { @@ -59,13 +59,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel continue; // Compute the bufferized type of the forwarded operand. - BaseMemRefType callerType; - if (auto memrefType = - dyn_cast(opOperand->get().getType())) { + BufferLikeType callerType; + if (auto bufferLikeType = + dyn_cast(opOperand->get().getType())) { // The operand was already bufferized. Take its type directly. - callerType = memrefType; + callerType = bufferLikeType; } else { - FailureOr maybeCallerType = + FailureOr maybeCallerType = bufferization::getBufferType(opOperand->get(), options, invocationStack); if (failed(maybeCallerType)) @@ -86,14 +86,20 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel // of the earlier forwarded operands, fall back to a buffer type with a // fully dynamic layout map. #ifndef NDEBUG + assert(mlir::isa(bufferType) && + mlir::isa(callerType) && "expected memrefs"); + auto memrefType = mlir::cast(bufferType); + auto callerMemrefType = mlir::cast(callerType); + if (auto rankedTensorType = dyn_cast(tensorType)) { - assert(bufferType.hasRank() && callerType.hasRank() && + assert(memrefType.hasRank() && callerMemrefType.hasRank() && "expected ranked memrefs"); - assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(), - rankedTensorType.getShape()}) && - "expected same shape"); + assert( + llvm::all_equal({memrefType.getShape(), callerMemrefType.getShape(), + rankedTensorType.getShape()}) && + "expected same shape"); } else { - assert(!bufferType.hasRank() && !callerType.hasRank() && + assert(!memrefType.hasRank() && !callerMemrefType.hasRank() && "expected unranked memrefs"); } #endif // NDEBUG @@ -102,8 +108,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel return op->emitOpError("incoming operands of block argument have " "inconsistent memory spaces"); - bufferType = getMemRefTypeWithFullyDynamicLayout( - tensorType, bufferType.getMemorySpace()); + bufferType = + mlir::cast(getMemRefTypeWithFullyDynamicLayout( + tensorType, bufferType.getMemorySpace())); } if (!bufferType) diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index 5e69a98db8f1e..433757192bfd1 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -26,7 +26,7 @@ struct ConstantOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto constantOp = cast(op); - auto type = dyn_cast(constantOp.getType()); + auto type = dyn_cast(constantOp.getType()); // Only ranked tensors are supported. if (!type) @@ -176,7 +176,7 @@ struct SelectOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto selectOp = cast(op); @@ -195,10 +195,11 @@ struct SelectOpInterface // If the buffers have different types, they differ only in their layout // map. auto memrefType = llvm::cast(*trueType); - return getMemRefTypeWithFullyDynamicLayout( - RankedTensorType::get(memrefType.getShape(), - memrefType.getElementType()), - memrefType.getMemorySpace()); + return mlir::cast( + getMemRefTypeWithFullyDynamicLayout( + RankedTensorType::get(memrefType.getShape(), + memrefType.getElementType()), + memrefType.getMemorySpace())); } }; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 99ffa62c41a4d..82ff1bdfe5fd7 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -206,12 +206,13 @@ FailureOr bufferization::allocateTensorForShapedValue( // Add 'memory_space' attribute. Not needed if 'copy' operand is specified. if (copy) return allocTensorOp.getResult(); - FailureOr copyBufferType = getBufferType(tensor, options); + FailureOr copyBufferType = getBufferType(tensor, options); if (failed(copyBufferType)) return failure(); std::optional memorySpace = copyBufferType->getMemorySpace(); if (!memorySpace) - memorySpace = options.defaultMemorySpaceFn(tensorType); + memorySpace = + options.defaultMemorySpaceFn(mlir::cast(tensorType)); if (memorySpace.has_value()) allocTensorOp.setMemorySpaceAttr(memorySpace.value()); return allocTensorOp.getResult(); @@ -229,6 +230,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( // Find all out-of-place OpOperands. for (OpOperand &opOperand : op->getOpOperands()) { Type operandType = opOperand.get().getType(); + // Note: can only copy TensorType (any other TensorLikeType is rejected) if (!llvm::isa(operandType)) continue; if (state.isInPlace(opOperand)) @@ -328,18 +330,21 @@ bool OpFilter::isOpAllowed(Operation *op) const { namespace { /// Default function arg type converter: Use a fully dynamic layout map. -BaseMemRefType -defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace, - func::FuncOp funcOp, +bufferization::BufferLikeType +defaultFunctionArgTypeConverter(bufferization::TensorLikeType type, + Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { - return getMemRefTypeWithFullyDynamicLayout(type, memorySpace); + return mlir::cast( + getMemRefTypeWithFullyDynamicLayout(mlir::cast(type), + memorySpace)); } /// Default unknown type converter: Use a fully dynamic layout map. -BaseMemRefType +BufferLikeType defaultUnknownTypeConverter(Value value, Attribute memorySpace, const BufferizationOptions &options) { - return getMemRefTypeWithFullyDynamicLayout( - llvm::cast(value.getType()), memorySpace); + return mlir::cast( + getMemRefTypeWithFullyDynamicLayout( + llvm::cast(value.getType()), memorySpace)); } } // namespace @@ -376,14 +381,16 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const { void BufferizationOptions::setFunctionBoundaryTypeConversion( LayoutMapOption layoutMapOption) { - functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace, - func::FuncOp funcOp, + functionArgTypeConverterFn = [=](TensorLikeType tensorType, + Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) - return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, - memorySpace); - return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, - memorySpace); + return mlir::cast( + bufferization::getMemRefTypeWithStaticIdentityLayout( + mlir::cast(tensorType), memorySpace)); + return mlir::cast( + bufferization::getMemRefTypeWithFullyDynamicLayout( + mlir::cast(tensorType), memorySpace)); }; inferFunctionResultLayout = layoutMapOption == LayoutMapOption::InferLayoutMap; @@ -473,7 +480,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const { /// read. Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). bool AnalysisState::isValueRead(Value value) const { - assert(llvm::isa(value.getType()) && "expected TensorType"); + assert(llvm::isa(value.getType()) && + "expected TensorLikeType"); SmallVector workingSet; DenseSet visited; for (OpOperand &use : value.getUses()) @@ -663,7 +671,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options) { #ifndef NDEBUG - auto tensorType = llvm::dyn_cast(value.getType()); + auto tensorType = + llvm::dyn_cast(value.getType()); assert(tensorType && "unexpected non-tensor type"); #endif // NDEBUG @@ -674,7 +683,7 @@ FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, // Insert to_memref op. OpBuilder::InsertionGuard g(rewriter); setInsertionPointAfter(rewriter, value); - FailureOr memrefType = getBufferType(value, options); + FailureOr memrefType = getBufferType(value, options); if (failed(memrefType)) return failure(); ensureToMemrefOpIsValid(value, *memrefType); @@ -684,18 +693,18 @@ FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, } /// Return the buffer type for a given Value (tensor) after bufferization. -FailureOr +FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options) { SmallVector invocationStack; return getBufferType(value, options, invocationStack); } /// Return the buffer type for a given Value (tensor) after bufferization. -FailureOr +FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options, SmallVector &invocationStack) { - assert(llvm::isa(value.getType()) && - "unexpected non-tensor type"); + assert(llvm::isa(value.getType()) && + "unexpected non-tensor-like type"); invocationStack.push_back(value); auto popFromStack = llvm::make_scope_exit([&]() { invocationStack.pop_back(); }); @@ -708,11 +717,12 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, // Op is not bufferizable. auto memSpace = - options.defaultMemorySpaceFn(cast(value.getType())); + options.defaultMemorySpaceFn(cast(value.getType())); if (!memSpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(value, options, /*layout=*/{}, *memSpace); + return mlir::cast( + getMemRefType(value, options, /*layout=*/{}, *memSpace)); } bool bufferization::hasTensorSemantics(Operation *op) { @@ -732,12 +742,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, SmallVector replacements; for (OpResult opResult : op->getOpResults()) { Value replacement = values[opResult.getResultNumber()]; - if (llvm::isa(opResult.getType())) { - // The OpResult is a tensor. Such values are replaced with memrefs during + if (llvm::isa(opResult.getType())) { + // The OpResult is a tensor. Such values are replaced with buffers during // bufferization. - assert((llvm::isa(replacement.getType()) || - llvm::isa(replacement.getType())) && - "tensor op result should be replaced with a memref value"); + assert(llvm::isa(replacement.getType()) && + "tensor op result should be replaced with a buffer value"); // The existing uses of the OpResult still expect a tensor. Insert a // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually // loose all of its users and eventually DCE away. @@ -789,6 +798,8 @@ BaseMemRefType bufferization::getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout, Attribute memorySpace) { + assert(mlir::isa(value.getType()) && + "expected tensor type in tensor -> memref conversion"); auto tensorType = llvm::cast(value.getType()); // Case 1: Unranked memref type. @@ -807,7 +818,8 @@ BaseMemRefType bufferization::getMemRefType(Value value, memorySpace); } - return options.unknownTypeConverterFn(value, memorySpace, options); + return mlir::cast( + options.unknownTypeConverterFn(value, memorySpace, options)); } BaseMemRefType @@ -928,7 +940,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands( Operation *op = getOwnerOfValue(value); SmallVector result; for (OpOperand &opOperand : op->getOpOperands()) { - if (!llvm::isa(opOperand.get().getType())) + if (!llvm::isa(opOperand.get().getType())) continue; AliasingValueList aliasingValues = state.getAliasingValues(opOperand); for (const auto &it : aliasingValues) @@ -938,14 +950,15 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands( return AliasingOpOperandList(std::move(result)); } -FailureOr bufferization::detail::defaultGetBufferType( +FailureOr bufferization::detail::defaultGetBufferType( Value value, const BufferizationOptions &options, SmallVector &invocationStack) { - assert(llvm::isa(value.getType()) && "expected tensor type"); + assert(llvm::isa(value.getType()) && "expected tensor type"); // No further analysis is possible for a block argument. if (llvm::isa(value)) - return bufferization::getMemRefType(value, options); + return mlir::cast( + bufferization::getMemRefType(value, options)); // Value is an OpResult. Operation *op = getOwnerOfValue(value); @@ -963,11 +976,12 @@ FailureOr bufferization::detail::defaultGetBufferType( // If we do not know the memory space and there is no default memory space, // report a failure. auto memSpace = - options.defaultMemorySpaceFn(cast(value.getType())); + options.defaultMemorySpaceFn(cast(value.getType())); if (!memSpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(value, options, /*layout=*/{}, *memSpace); + return mlir::cast( + getMemRefType(value, options, /*layout=*/{}, *memSpace)); } bool bufferization::detail::defaultIsRepetitiveRegion( @@ -993,7 +1007,7 @@ bufferization::detail::unknownGetAliasingOpOperands(Value value) { // with every OpOperand. AliasingOpOperandList r; for (OpOperand &operand : value.getDefiningOp()->getOpOperands()) - if (isa(operand.get().getType())) + if (isa(operand.get().getType())) r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false}); return r; } @@ -1006,18 +1020,18 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) { // with every OpOperand. AliasingValueList r; for (OpResult result : opOperand.getOwner()->getOpResults()) - if (llvm::isa(result.getType())) + if (llvm::isa(result.getType())) r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false}); for (Region ®ion : opOperand.getOwner()->getRegions()) if (!region.getBlocks().empty()) for (BlockArgument bbArg : region.getBlocks().front().getArguments()) - if (isa(bbArg.getType())) + if (isa(bbArg.getType())) r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false}); return r; } bool bufferization::detail::defaultHasTensorSemantics(Operation *op) { - auto isaTensor = [](Type t) { return isa(t); }; + auto isaTensor = [](Type t) { return isa(t); }; bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) { return any_of(r.getBlocks(), [&](Block &b) { return any_of(b.getArguments(), [&](BlockArgument bbArg) { diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index 6b9253a5d71da..02f9252dcb088 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -62,7 +62,11 @@ struct BuiltinTensorExternalModel template struct BuiltinMemRefExternalModel : BufferLikeType::ExternalModel, - MemRef> {}; + MemRef> { + mlir::Attribute getMemorySpace(mlir::Type type) const { + return mlir::cast(type).getMemorySpace(); + } +}; } // namespace //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 4fce9be390bd6..2ceb6795899c9 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -220,7 +220,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand, return {}; } -FailureOr +FailureOr AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, SmallVector &invocationStack) { assert(value == getResult() && "invalid value"); @@ -235,13 +235,15 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, if (failed(copyBufferType)) return failure(); memorySpace = copyBufferType->getMemorySpace(); - } else if (auto ms = options.defaultMemorySpaceFn(getType())) { + } else if (auto ms = options.defaultMemorySpaceFn( + mlir::cast(getType()))) { memorySpace = *ms; } else { return getOperation()->emitError("could not infer memory space"); } - return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace); + return mlir::cast( + getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace)); } LogicalResult AllocTensorOp::verify() { @@ -585,7 +587,7 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter, return failure(); buffer = *maybeBuffer; } else { - assert(isa(getDest().getType()) && "expected memref type"); + assert(isa(getDest().getType()) && "expected buffer type"); buffer = getDest(); } auto srcBuffer = getBuffer(rewriter, getSource(), options); @@ -632,7 +634,7 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder, return {}; // Build a bufferization.to_tensor op. - assert(isa(getDest().getType()) && "expected memref type"); + assert(isa(getDest().getType()) && "expected buffer type"); assert(getRestrict() && "expected that ops with memrefs dest have 'restrict'"); setRestrict(false); @@ -667,22 +669,22 @@ bool MaterializeInDestinationOp::operatesOnDisjointSubset( } LogicalResult MaterializeInDestinationOp::verify() { - if (!isa(getDest().getType())) - return emitOpError("'dest' must be a tensor or a memref"); + if (!isa(getDest().getType())) + return emitOpError("'dest' must be a tensor or a buffer"); if (auto destType = dyn_cast(getDest().getType())) { if (getOperation()->getNumResults() != 1) return emitOpError("tensor 'dest' implies exactly one tensor result"); if (destType != getResult().getType()) return emitOpError("result and 'dest' types must match"); } - if (isa(getDest().getType()) && + if (isa(getDest().getType()) && getOperation()->getNumResults() != 0) - return emitOpError("memref 'dest' implies zero results"); - if (getRestrict() && !isa(getDest().getType())) - return emitOpError("'restrict' is valid only for memref destinations"); - if (getWritable() != isa(getDest().getType())) + return emitOpError("buffer 'dest' implies zero results"); + if (getRestrict() && !isa(getDest().getType())) + return emitOpError("'restrict' is valid only for buffer destinations"); + if (getWritable() != isa(getDest().getType())) return emitOpError("'writable' must be specified if and only if the " - "destination is of memref type"); + "destination is of buffer type"); TensorType srcType = getSource().getType(); ShapedType destType = cast(getDest().getType()); if (srcType.hasRank() != destType.hasRank()) @@ -724,7 +726,7 @@ MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() { void MaterializeInDestinationOp::getEffects( SmallVectorImpl> &effects) { - if (isa(getDest().getType())) + if (isa(getDest().getType())) effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(), SideEffects::DefaultResource::get()); } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp new file mode 100644 index 0000000000000..0e973915c6fc9 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp @@ -0,0 +1,21 @@ +//===- BufferizationTypeInterfaces.cpp - Type Interfaces --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" + +//===----------------------------------------------------------------------===// +// Bufferization Type Interfaces +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace bufferization { + +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp.inc" + +} // namespace bufferization +} // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt index 63dcc1eb233e9..5d8f0060f2c3f 100644 --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect BufferizationDialect.cpp BufferViewFlowOpInterface.cpp UnstructuredControlFlow.cpp + BufferizationTypeInterfaces.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp index 72f47b8b468ea..cb9db1288039a 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" #include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -93,11 +94,11 @@ void BufferViewFlowAnalysis::build(Operation *op) { // given op as terminals. auto populateTerminalValues = [&](Operation *op) { for (Value v : op->getResults()) - if (isa(v.getType())) + if (isa(v.getType())) this->terminals.insert(v); for (Region &r : op->getRegions()) for (BlockArgument v : r.getArguments()) - if (isa(v.getType())) + if (isa(v.getType())) this->terminals.insert(v); }; @@ -108,12 +109,12 @@ void BufferViewFlowAnalysis::build(Operation *op) { if (auto bufferViewFlowOp = dyn_cast(op)) { bufferViewFlowOp.populateDependencies(registerDependencies); for (Value v : op->getResults()) - if (isa(v.getType()) && + if (isa(v.getType()) && bufferViewFlowOp.mayBeTerminalBuffer(v)) this->terminals.insert(v); for (Region &r : op->getRegions()) for (BlockArgument v : r.getArguments()) - if (isa(v.getType()) && + if (isa(v.getType()) && bufferViewFlowOp.mayBeTerminalBuffer(v)) this->terminals.insert(v); return WalkResult::advance(); @@ -201,7 +202,7 @@ void BufferViewFlowAnalysis::build(Operation *op) { } bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const { - assert(isa(value.getType()) && "expected memref"); + assert(isa(value.getType()) && "expected memref"); return terminals.contains(value); } @@ -240,8 +241,8 @@ static Value getViewBase(Value value) { BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {} std::optional BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) { - assert(isa(v1.getType()) && "expected buffer"); - assert(isa(v2.getType()) && "expected buffer"); + assert(isa(v1.getType()) && "expected buffer"); + assert(isa(v2.getType()) && "expected buffer"); // Skip over all view-like ops. v1 = getViewBase(v1); @@ -275,7 +276,7 @@ std::optional BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) { bool &allAllocs, bool &allAllocsOrFuncEntryArgs) { for (Value v : origin) { - if (isa(v.getType()) && analysis.mayBeTerminalBuffer(v)) { + if (isa(v.getType()) && analysis.mayBeTerminalBuffer(v)) { terminal.insert(v); allAllocs &= hasAllocateSideEffect(v); allAllocsOrFuncEntryArgs &= diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 0b60c44ece5fd..a296b617024d8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -80,14 +80,14 @@ struct OneShotBufferizePass if (mustInferMemorySpace) { opt.defaultMemorySpaceFn = - [](TensorType t) -> std::optional { + [](TensorLikeType t) -> std::optional { return std::nullopt; }; } if (useEncodingForMemorySpace) { opt.defaultMemorySpaceFn = - [](TensorType t) -> std::optional { + [](TensorLikeType t) -> std::optional { if (auto rtt = dyn_cast(t)) return rtt.getEncoding(); return std::nullopt; @@ -113,13 +113,15 @@ struct OneShotBufferizePass const BufferizationOptions &options) { auto tensorType = cast(value.getType()); if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap) - return bufferization::getMemRefTypeWithStaticIdentityLayout( - tensorType, memorySpace); + return mlir::cast( + bufferization::getMemRefTypeWithStaticIdentityLayout( + tensorType, memorySpace)); assert(unknownTypeConversionOption == LayoutMapOption::FullyDynamicLayoutMap && "invalid layout map option"); - return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, - memorySpace); + return mlir::cast( + bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, + memorySpace)); }; // Configure op filter. @@ -407,7 +409,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, continue; } - FailureOr memrefType = + FailureOr memrefType = bufferization::getBufferType(bbArg, options); if (failed(memrefType)) return failure(); @@ -458,7 +460,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, newOperands.push_back(operand); continue; } - FailureOr operandBufferType = + FailureOr operandBufferType = bufferization::getBufferType(operand, options); if (failed(operandBufferType)) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index c45678f1e4b4d..4d39d9b795bed 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -53,14 +53,14 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { /// Return the index-th bufferized function argument type. This assumes that the /// specified argument is a tensor. If the tensor is ranked, a layout map may be /// specified by the user (as per `options.functionArgTypeConverterFn`). -static BaseMemRefType +static BufferLikeType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options) { auto tensorType = - dyn_cast(funcOp.getFunctionType().getInput(index)); - assert(tensorType && "expected TensorType"); + dyn_cast(funcOp.getFunctionType().getInput(index)); + assert(tensorType && "expected TensorLikeType"); - BaseMemRefType memrefType = options.functionArgTypeConverterFn( + BufferLikeType memrefType = options.functionArgTypeConverterFn( tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); auto layoutAttr = funcOp.getArgAttrOfType( @@ -70,9 +70,9 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, auto rankedMemrefType = dyn_cast(memrefType); assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); - return MemRefType::get( + return mlir::cast(MemRefType::get( rankedMemrefType.getShape(), rankedMemrefType.getElementType(), - layoutAttr.getValue(), rankedMemrefType.getMemorySpace()); + layoutAttr.getValue(), rankedMemrefType.getMemorySpace())); } /// Return the FuncOp called by `callOp`. @@ -195,7 +195,7 @@ struct CallOpInterface return result; } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto callOp = cast(op); @@ -207,11 +207,11 @@ struct CallOpInterface FunctionType funcType = funcOp.getFunctionType(); Type resultType = funcType.getResult(cast(value).getResultNumber()); - if (auto bufferizedType = dyn_cast(resultType)) + if (auto bufferizedType = dyn_cast(resultType)) return bufferizedType; // Otherwise, call the type converter to compute the bufferized type. - auto tensorType = cast(resultType); + auto tensorType = cast(resultType); return options.functionArgTypeConverterFn( tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); } @@ -233,7 +233,7 @@ struct CallOpInterface } // Returning a memref. - FailureOr resultType = + FailureOr resultType = bufferization::getBufferType(result, options); if (failed(resultType)) return failure(); @@ -263,11 +263,11 @@ struct CallOpInterface // Caller / callee type mismatch is handled with castOrReallocMemRefValue. auto memRefType = funcType.getInput(opOperand.getOperandNumber()); - if (!isa(memRefType)) { + if (!isa(memRefType)) { // The called function was not bufferized yet. This can happen when // there cycles in the function call graph. Compute the bufferized // result type. - FailureOr maybeMemRefType = + FailureOr maybeMemRefType = bufferization::getBufferType( funcOp.getArgument(opOperand.getOperandNumber()), options); if (failed(maybeMemRefType)) @@ -371,7 +371,7 @@ struct FuncOpInterface return getAliasingBranchOpOperands(op, cast(value), state); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto funcOp = cast(op); @@ -413,8 +413,8 @@ struct FuncOpInterface // Compute the result types. SmallVector retTypes; for (Type resultType : funcType.getResults()) { - if (auto tensorType = dyn_cast(resultType)) { - BaseMemRefType resultType = options.functionArgTypeConverterFn( + if (auto tensorType = dyn_cast(resultType)) { + BufferLikeType resultType = options.functionArgTypeConverterFn( tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); retTypes.push_back(resultType); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index cf62ee8bc45b5..523ee48be2003 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -102,11 +102,11 @@ struct ConditionOpInterface SmallVector newArgs; for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { Value value = it.value(); - if (isa(value.getType())) { + if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); - FailureOr resultType = bufferization::getBufferType( + auto resultType = bufferization::getBufferType( whileOp.getAfterArguments()[it.index()], options); if (failed(resultType)) return failure(); @@ -201,7 +201,7 @@ struct ExecuteRegionOpInterface rewriter.setInsertionPointAfter(newOp); SmallVector newResults; for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { - if (isa(it.value())) { + if (isa(it.value())) { newResults.push_back(rewriter.create( executeRegionOp.getLoc(), it.value(), newOp->getResult(it.index()))); @@ -244,7 +244,7 @@ struct IfOpInterface // Compute bufferized result types. SmallVector newTypes; for (Value result : ifOp.getResults()) { - if (!isa(result.getType())) { + if (!isa(result.getType())) { newTypes.push_back(result.getType()); continue; } @@ -270,7 +270,7 @@ struct IfOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto ifOp = cast(op); @@ -282,10 +282,10 @@ struct IfOpInterface auto opResult = cast(value); auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); - BaseMemRefType thenBufferType, elseBufferType; - if (isa(thenValue.getType())) { + bufferization::BufferLikeType thenBufferType, elseBufferType; + if (isa(thenValue.getType())) { // True branch was already bufferized. - thenBufferType = cast(thenValue.getType()); + thenBufferType = cast(thenValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(thenValue, options, invocationStack); @@ -293,9 +293,9 @@ struct IfOpInterface return failure(); thenBufferType = *maybeBufferType; } - if (isa(elseValue.getType())) { + if (isa(elseValue.getType())) { // False branch was already bufferized. - elseBufferType = cast(elseValue.getType()); + elseBufferType = cast(elseValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(elseValue, options, invocationStack); @@ -313,8 +313,10 @@ struct IfOpInterface return op->emitError("inconsistent memory space on then/else branches"); // Layout maps are different: Promote to fully dynamic layout map. - return getMemRefTypeWithFullyDynamicLayout( - cast(opResult.getType()), thenBufferType.getMemorySpace()); + return mlir::cast( + getMemRefTypeWithFullyDynamicLayout( + cast(opResult.getType()), + thenBufferType.getMemorySpace())); } }; @@ -354,7 +356,7 @@ struct IndexSwitchOpInterface // Compute bufferized result types. SmallVector newTypes; for (Value result : switchOp.getResults()) { - if (!isa(result.getType())) { + if (!isa(result.getType())) { newTypes.push_back(result.getType()); continue; } @@ -384,7 +386,7 @@ struct IndexSwitchOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto switchOp = cast(op); @@ -392,11 +394,13 @@ struct IndexSwitchOpInterface int64_t resultNum = cast(value).getResultNumber(); // Helper function to get buffer type of a case. - SmallVector yieldedTypes; - auto getYieldedBufferType = [&](Block &b) -> FailureOr { + SmallVector yieldedTypes; + auto getYieldedBufferType = + [&](Block &b) -> FailureOr { auto yieldOp = cast(b.getTerminator()); Value yieldedValue = yieldOp->getOperand(resultNum); - if (auto bufferType = dyn_cast(yieldedValue.getType())) + if (auto bufferType = + dyn_cast(yieldedValue.getType())) return bufferType; auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, invocationStack); @@ -409,7 +413,7 @@ struct IndexSwitchOpInterface auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock()); if (failed(maybeBufferType)) return failure(); - BaseMemRefType bufferType = *maybeBufferType; + auto bufferType = *maybeBufferType; // Compute buffer types of all other cases. for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { @@ -426,8 +430,9 @@ struct IndexSwitchOpInterface return op->emitError("inconsistent memory space on switch cases"); // Layout maps are different: Promote to fully dynamic layout map. - bufferType = getMemRefTypeWithFullyDynamicLayout( - cast(value.getType()), bufferType.getMemorySpace()); + bufferType = mlir::cast( + getMemRefTypeWithFullyDynamicLayout(cast(value.getType()), + bufferType.getMemorySpace())); } return bufferType; @@ -439,7 +444,7 @@ struct IndexSwitchOpInterface static DenseSet getTensorIndices(ValueRange values) { DenseSet result; for (const auto &it : llvm::enumerate(values)) - if (isa(it.value().getType())) + if (isa(it.value().getType())) result.insert(it.index()); return result; } @@ -452,8 +457,8 @@ DenseSet getEquivalentBuffers(Block::BlockArgListType bbArgs, unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); DenseSet result; for (unsigned int i = 0; i < minSize; ++i) { - if (!isa(bbArgs[i].getType()) || - !isa(yieldedValues[i].getType())) + if (!isa(bbArgs[i].getType()) || + !isa(yieldedValues[i].getType())) continue; if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) result.insert(i); @@ -468,7 +473,7 @@ getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, const BufferizationOptions &options) { SmallVector result; for (OpOperand &opOperand : operands) { - if (isa(opOperand.get().getType())) { + if (isa(opOperand.get().getType())) { FailureOr resultBuffer = getBuffer(rewriter, opOperand.get(), options); if (failed(resultBuffer)) @@ -516,9 +521,11 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, /// If both buffer types are equal, no casts are needed the computed buffer type /// can be used directly. Otherwise, the buffer types can only differ in their /// layout map and a cast must be inserted. -static FailureOr computeLoopRegionIterArgBufferType( - Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue, - const BufferizationOptions &options, SmallVector &invocationStack) { +static FailureOr +computeLoopRegionIterArgBufferType(Operation *loopOp, BlockArgument iterArg, + Value initArg, Value yieldedValue, + const BufferizationOptions &options, + SmallVector &invocationStack) { // Determine the buffer type of the init_arg. auto initArgBufferType = bufferization::getBufferType(initArg, options, invocationStack); @@ -540,10 +547,11 @@ static FailureOr computeLoopRegionIterArgBufferType( } // Compute the buffer type of the yielded value. - BaseMemRefType yieldedValueBufferType; - if (isa(yieldedValue.getType())) { + bufferization::BufferLikeType yieldedValueBufferType; + if (isa(yieldedValue.getType())) { // scf.yield was already bufferized. - yieldedValueBufferType = cast(yieldedValue.getType()); + yieldedValueBufferType = + cast(yieldedValue.getType()); } else { // Note: This typically triggers a recursive call for the buffer type of // the iter_arg. @@ -576,8 +584,9 @@ static FailureOr computeLoopRegionIterArgBufferType( "expected same shape"); } #endif // NDEBUG - return getMemRefTypeWithFullyDynamicLayout( - iterTensorType, yieldedBufferType.getMemorySpace()); + return mlir::cast( + getMemRefTypeWithFullyDynamicLayout(iterTensorType, + yieldedBufferType.getMemorySpace())); } /// Return `true` if the given loop may have 0 iterations. @@ -696,12 +705,13 @@ struct ForOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto forOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); - assert(isa(value.getType()) && "expected tensor type"); + assert(isa(value.getType()) && + "expected tensor type"); if (auto opResult = dyn_cast(value)) { // The type of an OpResult must match the corresponding iter_arg type. @@ -744,7 +754,7 @@ struct ForOpInterface Value initArg = it.value(); Value result = forOp->getResult(it.index()); // If the type is not a tensor, bufferization doesn't need to touch it. - if (!isa(result.getType())) { + if (!isa(result.getType())) { castedInitArgs.push_back(initArg); continue; } @@ -795,7 +805,7 @@ struct ForOpInterface auto forOp = cast(op); auto yieldOp = cast(forOp.getBody()->getTerminator()); for (OpResult opResult : op->getOpResults()) { - if (!isa(opResult.getType())) + if (!isa(opResult.getType())) continue; // Note: This is overly strict. We should check for aliasing bufferized @@ -920,7 +930,7 @@ struct WhileOpInterface for (int64_t idx = 0; idx < static_cast(conditionOp.getArgs().size()); ++idx) { Value value = conditionOp.getArgs()[idx]; - if (!isa(value.getType()) || + if (!isa(value.getType()) || (equivalentYieldsAfter.contains(idx) && equivalentYieldsBefore.contains(idx))) { beforeYieldValues.push_back(value); @@ -962,7 +972,7 @@ struct WhileOpInterface Value initArg = it.value(); Value beforeArg = whileOp.getBeforeArguments()[it.index()]; // If the type is not a tensor, bufferization doesn't need to touch it. - if (!isa(beforeArg.getType())) { + if (!isa(beforeArg.getType())) { castedInitArgs.push_back(initArg); continue; } @@ -975,7 +985,7 @@ struct WhileOpInterface // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::to_vector( llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { - if (!isa(bbArg.getType())) + if (!isa(bbArg.getType())) return bbArg.getType(); // TODO: error handling return llvm::cast( @@ -1022,12 +1032,13 @@ struct WhileOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto whileOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); - assert(isa(value.getType()) && "expected tensor type"); + assert(isa(value.getType()) && + "expected tensor type"); // Case 1: Block argument of the "before" region. if (auto bbArg = dyn_cast(value)) { @@ -1053,9 +1064,9 @@ struct WhileOpInterface llvm_unreachable("invalid value"); } Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; - if (!isa(conditionYieldedVal.getType())) { + if (!isa(conditionYieldedVal.getType())) { // scf.condition was already bufferized. - return cast(conditionYieldedVal.getType()); + return cast(conditionYieldedVal.getType()); } return bufferization::getBufferType(conditionYieldedVal, options, invocationStack); @@ -1082,7 +1093,7 @@ struct WhileOpInterface auto conditionOp = whileOp.getConditionOp(); for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { Block *block = conditionOp->getBlock(); - if (!isa(it.value().getType())) + if (!isa(it.value().getType())) continue; if (it.index() >= block->getNumArguments() || !state.areEquivalentBufferizedValues(it.value(), @@ -1095,7 +1106,7 @@ struct WhileOpInterface auto yieldOp = whileOp.getYieldOp(); for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Block *block = yieldOp->getBlock(); - if (!isa(it.value().getType())) + if (!isa(it.value().getType())) continue; if (it.index() >= block->getNumArguments() || !state.areEquivalentBufferizedValues(it.value(), @@ -1154,7 +1165,7 @@ struct YieldOpInterface SmallVector newResults; for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Value value = it.value(); - if (isa(value.getType())) { + if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); @@ -1162,14 +1173,14 @@ struct YieldOpInterface // We may have to cast the value before yielding it. if (isa( yieldOp->getParentOp())) { - FailureOr resultType = bufferization::getBufferType( + auto resultType = bufferization::getBufferType( yieldOp->getParentOp()->getResult(it.index()), options); if (failed(resultType)) return failure(); buffer = castBuffer(rewriter, buffer, *resultType); } else if (auto whileOp = dyn_cast(yieldOp->getParentOp())) { - FailureOr resultType = bufferization::getBufferType( + auto resultType = bufferization::getBufferType( whileOp.getBeforeArguments()[it.index()], options); if (failed(resultType)) return failure(); @@ -1274,7 +1285,7 @@ struct ForallOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto forallOp = cast(op); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 6e882a8d0ff30..068c248c1bcd7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -220,8 +220,8 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) { options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap); options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, const BufferizationOptions &options) { - return getMemRefTypeWithStaticIdentityLayout( - cast(value.getType()), memorySpace); + return llvm::cast(getMemRefTypeWithStaticIdentityLayout( + cast(value.getType()), memorySpace)); }; if (analysisOnly) { options.testAnalysisOnly = true; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp index f92382472b478..742a92566a31e 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp @@ -550,8 +550,8 @@ TypedValue sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { auto tTp = llvm::cast(tensor.getType()); auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType()); - return builder.create(loc, mTp, tensor) - .getResult(); + return llvm::cast>( + builder.create(loc, mTp, tensor).getResult()); } Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 31014172a9555..fb0dd151a4448 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -487,8 +487,7 @@ struct FromElementsOpInterface /*copy=*/false); if (failed(tensorAlloc)) return failure(); - FailureOr memrefType = - bufferization::getBufferType(*tensorAlloc, options); + auto memrefType = bufferization::getBufferType(*tensorAlloc, options); if (failed(memrefType)) return failure(); Value buffer = rewriter.create( @@ -592,7 +591,8 @@ struct GenerateOpInterface auto type = generateOp.getResult().getType(); // TODO: Implement memory space for this op. - if (options.defaultMemorySpaceFn(type) != Attribute()) + if (options.defaultMemorySpaceFn(llvm::cast(type)) != + Attribute()) return op->emitError("memory space not implemented yet"); // Allocate memory. @@ -1031,7 +1031,8 @@ struct SplatOpInterface auto tensorType = cast(tensorAlloc->getType()); // TODO: Implement memory space for this op. - if (options.defaultMemorySpaceFn(tensorType) != Attribute()) + if (options.defaultMemorySpaceFn(llvm::cast(tensorType)) != + Attribute()) return op->emitError("memory space not implemented yet"); auto linalgOp = diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir index e65c5b92949f6..6fb421675fab6 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -268,4 +268,23 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x %r = tensor.extract %dest_filled[%idx] : tensor<5xf32> return %0, %r : tensor<5xf32>, f32 -} \ No newline at end of file +} + +// ----- + +// CHECK-LABEL: func.func @test_dialect_op( +// CHECK-SAME: %[[ARG:.*]]: !test.test_tensor<[32, 64], f64> +// CHECK-SAME: ) -> !test.test_tensor<[32, 128], f64> { +func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>) + -> !test.test_tensor<[32, 128], f64> { + // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[ARG]] + // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]]) + // CHECK-SAME: : (!test.test_memref<[32, 64], f64>) + // CHECK-SAME: -> !test.test_memref<[32, 128], f64> + // CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]] + %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[32, 64], f64>) + -> !test.test_tensor<[32, 128], f64> + + // CHECK: return %[[OUT]] + return %out : !test.test_tensor<[32, 128], f64> +} diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir index 2c8807b66de74..86b541d95924b 100644 --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -58,14 +58,14 @@ func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: ten // ----- func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %arg1: vector<5xf32>) { - // expected-error @below{{'dest' must be a tensor or a memref}} + // expected-error @below{{'dest' must be a tensor or a buffer}} bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5xf32>, vector<5xf32>) -> () } // ----- func.func @invalid_materialize_in_destination_result(%arg0: tensor, %arg1: memref) { - // expected-error @below{{memref 'dest' implies zero results}} + // expected-error @below{{buffer 'dest' implies zero results}} bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor, memref) -> (tensor) } @@ -79,14 +79,14 @@ func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor, %arg1: tensor) { - // expected-error @below{{'restrict' is valid only for memref destinations}} + // expected-error @below{{'restrict' is valid only for buffer destinations}} bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor, tensor) -> (tensor) } // ----- func.func @invalid_materialize_in_destination_restrict(%arg0: tensor, %arg1: tensor) { - // expected-error @below{{'writable' must be specified if and only if the destination is of memref type}} + // expected-error @below{{'writable' must be specified if and only if the destination is of buffer type}} bufferization.materialize_in_destination %arg0 in writable %arg1 : (tensor, tensor) -> (tensor) } diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp index 2991a3c165ee2..95d6158d7c67f 100644 --- a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp +++ b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp @@ -46,7 +46,9 @@ struct TestTensorCopyInsertionPass options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; if (mustInferMemorySpace) { options.defaultMemorySpaceFn = - [](TensorType t) -> std::optional { return std::nullopt; }; + [](bufferization::TensorLikeType t) -> std::optional { + return std::nullopt; + }; } if (failed(bufferization::insertTensorCopies(getOperation(), options))) signalPassFailure(); diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 454a12bac9ab3..df7586976280c 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -8,6 +8,7 @@ #include "TestDialect.h" #include "TestOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/FunctionImplementation.h" @@ -1386,3 +1387,26 @@ TestMultiSlotAlloca::handleDestructuringComplete( const DestructurableMemorySlot &slot, OpBuilder &builder) { return createNewMultiAllocaWithoutSlot(slot, builder, *this); } + +::mlir::LogicalResult test::TestDummyTensorOp::bufferize( + ::mlir::RewriterBase &rewriter, + const ::mlir::bufferization::BufferizationOptions &options) { + const auto inType = getInput().getType(); + const auto bufferizedInType = test::TestMemrefType::get( + getContext(), inType.getShape(), inType.getElementType(), nullptr); + const auto outType = getOutput().getType(); + const auto bufferizedOutType = test::TestMemrefType::get( + getContext(), outType.getShape(), outType.getElementType(), nullptr); + + // replace op with memref analogy, preserve correct types at the boundaries + auto toMemref = rewriter.create<::mlir::bufferization::ToMemrefOp>( + getLoc(), bufferizedInType, getInput()); + auto dummyMemrefOp = rewriter.create( + getLoc(), bufferizedOutType, toMemref.getResult()); + auto toTensor = rewriter.create<::mlir::bufferization::ToTensorOp>( + getLoc(), outType, dummyMemrefOp.getOutput()); + + rewriter.replaceOp(*this, toTensor); + + return mlir::success(); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h index f070c3bedd92c..ea8867e3fc41d 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.h +++ b/mlir/test/lib/Dialect/Test/TestOps.h @@ -13,6 +13,7 @@ #include "TestInterfaces.h" #include "TestTypes.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 85a49e05d4c73..976b4963a29f7 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -30,7 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" - +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" // Include the attribute definitions. include "TestAttrDefs.td" @@ -3499,4 +3499,57 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> { }]; } +//===----------------------------------------------------------------------===// +// Test Ops bufferization +//===----------------------------------------------------------------------===// + +def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", [BufferizableOpInterface]> { + let arguments = (ins + Arg:$input + ); + let results = (outs + Arg:$output + ); + let extraClassDeclaration = [{ + // BufferizableOpInterface + bool bufferizesToMemoryRead(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + bool bufferizesToMemoryWrite(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + mlir::LogicalResult bufferize( + mlir::RewriterBase& rewriter, + const mlir::bufferization::BufferizationOptions& options); + }]; + + let extraClassDefinition = [{ + bool test::TestDummyTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return true; + } + bool test::TestDummyTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return true; + } + ::mlir::bufferization::AliasingValueList + test::TestDummyTensorOp::getAliasingValues(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return {}; + } + }]; +} + +def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> { + let arguments = (ins + Arg:$input + ); + let results = (outs + Arg:$output + ); +} + #endif // TEST_OPS diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index e9785594d3332..cee6888a7196c 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -446,6 +446,9 @@ def TestMemrefType : Test_Type<"TestMemref", return test::TestMemrefType::get( getContext(), shape.value_or(getShape()), elementType, getMemSpace()); } + + // BufferLikeTypeInterface: + ::mlir::Attribute getMemorySpace() const { return getMemSpace(); } }]; }