diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 1d0edf9ea809d..5c62495110df8 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -697,6 +697,13 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides, const int64_t *strides); +// Creates a strided layout attribute from given strides and offset, +// canonicalizing the 0D and 1D unit stride to contiguous layout attributes. The +// returned value may not be a StridedLayoutAttr. +MLIR_CAPI_EXPORTED MlirAttribute +mlirStridedLayoutAttrGetCanonical(MlirContext ctx, int64_t offset, + intptr_t numStrides, const int64_t *strides); + // Returns the offset in the given strided layout layout attribute. MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr); @@ -711,6 +718,38 @@ MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, /// Returns the typeID of a StridedLayout attribute. MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void); +//===----------------------------------------------------------------------===// +// Contiguous layout attribute. +//===----------------------------------------------------------------------===// + +// Checks wheather the given attribute is a contiguous layout attribute. +MLIR_CAPI_EXPORTED bool mlirAttributeIsAContiguousLayout(MlirAttribute attr); + +// Creates a contiguous layout attribute from given permutation and offset. +// There must be `rank` values in `permutation`. +MLIR_CAPI_EXPORTED MlirAttribute mlirContiguousLayoutAttrGet( + MlirContext ctx, int64_t offset, intptr_t rank, const int64_t *permutation); + +// Creates a row-major contiguous layout attribute from given offset and rank. +MLIR_CAPI_EXPORTED MlirAttribute mlirContiguousLayoutAttrGetRowMajor( + MlirContext ctx, int64_t offset, int64_t rank); + +// Returns the offset in the given contiguous layout attribute. +MLIR_CAPI_EXPORTED int64_t +mlirContiguousLayoutAttrGetOffset(MlirAttribute attr); + +// Returns the number of permutation entries in the given contiguous layout +// attribute. +MLIR_CAPI_EXPORTED intptr_t mlirContiguousLayoutAttrGetRank(MlirAttribute attr); + +// Returns the pos-th permutation entry stored in the given contiguous layout +// attribute. +MLIR_CAPI_EXPORTED int64_t +mlirContiguousLayoutAttrGetPermutationEntry(MlirAttribute attr, intptr_t pos); + +/// Returns the typeID of a ContiguousLayout attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirContiguousLayoutAttrGetTypeID(void); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 134cca5800918..121099f3c2590 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -32,8 +32,9 @@ def MemRefTypeAttr class MemRef_Op traits = []> : Op; -// Base class for ops with static/dynamic offset, sizes and strides -// attributes/arguments. +// Base class for ops with static/dynamic offset, sizes and optional strides +// attributes/arguments. When the strides are not specified, this implies a +// contiguous layout. class MemRef_OpWithOffsetSizesAndStrides traits = []> : MemRef_Op { diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index 3af89a6ab3799..183bdb005c186 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -178,8 +178,10 @@ LogicalResult reshapeLikeShapesAreCompatible( ArrayRef collapsedShape, ArrayRef expandedShape, ArrayRef reassociationMaps, bool isExpandingReshape); -/// Returns true iff the type is a MemRefType and has a non-identity layout. -bool hasNonIdentityLayout(Type type); +/// Returns true iff the type is a MemRefType and has a layout that is not +/// row-major contiguous - that is, the identity layout with an optional +/// offset. +bool hasNonRowMajorContiguousLayout(Type type); enum class ReshapeOpKind { kExpand, kCollapse }; @@ -197,9 +199,9 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern { ShapedType resultType = reshapeOp.getResultType(); - if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) || - hasNonIdentityLayout(reshapeOp.getSrc().getType()) || - hasNonIdentityLayout(reshapeOp.getResult().getType())) + if (hasNonRowMajorContiguousLayout(srcReshapeOp.getSrc().getType()) || + hasNonRowMajorContiguousLayout(reshapeOp.getSrc().getType()) || + hasNonRowMajorContiguousLayout(reshapeOp.getResult().getType())) return failure(); std::optional> reassociationIndices = @@ -265,9 +267,9 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern { ShapedType srcType = expandOp.getSrcType(); ShapedType resultType = collapseOp.getResultType(); - if (hasNonIdentityLayout(collapseOp.getSrc().getType()) || - hasNonIdentityLayout(expandOp.getSrc().getType()) || - hasNonIdentityLayout(expandOp.getResult().getType())) + if (hasNonRowMajorContiguousLayout(collapseOp.getSrc().getType()) || + hasNonRowMajorContiguousLayout(expandOp.getSrc().getType()) || + hasNonRowMajorContiguousLayout(expandOp.getResult().getType())) return failure(); int64_t srcRank = srcType.getRank(); @@ -331,9 +333,9 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern { ShapedType srcType = collapseOp.getSrcType(); ShapedType resultType = expandOp.getResultType(); - if (hasNonIdentityLayout(expandOp.getSrc().getType()) || - hasNonIdentityLayout(collapseOp.getSrc().getType()) || - hasNonIdentityLayout(collapseOp.getResult().getType())) + if (hasNonRowMajorContiguousLayout(expandOp.getSrc().getType()) || + hasNonRowMajorContiguousLayout(collapseOp.getSrc().getType()) || + hasNonRowMajorContiguousLayout(collapseOp.getResult().getType())) return failure(); int64_t srcRank = srcType.getRank(); @@ -451,7 +453,7 @@ getLinearizedDimensions(ArrayRef reassociationIndices); /// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] : /// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32> /// -/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : +/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : /// tensor<1x1x1x10xf32> into tensor<1x10xf32> /// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] : /// tensor<1x10xf32> into tensor<10x10xf32> diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index 901df3a25a46f..f7b9a78ef4cc9 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -1081,6 +1081,28 @@ inline bool operator!=(StringRef lhs, StringAttr rhs) { return !(lhs == rhs); } namespace mlir { +/// Given an N-dimensional permutation and an offset (which can use +/// ShapedType::kDynamic) to represent a dynamic value), return the +/// N-dimensional map that is permuted according to said permutation and adds +/// the offset to the final output. If the permutation has no outputs (it's a +/// 0-D map), add one result to hold the offset. +/// +/// Examples: +/// ========= +/// +/// offset = 0, permutation = [0, 1, 2] gives +/// [](d0, d1, d2) -> (d0, d1, d2) +/// while offset = 5 gives [](d0, d1, d2) -> (d0, d1, d2 + 5) +/// and offset = ? gives [s0](d0, d1, d2) -> (d0, d1, d2 + s0). +/// +/// offset = ?, permutation = [2, 1, 0] gives +/// [s0](d0, d1, d2) -> (d2, d1, d0 + s0) +/// +/// Finally, offset = 0, permutation = [], gives []() -> (0), while +/// offset = ?, permutation = [] gives [s0]() -> (s0). +AffineMap makePermutedMapWithOffset(ArrayRef permutation, + int64_t offset, MLIRContext *context); + /// Given a list of strides (in which ShapedType::kDynamic /// represents a dynamic value), return the single result AffineMap which /// represents the linearized strided layout map. Dimensions correspond to the diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index 6826d1a437775..455864d50cd8e 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -164,7 +164,7 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter< }]; } -def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array", +def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array", [BlobAttrInterface]> { let summary = "A dense array of integer or floating point elements."; let description = [{ @@ -494,7 +494,7 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", /// when building the attribute. The provided `blobName` is used as a hint /// for the key of the new handle for the `blob` resource, but may be /// changed if necessary to ensure uniqueness during insertion. - /// This base class builder does no element type specific size or alignment + /// This base class builder does no element type specific size or alignment /// checking. Use the typed subclasses for more safety unless if performing /// generic operations. AttrBuilderWithInferredContext<(ins @@ -1051,9 +1051,96 @@ def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout", /// Returns true if this layout is static, i.e. the strides and offset all /// have a known value > 0. bool hasStaticLayout() const; + + /// Get a "canonical" strided layout for the given strides. + /// This constructs a strided layout with the given `offset` and `strides`, + /// except that if either the strides are empty or equal to [1], it returns + /// the corresponding ContiguousLayoutAttr in order to guard against multiple + /// representations of the identity layout. + static ::mlir::MemRefLayoutAttrInterface getCanonical(MLIRContext *context, + int64_t offset, ::llvm::ArrayRef strides); }]; } +//===----------------------------------------------------------------------===// +// ContiguousLayoutAttr +//===----------------------------------------------------------------------===// + +def ContiguousLayoutAttr : Builtin_Attr<"ContiguousLayout", "contiguous_layout", + [DeclareAttrInterfaceMethods]> { + let summary = "An Attribute representing a contiguous layout of a shaped type"; + let description = [{ + Syntax: + + ``` + contiguous-layout-attribute ::= `contiguous` `<` maybe-permutation + (`,` `offset` `:` dimension)? `>` + maybe-permutation ::= decimal-literal | `[` permutation `]` + permutation ::= decimal-literal (`,` decimal-literal)* + dimension ::= decimal-literal | `?` + ``` + + A contiguous layout is a layout that represents a sequence of dimensions + laid out in linear memory in its canonical form. Specifically, it indicates + that if one permutes the dimensions of a memref according to `permutaton`, + they will be in a row-major contiguos form: that is, the stride (in the + sense of the strided layout) of dimension `permutation[i]` is equal + to the products of the sizes of all dimensions appearing later in the permutation. + + For example, a MxN memref with a `contiguous<[1, 0]>` layout is colmn-major: + advancing in the M dimension requires moving by 1 element in linear memory, + while the N dimension requires moving by M elements. Conversely, + if the layout is `contiguous<[0, 1]>` (which can be written `contiguous<2>` + for brevity and will be omitted from printing without an offset), the stride + of the N dimension will be 1 element while the stride of the M dimension will be + N elements. + + As a more complex example, `memref>` + , where A, B, C, and D are potentially dynamic values, means that + the value at index `[%i, %j, %k]` is located `%k * A * B + %i * B + %j + D` + elements from the beginning of the memory underlying that memref. + + The permutation must contain the integers between 0 and the rank of the memref - 1, + and must have one distinct entry for each memref dimension. The value + `[0, 1, ..., N-1]`, specifying a row-major format, may be printed as `N` + for clarity. + + If an offset is specified, it is a number of elements to move within + the underlying linear memory after the permutation is applied. This offset + may be _dynamic_, meaning that it may not be known at compile time. + A dynamic offset is represented as a `?` in the assembly syntax and as + `ShapedType::kDynamic` in the code. The offset must be non-negative. + + See [Dialects/Builtin.md#memreftype](MemRef type) for more information. + }]; + + let parameters = (ins + "int64_t":$offset, + ArrayRefParameter< + "int64_t", + "permutation (64-bit integer)" + >:$permutation + ); + + let builders = [ + // Builder for row-major contiguous attribute. + AttrBuilder<(ins "int64_t":$offset, "int64_t":$rank)> + ]; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + /// Print the attribute to the given output stream. + void print(raw_ostream &os) const; + + /// Returns true if this layout is static, i.e. the offset has a static value. + bool hasStaticLayout() const; + + /// Return true if this layout has a row-major permutation - that is, the + /// dimensions of the shape are not permuted. + bool isRowMajor() const; + }]; +} //===----------------------------------------------------------------------===// // StringAttr diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index af474b3e3ec47..db21dfd656161 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -585,20 +585,27 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ layout must avoid internal aliasing, i.e., two distinct tuples of _in-bounds_ indices must be pointing to different elements in memory. The layout is an attribute that implements `MemRefLayoutAttrInterface`. The - bulitin dialect offers two kinds of layouts: strided and affine map, each - of which is available as an attribute. Other attributes may be used to - represent the layout as long as they can be converted to a + bulitin dialect offers three kinds of layouts: contiguous, strided and + affine map, each of which is available as an attribute. Other attributes may be + used to represent the layout as long as they can be converted to a [semi-affine map](Affine.md/#semi-affine-maps) and implement the required interface. Users of memref are expected to fallback to the affine representation when handling unknown memref layouts. Multi-dimensional affine forms are interpreted in _row-major_ fashion. In absence of an explicit layout, a memref is considered to have a - multi-dimensional identity affine map layout. Identity layout maps do not - contribute to the MemRef type identification and are discarded on - construction. That is, a type with an explicit identity map is + row-major contiguous layout with an offset of 0, which is equivalent + to a multi-dimensional identity map. For backwards compatibility, + identity layout maps do not contribute to the MemRef type identification and + are discarded on construction. That is, a type with an explicit identity map is `memref(i,j)>` is strictly the same as the one without a - layout, `memref`. + layout, `memref`, which, written explicitly, has the layout + `memref>`. + + The built-in layouts form a hierarchy: all contiguous layuts are strided layouts, + and all strided layouts are affine map layouts, but the reverse is not true. + Using a more specific layout may permit a greater degree of optimization in + the generated code. ##### Affine Map Layout @@ -656,6 +663,37 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ Therefore, it is never subject to the implicit row-major layout interpretation. + ### Contiguous layout + + The most restricted of the built-in layouts is the _contiguous_ layout, which + expresses the fact that the in-memory layout of the memref would be row-major + without padding after the associated permutation is applied. Equivalently, + a contigous layout is a strided layout where the strides are implicitly computed + from the (permuted) sizes of the memref. + + This layout is necessary to allow optimizations during lowering passes in the + presence of dynamic sizes, since + `memref>` doesn't specify if it's + dimensions have padding in between tem or not - the two non-1 strides are + dynamic. By contrast, `contiguous<3, offset: ?>` indiates a row-major layout + with an offset, while `contiguous<[2, 1, 0], offset: ?>` indicates a + column-major layout. While this scheme could be expressed with an affine map, + some operations expect memrefs to be in a form compatible with the `strided` + layout, which can be difficult to detect from analyzing an affine expression. + + In general, the layout `contiguous<[p0, p1, ..., pN], offset: V>` + corresponds to the affine map + + ```mlir + affine_map<(d0, ..., dN) -> (d[p0], d[p1], ... + d[pN] + V)> + ``` + + where `V` is either `s0` if it is dynamic or some constant value. + + For convenience, the layout `contigous<[0, 1, ..., N], offset: V>` is printed + as `contigous`, and the `, offset: V` segment is omitted if `V` + is `0`. + ##### Codegen of Unranked Memref Using unranked memref in codegen besides the case mentioned above is highly @@ -815,6 +853,10 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ /// considering both _all_ and _only_ the trailing 3 dims, /// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when /// considering the trailing 3 dims. + /// - memref> is contiguous when + /// considering all dimensions. + /// - memref> is + /// _only_ contiguous when considering the trailing 2 dimensions. /// bool areTrailingDimsContiguous(int64_t n); @@ -830,8 +872,8 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ /// Returns the strides of the MemRef if the layout map is in strided form. /// MemRefs with a layout map in strided form include: - /// 1. empty or identity layout map, in which case the stride information - /// is the canonical form computed from sizes; + /// 1. the empty layout, the identity layout affine map, and any ContigousLayoutAttr, + /// in which case the stride information is the canonical form computed from sizes; /// 2. a StridedLayoutAttr layout; /// 3. any other layout that be converted into a single affine map layout /// of the form `K + k0 * d0 + ... kn * dn`, where K and ki's are diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index 2013d3623711b..0175c6cc3d093 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -156,6 +156,10 @@ Attribute Parser::parseAttribute(Type type) { case Token::kw_strided: return parseStridedLayoutAttr(); + // Parse a contiguous layout attribute. + case Token::kw_contiguous: + return parseContiguousLayoutAttr(); + // Parse a distinct attribute. case Token::kw_distinct: return parseDistinctAttr(type); @@ -1100,6 +1104,88 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { return getChecked(loc, type, indices, values); } +Attribute Parser::parseContiguousLayoutAttr() { + // Callback for error emissing at the keyword token location. + llvm::SMLoc loc = getToken().getLoc(); + auto errorEmitter = [&] { return emitError(loc); }; + + consumeToken(Token::kw_contiguous); + if (failed(parseToken(Token::less, "expected '<' after 'contiguous'"))) + return nullptr; + + auto parseNonNegativeInteger = [&]() -> std::optional { + Token curTok = getToken(); + if (!consumeIf(Token::integer)) + return std::nullopt; + + std::optional parsedVal = curTok.getUInt64IntegerValue(); + if (parsedVal.has_value() && + *parsedVal <= std::numeric_limits::max()) + return *parsedVal; + return std::nullopt; + }; + SmallVector permutation; + Token permStart = getToken(); + if (permStart.getKind() == Token::integer) { + std::optional rowMajorRank = parseNonNegativeInteger(); + if (!rowMajorRank) { + emitError(permStart.getLoc()) + << "expected short-form permutation rank to fit within 64 bits"; + return nullptr; + } + llvm::append_range(permutation, llvm::iota_range( + 0, *rowMajorRank, /*Inclusive=*/false)); + } else if (permStart.getKind() == Token::l_square) { + auto parsedPerm = parseCommaSeparatedList(Delimiter::Square, [&]() { + std::optional elem = parseNonNegativeInteger(); + if (!elem.has_value()) { + emitError(getToken().getLoc()) << "expected non-negative integer"; + return ParseResult::failure(); + } + permutation.push_back(*elem); + return ParseResult::success(); + }); + if (failed(parsedPerm)) + return nullptr; + } else { + emitError(permStart.getLoc()) << "expected non-negative integer or '['"; + return nullptr; + } + + // Fast path in absence of offset. + if (consumeIf(Token::greater)) { + if (failed(ContiguousLayoutAttr::verify(errorEmitter, + /*offset=*/0, permutation))) + return nullptr; + return ContiguousLayoutAttr::get(getContext(), /*offset=*/0, permutation); + } + + if (failed(parseToken(Token::comma, "expected ',' or '>'")) || + failed(parseToken(Token::kw_offset, "expected 'offset' after comma")) || + failed(parseToken(Token::colon, "expected ':' after 'offset'"))) + return nullptr; + + std::optional offset; + if (consumeIf(Token::question)) + offset = ShapedType::kDynamic; + else + offset = parseNonNegativeInteger(); + + if (!offset) { + emitError(getToken().getLoc(), + "expected non-negative integer or '?' after ':'"); + return nullptr; + } + if (failed(parseToken(Token::greater, "expected '>'"))) + return nullptr; + + if (failed(ContiguousLayoutAttr::verify(errorEmitter, *offset, permutation))) + return nullptr; + return ContiguousLayoutAttr::get(getContext(), *offset, permutation); + // return getChecked(loc,getContext(), *offset, + // permutation); +} + Attribute Parser::parseStridedLayoutAttr() { // Callback for error emissing at the keyword token location. llvm::SMLoc loc = getToken().getLoc(); diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index ecc128cf767b3..760424d99824a 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -273,6 +273,9 @@ class Parser { /// Parse an attribute dictionary. ParseResult parseAttributeDict(NamedAttrList &attributes); + /// Parse a contiguous layout attribute. + Attribute parseContiguousLayoutAttr(); + /// Parse a distinct attribute. Attribute parseDistinctAttr(Type type); diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def index 49da8c3dea5fa..211ad3684edbc 100644 --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -87,6 +87,7 @@ TOK_KEYWORD(attributes) TOK_KEYWORD(bf16) TOK_KEYWORD(ceildiv) TOK_KEYWORD(complex) +TOK_KEYWORD(contiguous) TOK_KEYWORD(dense) TOK_KEYWORD(dense_resource) TOK_KEYWORD(distinct) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 12725a0ed0939..be1c87f8d3e3c 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1699,6 +1699,60 @@ class PyStridedLayoutAttribute } }; +/// Contiguous layout attribute subclass. +class PyContiguousLayoutAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAContiguousLayout; + static constexpr const char *pyClassName = "ContiguousLayoutAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirContiguousLayoutAttrGetTypeID; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](int64_t offset, const std::vector permutation, + DefaultingPyMlirContext ctx) { + MlirAttribute attr = mlirContiguousLayoutAttrGet( + ctx->get(), offset, permutation.size(), permutation.data()); + return PyContiguousLayoutAttribute(ctx->getRef(), attr); + }, + nb::arg("offset"), nb::arg("permutation"), + nb::arg("context").none() = nb::none(), + "Gets a contiguous layout attribute."); + c.def_static( + "get_row_major", + [](int64_t offset, int64_t rank, DefaultingPyMlirContext ctx) { + MlirAttribute attr = + mlirContiguousLayoutAttrGetRowMajor(ctx->get(), offset, rank); + return PyContiguousLayoutAttribute(ctx->getRef(), attr); + }, + nb::arg("offset"), nb::arg("rank"), + nb::arg("context").none() = nb::none(), + "Gets a contiguous layout attribute with the given offset and a " + "row-major layout (the identity permutation)"); + c.def_prop_ro( + "offset", + [](PyContiguousLayoutAttribute &self) { + return mlirContiguousLayoutAttrGetOffset(self); + }, + "Returns the value of the offset"); + c.def_prop_ro( + "permutation", + [](PyContiguousLayoutAttribute &self) { + intptr_t rank = mlirContiguousLayoutAttrGetRank(self); + std::vector permutation(rank); + for (intptr_t i = 0; i < rank; ++i) { + permutation[i] = + mlirContiguousLayoutAttrGetPermutationEntry(self, i); + } + return permutation; + }, + "Returns the value of the permutation in this contiguous layout"); + } +}; + nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) return nb::cast(PyDenseBoolArrayAttribute(pyAttribute)); @@ -1808,4 +1862,5 @@ void mlir::python::populateIRAttributes(nb::module_ &m) { PyUnitAttribute::bind(m); PyStridedLayoutAttribute::bind(m); + PyContiguousLayoutAttribute::bind(m); } diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 8d57ab6b59e79..bb8564f291dad 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -981,6 +981,13 @@ MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, ArrayRef(strides, numStrides))); } +MlirAttribute mlirStridedLayoutAttrGetCanonical(MlirContext ctx, int64_t offset, + intptr_t numStrides, + const int64_t *strides) { + return wrap(StridedLayoutAttr::getCanonical( + unwrap(ctx), offset, ArrayRef(strides, numStrides))); +} + int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getOffset(); } @@ -997,3 +1004,42 @@ int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { MlirTypeID mlirStridedLayoutAttrGetTypeID(void) { return wrap(StridedLayoutAttr::getTypeID()); } + +//===----------------------------------------------------------------------===// +// Contiguous layout attribute. +//===----------------------------------------------------------------------===// + +bool mlirAttributeIsAContiguousLayout(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirAttribute mlirContiguousLayoutAttrGet(MlirContext ctx, int64_t offset, + intptr_t rank, + const int64_t *permutation) { + return wrap(ContiguousLayoutAttr::get(unwrap(ctx), offset, + ArrayRef(permutation, rank))); +} + +MlirAttribute mlirContiguousLayoutAttrGetRowMajor(MlirContext ctx, + int64_t offset, + intptr_t rank) { + return wrap(ContiguousLayoutAttr::get(unwrap(ctx), offset, rank)); +} + +int64_t mlirContiguousLayoutAttrGetOffset(MlirAttribute attr) { + return llvm::cast(unwrap(attr)).getOffset(); +} + +intptr_t mlirContiguousLayoutAttrGetRank(MlirAttribute attr) { + return static_cast( + llvm::cast(unwrap(attr)).getPermutation().size()); +} + +int64_t mlirContiguousLayoutAttrGetPermutationEntry(MlirAttribute attr, + intptr_t pos) { + return llvm::cast(unwrap(attr)).getPermutation()[pos]; +} + +MlirTypeID mlirContiguousLayoutAttrGetTypeID(void) { + return wrap(ContiguousLayoutAttr::getTypeID()); +} diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 2b2a167b90c82..f8867176dc3b5 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -75,10 +75,17 @@ static FailureOr getFatRawBufferTypeLike(MemRefType source, amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer)); MemRefLayoutAttrInterface layout = source.getLayout(); if (resetOffset && !layout.isIdentity()) { - auto stridedLayout = dyn_cast(layout); - if (!stridedLayout) + MemRefLayoutAttrInterface newLayout; + if (auto stridedLayout = dyn_cast(layout)) { + newLayout = + StridedLayoutAttr::get(ctx, /*offset=*/0, stridedLayout.getStrides()); + } else if (auto contiguousLayout = dyn_cast(layout)) { + newLayout = ContiguousLayoutAttr::get(ctx, /*offset=*/0, + contiguousLayout.getPermutation()); + } + if (!newLayout) return failure(); - mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides())); + mb.setLayout(newLayout); } return (MemRefType)(mb); } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index dd539ff685653..3479b87ebedb3 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -2096,10 +2096,12 @@ static LogicalResult generateCopy( // Check if a buffer was already created. bool existingBuf = fastBufferMap.count(memref) > 0; if (!existingBuf) { - AffineMap fastBufferLayout = b.getMultiDimIdentityMap(rank); + Attribute fastMemorySpace; + if (copyOptions.fastMemorySpace != 0) + fastMemorySpace = prologue.getI64IntegerAttr(copyOptions.fastMemorySpace); auto fastMemRefType = MemRefType::get(fastBufferShape, memRefType.getElementType(), - fastBufferLayout, copyOptions.fastMemorySpace); + MemRefLayoutAttrInterface{}, fastMemorySpace); // Create the fast memory space buffer just before the 'affine.for' // operation. diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 2723cff6900d0..7288557f3e9ca 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1925,8 +1925,7 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) { auto newMemRefType = MemRefType::Builder(memrefType) .setShape(newShape) - .setLayout(AffineMapAttr::get( - AffineMap::getMultiDimIdentityMap(newRank, context))); + .setLayout(ContiguousLayoutAttr::get(context, /*offset=*/0, newRank)); return newMemRefType; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index e5a0c3c45b09e..798f992e6815f 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -96,9 +96,9 @@ LogicalResult BufferizationDialect::verifyRegionArgAttribute( return success(); } if (attr.getName() == kBufferLayoutAttrName) { - if (!llvm::isa(attr.getValue())) { + if (!llvm::isa(attr.getValue())) { return op->emitError() << "'" << kBufferLayoutAttrName - << "' is expected to be a affine map attribute"; + << "' is expected to be a memref layout attribute"; } if (!isa(op)) return op->emitError() << "expected '" << kBufferLayoutAttrName diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index c45678f1e4b4d..86d06677520b5 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -63,16 +63,23 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, BaseMemRefType memrefType = options.functionArgTypeConverterFn( tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); - auto layoutAttr = funcOp.getArgAttrOfType( + auto layoutAttr = funcOp.getArgAttrOfType( index, BufferizationDialect::kBufferLayoutAttrName); if (!layoutAttr) return memrefType; + if (auto affineLayoutAttr = dyn_cast(layoutAttr)) { + // Support using an identity affine map as a synonym for row-major + // contiguous layouts. + if (affineLayoutAttr.isIdentity()) + layoutAttr = ContiguousLayoutAttr::get(funcOp.getContext(), /*offset=*/0, + tensorType.getRank()); + } auto rankedMemrefType = dyn_cast(memrefType); assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); - return MemRefType::get( - rankedMemrefType.getShape(), rankedMemrefType.getElementType(), - layoutAttr.getValue(), rankedMemrefType.getMemorySpace()); + return MemRefType::get(rankedMemrefType.getShape(), + rankedMemrefType.getElementType(), layoutAttr, + rankedMemrefType.getMemorySpace()); } /// Return the FuncOp called by `callOp`. diff --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp index a64dc7f74a19c..d13a48125f81e 100644 --- a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp @@ -34,9 +34,9 @@ static MemRefType inferCastResultType(Value source, OpFoldResult offset) { SmallVector staticOffsets; SmallVector dynamicOffsets; dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets); - auto stridedLayout = - StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {}); - return MemRefType::get({}, sourceType.getElementType(), stridedLayout, + auto contiguousLayout = + ContiguousLayoutAttr::get(source.getContext(), staticOffsets.front(), {}); + return MemRefType::get({}, sourceType.getElementType(), contiguousLayout, sourceType.getMemorySpace()); } @@ -52,6 +52,52 @@ static bool isInsideLaunch(Operation *op) { return op->getParentOfType(); } +static std::tuple +linearizeContiguousIndex(OpBuilder &rewriter, Location loc, Value source, + ValueRange offsets) { + auto sourceType = cast(source.getType()); + auto contigLayout = dyn_cast(sourceType.getLayout()); + if (!contigLayout) + return std::make_tuple(Value{}, OpFoldResult{}); + auto sourceRank = static_cast(sourceType.getRank()); + + memref::ExtractStridedMetadataOp newExtractStridedMetadata; + { + OpBuilder::InsertionGuard g(rewriter); + setInsertionPointToStart(rewriter, source); + newExtractStridedMetadata = + rewriter.create(loc, source); + } + + auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult { + return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal) + : rewriter.getIndexAttr(dim); + }; + OpFoldResult origOffset = + getDim(contigLayout.getOffset(), newExtractStridedMetadata.getOffset()); + + ValueRange dynSizes = newExtractStridedMetadata.getSizes(); + SmallVector basis; + basis.reserve(sourceRank); + ArrayRef shape = sourceType.getShape(); + SmallVector permutedOffsets; + permutedOffsets.reserve(sourceRank); + for (int64_t dim : contigLayout.getPermutation()) { + basis.push_back(getDim(shape[dim], dynSizes[dim])); + permutedOffsets.push_back(offsets[dim]); + } + OpFoldResult newOffset = + rewriter.createOrFold( + loc, permutedOffsets, basis, /*disjoint=*/true); + if (contigLayout.getOffset() == 0) + return {newExtractStridedMetadata.getBaseBuffer(), newOffset}; + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + OpFoldResult totalOffset = affine::makeComposedFoldedAffineApply( + rewriter, loc, s0 + s1, {newOffset, origOffset}); + return {newExtractStridedMetadata.getBaseBuffer(), totalOffset}; +} + static std::tuple> getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, ArrayRef subOffsets, @@ -106,9 +152,15 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, ValueRange offsets) { - SmallVector offsetsTemp = getAsOpFoldResult(offsets); - auto &&[base, offset, ignore] = - getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp); + Value base; + OpFoldResult offset; + std::tie(base, offset) = + linearizeContiguousIndex(rewriter, loc, source, offsets); + if (!offset) { + SmallVector offsetsTemp = getAsOpFoldResult(offsets); + std::tie(base, offset, std::ignore) = + getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp); + } MemRefType retType = inferCastResultType(base, offset); return rewriter.create(loc, retType, base, offset, std::nullopt, std::nullopt); @@ -122,7 +174,7 @@ static bool needFlatten(Value val) { static bool checkLayout(Value val) { auto type = cast(val.getType()); return type.getLayout().isIdentity() || - isa(type.getLayout()); + isa(type.getLayout()); } namespace { diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 59434dccc117b..9e47262f31753 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -23,6 +24,7 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::memref; @@ -1846,7 +1848,7 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - auto stridedLayout = StridedLayoutAttr::get( + auto stridedLayout = StridedLayoutAttr::getCanonical( b.getContext(), staticOffsets.front(), staticStrides); auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(), stridedLayout, sourceType.getMemorySpace()); @@ -2225,9 +2227,35 @@ SmallVector ExpandShapeOp::getReassociationExprs() { /// Compute the layout map after expanding a given source MemRef type with the /// specified reassociation indices. -static FailureOr +static FailureOr computeExpandedLayoutMap(MemRefType srcType, ArrayRef resultShape, ArrayRef reassociation) { + // Special case: expanding the dimensions of a contiguous shape creates a + // contiguous shape by applying the permutation to the reassociation maps. + if (auto contigLayout = dyn_cast(srcType.getLayout())) { + int64_t srcOffset = contigLayout.getOffset(); + SmallVector srcDimsBySpeed = + invertPermutationVector(contigLayout.getPermutation()); + SmallVector resultPerm(resultShape.size(), -1); + // Invert the permutation to order the source dimensions by where + // they appear if you order them in a row-major sense, then expand that + // to construct the new permutation. + int64_t nextIndex = 0; + for (int64_t srcDim : srcDimsBySpeed) { + for (int64_t reassoc : ArrayRef(reassociation[srcDim])) { + resultPerm[reassoc] = nextIndex++; + } + } + // Fill in any dimensions that we're sneaking in to the end + for (int64_t &permEntry : resultPerm) { + if (permEntry == -1) + permEntry = nextIndex++; + } + + return cast( + ContiguousLayoutAttr::get(srcType.getContext(), srcOffset, resultPerm)); + } + int64_t srcOffset; SmallVector srcStrides; if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset))) @@ -2262,7 +2290,8 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef resultShape, } auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides)); resultStrides.resize(resultShape.size(), 1); - return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides); + return StridedLayoutAttr::getCanonical(srcType.getContext(), srcOffset, + resultStrides); } FailureOr ExpandShapeOp::computeExpandedType( @@ -2277,7 +2306,7 @@ FailureOr ExpandShapeOp::computeExpandedType( } // Source may not be contiguous. Compute the layout map. - FailureOr computedLayout = + FailureOr computedLayout = computeExpandedLayoutMap(srcType, resultShape, reassociation); if (failed(computedLayout)) return failure(); @@ -2420,13 +2449,49 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, /// not possible to check this by inspecting a MemRefType in the general case. /// If non-contiguity cannot be checked statically, the collapse is assumed to /// be valid (and thus accepted by this function) unless `strict = true`. -static FailureOr +static FailureOr computeCollapsedLayoutMap(MemRefType srcType, ArrayRef reassociation, bool strict = false) { + auto srcShape = srcType.getShape(); + // Special case for contiguous layouts. + if (auto contigLayout = dyn_cast(srcType.getLayout())) { + int64_t srcOffset = contigLayout.getOffset(); + ArrayRef srcPerm = contigLayout.getPermutation(); + // Store (smallest permutation in group, reassoc index) so we know + // which reassociation is innermost, outermost, etc. This is because we + // want to preserve the permutation of the dimensions that aren't being + // collapsed together. For example, we can have memref, which is column-major, being collapsed by [[0], + // [2, 1], [3]], which should produce a memref, because dim 3 is the slowest-moving one and dim 0 is fastest-moving. + SmallVector> minPermAndResLogicalIdxs; + + for (auto [resultIdx, reassoc] : llvm::enumerate(reassociation)) { + ArrayRef ref = llvm::ArrayRef(reassoc); + while (srcShape[ref.back()] == 1 && ref.size() > 1) + ref = ref.drop_back(); + int64_t minPerm = srcPerm[ref.front()]; + if (!llvm::all_of(llvm::enumerate(ref), [&](auto e) { + return srcPerm[e.value()] == + minPerm + static_cast(e.index()); + })) { + return failure(); + } + minPermAndResLogicalIdxs.emplace_back(minPerm, resultIdx); + } + llvm::sort(minPermAndResLogicalIdxs); + SmallVector resultPerm(reassociation.size(), -1); + for (auto [permRes, srcMinPermAndLogicalIdxRes] : + llvm::enumerate(minPermAndResLogicalIdxs)) { + resultPerm[std::get<1>(srcMinPermAndLogicalIdxRes)] = permRes; + } + return cast(ContiguousLayoutAttr::get( + contigLayout.getContext(), srcOffset, resultPerm)); + } + int64_t srcOffset; SmallVector srcStrides; - auto srcShape = srcType.getShape(); if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset))) return failure(); @@ -2481,7 +2546,8 @@ computeCollapsedLayoutMap(MemRefType srcType, return failure(); } } - return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides); + return StridedLayoutAttr::getCanonical(srcType.getContext(), srcOffset, + resultStrides); } bool CollapseShapeOp::isGuaranteedCollapsible( @@ -2517,7 +2583,7 @@ MemRefType CollapseShapeOp::computeCollapsedType( // Source may not be fully contiguous. Compute the layout map. // Note: Dimensions that are collapsed into a single dim are assumed to be // contiguous. - FailureOr computedLayout = + FailureOr computedLayout = computeCollapsedLayoutMap(srcType, reassociation); assert(succeeded(computedLayout) && "invalid source layout map or collapsing non-contiguous dims"); @@ -2567,7 +2633,7 @@ LogicalResult CollapseShapeOp::verify() { // Source may not be fully contiguous. Compute the layout map. // Note: Dimensions that are collapsed into a single dim are assumed to be // contiguous. - FailureOr computedLayout = + FailureOr computedLayout = computeCollapsedLayoutMap(srcType, getReassociationIndices()); if (failed(computedLayout)) return emitOpError( @@ -2738,10 +2804,11 @@ MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType, } // The type is now known. - return MemRefType::get(staticSizes, sourceMemRefType.getElementType(), - StridedLayoutAttr::get(sourceMemRefType.getContext(), - targetOffset, targetStrides), - sourceMemRefType.getMemorySpace()); + return MemRefType::get( + staticSizes, sourceMemRefType.getElementType(), + StridedLayoutAttr::getCanonical(sourceMemRefType.getContext(), + targetOffset, targetStrides), + sourceMemRefType.getMemorySpace()); } MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType, @@ -2780,18 +2847,36 @@ MemRefType SubViewOp::inferRankReducedResultType( assert(dimsToProject.has_value() && "invalid rank reduction"); // Compute the layout and result type. - auto inferredLayout = llvm::cast(inferredType.getLayout()); + + int64_t offset = 0; SmallVector rankReducedStrides; rankReducedStrides.reserve(resultShape.size()); - for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) { - if (!dimsToProject->contains(idx)) - rankReducedStrides.push_back(value); - } - return MemRefType::get(resultShape, inferredType.getElementType(), - StridedLayoutAttr::get(inferredLayout.getContext(), - inferredLayout.getOffset(), - rankReducedStrides), - inferredType.getMemorySpace()); + llvm::TypeSwitch(inferredType.getLayout()) + .Case([&](StridedLayoutAttr inferredLayout) { + for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) { + if (!dimsToProject->contains(idx)) + rankReducedStrides.push_back(value); + } + offset = inferredLayout.getOffset(); + }) + .Case([&](ContiguousLayoutAttr inferredLayout) { + assert(inferredLayout.getPermutation().size() <= 1 && + "Only 0- and 1-D values can be identity-like in subviews"); + // The result shape has no strides at all (0D) or a stride of length 1 + // (1D), if it got sent into this case. + rankReducedStrides.assign(resultShape.size(), 1ll); + offset = inferredLayout.getOffset(); + }) + .Default([](MemRefLayoutAttrInterface) { + llvm_unreachable( + "unexpected non-stride-like layout in subview type inference"); + }); + + return MemRefType::get( + resultShape, inferredType.getElementType(), + StridedLayoutAttr::getCanonical(inferredType.getContext(), offset, + rankReducedStrides), + inferredType.getMemorySpace()); } MemRefType SubViewOp::inferRankReducedResultType( @@ -3080,25 +3165,45 @@ static MemRefType getCanonicalSubViewResultType( if (failed(unusedDims)) return nullptr; - auto layout = llvm::cast(nonRankReducedType.getLayout()); + int64_t offset = 0; SmallVector shape, strides; unsigned numDimsAfterReduction = nonRankReducedType.getRank() - unusedDims->count(); shape.reserve(numDimsAfterReduction); strides.reserve(numDimsAfterReduction); - for (const auto &[idx, size, stride] : - llvm::zip(llvm::seq(0, nonRankReducedType.getRank()), - nonRankReducedType.getShape(), layout.getStrides())) { - if (unusedDims->test(idx)) - continue; - shape.push_back(size); - strides.push_back(stride); - } + llvm::TypeSwitch(nonRankReducedType.getLayout()) + .Case([&](StridedLayoutAttr layout) { + offset = layout.getOffset(); + for (const auto &[idx, size, stride] : + llvm::zip(llvm::seq(0, nonRankReducedType.getRank()), + nonRankReducedType.getShape(), layout.getStrides())) { + if (unusedDims->test(idx)) + continue; + shape.push_back(size); + strides.push_back(stride); + } + }) + .Case([&](ContiguousLayoutAttr layout) { + assert(nonRankReducedType.getRank() <= 1 && + "Only 0D and 1D memrefs can have contiguous non-rank-reduced " + "layout in subview type inference"); + offset = layout.getOffset(); + if (nonRankReducedType.getRank() == 1 && !unusedDims->test(0)) { + shape.push_back(nonRankReducedType.getShape().front()); + strides.push_back(1); + } + // Otherwise, either it's a 0D memref or we're not using the one + // dimension. + }) + .Default([](MemRefLayoutAttrInterface) { + llvm_unreachable( + "unexpected layout kind in rank-reduced subview inference"); + }); - return MemRefType::get(shape, nonRankReducedType.getElementType(), - StridedLayoutAttr::get(sourceType.getContext(), - layout.getOffset(), strides), - nonRankReducedType.getMemorySpace()); + return MemRefType::get( + shape, nonRankReducedType.getElementType(), + StridedLayoutAttr::getCanonical(sourceType.getContext(), offset, strides), + nonRankReducedType.getMemorySpace()); } Value mlir::memref::createCanonicalRankReducingSubViewOp( @@ -3277,10 +3382,11 @@ struct SubViewReturnTypeCanonicalizer { targetShape.push_back(nonReducedType.getDimSize(i)); } - return MemRefType::get(targetShape, nonReducedType.getElementType(), - StridedLayoutAttr::get(nonReducedType.getContext(), - offset, targetStrides), - nonReducedType.getMemorySpace()); + return MemRefType::get( + targetShape, nonReducedType.getElementType(), + StridedLayoutAttr::getCanonical(nonReducedType.getContext(), offset, + targetStrides), + nonReducedType.getMemorySpace()); } }; @@ -3302,12 +3408,17 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { MemRefType sourceMemrefType = getSource().getType(); MemRefType resultMemrefType = getResult().getType(); - auto resultLayout = - dyn_cast_if_present(resultMemrefType.getLayout()); + // We assume that if the layout isn't strided, someone isn't using + // subview to manipulate a dynamic offset. + bool hasStaticOffset = + llvm::TypeSwitch( + resultMemrefType.getLayout()) + .Case([](StridedLayoutAttr a) { return a.hasStaticLayout(); }) + .Case([](ContiguousLayoutAttr a) { return a.hasStaticLayout(); }) + .Default(true); if (resultMemrefType == sourceMemrefType && - resultMemrefType.hasStaticShape() && - (!resultLayout || resultLayout.hasStaticLayout())) { + resultMemrefType.hasStaticShape() && (hasStaticOffset)) { return getViewSource(); } @@ -3345,17 +3456,28 @@ void TransposeOp::getAsmResultNames( static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap) { auto originalSizes = memRefType.getShape(); - auto [originalStrides, offset] = memRefType.getStridesAndOffset(); - assert(originalStrides.size() == static_cast(memRefType.getRank())); - - // Compute permuted sizes and strides. + // Compute permuted sizes. auto sizes = applyPermutationMap(permutationMap, originalSizes); - auto strides = applyPermutationMap(permutationMap, originalStrides); - return MemRefType::Builder(memRefType) - .setShape(sizes) - .setLayout( - StridedLayoutAttr::get(memRefType.getContext(), offset, strides)); + MemRefLayoutAttrInterface newLayout; + if (auto contigLayout = + dyn_cast(memRefType.getLayout())) { + int64_t srcOffset = contigLayout.getOffset(); + auto permutation = applyPermutationMap( + permutationMap, contigLayout.getPermutation()); + newLayout = ContiguousLayoutAttr::get(memRefType.getContext(), srcOffset, + permutation); + } else { + auto [originalStrides, offset] = memRefType.getStridesAndOffset(); + assert(originalStrides.size() == + static_cast(memRefType.getRank())); + + auto strides = + applyPermutationMap(permutationMap, originalStrides); + newLayout = + StridedLayoutAttr::get(memRefType.getContext(), offset, strides); + } + return MemRefType::Builder(memRefType).setShape(sizes).setLayout(newLayout); } void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index f58385a7777db..377181d94f051 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -642,23 +642,20 @@ void memref::populateMemRefNarrowTypeEmulationConversions( if (!newElemTy) return nullptr; - StridedLayoutAttr layoutAttr; - // If the offset is 0, we do not need a strided layout as the stride is - // 1, so we only use the strided layout if the offset is not 0. - if (offset != 0) { - if (offset == ShapedType::kDynamic) { - layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, - ArrayRef{1}); - } else { - // Check if the number of bytes are a multiple of the loadStoreWidth - // and if so, divide it by the loadStoreWidth to get the offset. - if ((offset * width) % loadStoreWidth != 0) - return std::nullopt; - offset = (offset * width) / loadStoreWidth; - - layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, - ArrayRef{1}); - } + int64_t newRank = std::min(ty.getRank(), (int64_t)1); + ContiguousLayoutAttr layoutAttr; + if (offset == ShapedType::kDynamic) { + layoutAttr = + ContiguousLayoutAttr::get(ty.getContext(), offset, newRank); + } else { + // Check if the number of bytes are a multiple of the loadStoreWidth + // and if so, divide it by the loadStoreWidth to get the offset. + if ((offset * width) % loadStoreWidth != 0) + return std::nullopt; + offset = (offset * width) / loadStoreWidth; + + layoutAttr = + ContiguousLayoutAttr::get(ty.getContext(), offset, newRank); } return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth), diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 4ac6eca586961..e2b51d4de5b78 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -190,7 +190,7 @@ struct CollapseShapeOpInterface return failure(); resultType = MemRefType::get( {}, tensorResultType.getElementType(), - StridedLayoutAttr::get(op->getContext(), offset, {}), + ContiguousLayoutAttr::get(op->getContext(), offset, {}), bufferType.getMemorySpace()); } diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 0336423c57b1d..36269d80e199a 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -259,9 +259,9 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible( return success(); } -bool mlir::hasNonIdentityLayout(Type type) { +bool mlir::hasNonRowMajorContiguousLayout(Type type) { if (auto memrefType = dyn_cast(type)) - return !memrefType.getLayout().isIdentity(); + return !memrefType.areTrailingDimsContiguous(memrefType.getRank()); return false; } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 3000204c8ce17..ada9ff6862383 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -140,22 +140,25 @@ struct RemoveStrideFromGatherSource : OpRewritePattern { return failure(); // Get strides - auto layout = subview.getResult().getType().getLayout(); - auto stridedLayoutAttr = llvm::dyn_cast(layout); - if (!stridedLayoutAttr) - return failure(); - - // TODO: Allow the access to be strided in multiple dimensions. - if (stridedLayoutAttr.getStrides().size() != 1) - return failure(); - int64_t srcTrailingDim = sourceType.getShape().back(); - - // Assume that the stride matches the trailing dimension of the source - // memref. - // TODO: Relax this assumption. - if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim) + auto layout = subview.getResult().getType().getLayout(); + if (auto contigLayoutAttr = dyn_cast(layout)) { + // TODO: relax this to N-D layouts. + if (contigLayoutAttr.getPermutation() != ArrayRef{0, 1}) + return failure(); + } else if (auto stridedLayoutAttr = dyn_cast(layout)) { + // TODO: Allow the access to be strided in multiple dimensions. + if (stridedLayoutAttr.getStrides().size() != 1) + return failure(); + + // Assume that the stride matches the trailing dimension of the source + // memref. + // TODO: Relax this assumption. + if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim) + return failure(); + } else { return failure(); + } // 1. Collapse the input memref so that it's "flat". SmallVector reassoc = {{0, 1}}; @@ -209,12 +212,9 @@ struct Gather1DToConditionalLoads : OpRewritePattern { // vector.load requires the most minor memref dim to have unit stride // (unless reading exactly 1 element) if (auto memType = dyn_cast(base.getType())) { - if (auto stridesAttr = - dyn_cast_if_present(memType.getLayout())) { - if (stridesAttr.getStrides().back() != 1 && - resultTy.getNumElements() != 1) - return failure(); - } + if (!memType.areTrailingDimsContiguous(1) && + resultTy.getNumElements() != 1) + return failure(); } Value indexVec = rewriter.createOrFold( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 62dfd439b0ad1..1b6f897ffe8ca 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -543,39 +543,15 @@ static SmallVector getCollapsedIndices(RewriterBase &rewriter, return indicesAfterCollapsing; } - // Compute the remaining trailing index/offset required for reading from - // the collapsed memref: - // - // offset = 0 - // for (i = firstDimToCollapse; i < outputRank; ++i) - // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] - // - // For this example: - // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) : - // memref<1x43x2xi32>, vector<1x2xi32> - // which would be collapsed to: - // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) : - // memref<1x86xi32>, vector<2xi32> - // one would get the following offset: - // %offset = %arg0 * 43 - OpFoldResult collapsedOffset = - rewriter.create(loc, 0).getResult(); - - auto collapsedStrides = computeSuffixProduct( - ArrayRef(shape.begin() + firstDimToCollapse, shape.end())); - - // Compute the collapsed offset. - auto &&[collapsedExpr, collapsedVals] = - computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse); - collapsedOffset = affine::makeComposedFoldedAffineApply( - rewriter, loc, collapsedExpr, collapsedVals); - - if (auto value = dyn_cast(collapsedOffset)) { - indicesAfterCollapsing.push_back(value); - } else { - indicesAfterCollapsing.push_back(rewriter.create( - loc, *getConstantIntValue(collapsedOffset))); - } + ArrayRef collapseBasis = shape.take_back(indicesToCollapse.size()); + // If the outermost of the collapsing dimensions is dynamic, + // we can still do this rewrite, as the transfer_* is known to be `inbounds` + if (!collapseBasis.empty() && ShapedType::isDynamic(collapseBasis.front())) + collapseBasis = collapseBasis.drop_front(); + + Value collapsed = rewriter.create( + loc, indicesToCollapse, collapseBasis, /*isDisjoint=*/true); + indicesAfterCollapsing.push_back(collapsed); return indicesAfterCollapsing; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 5b5ec841917e7..49fc4ef541d43 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2508,6 +2508,9 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr, } } else if (auto stridedLayoutAttr = llvm::dyn_cast(attr)) { stridedLayoutAttr.print(os); + } else if (auto contiguousLayoutAttr = + llvm::dyn_cast(attr)) { + contiguousLayoutAttr.print(os); } else if (auto denseArrayAttr = llvm::dyn_cast(attr)) { os << "array<"; printType(denseArrayAttr.getElementType()); @@ -2801,7 +2804,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { os << 'x'; printType(memrefTy.getElementType()); MemRefLayoutAttrInterface layout = memrefTy.getLayout(); - if (!llvm::isa(layout) || !layout.isIdentity()) { + if (!llvm::isa(layout) || !layout.isIdentity()) { os << ", "; printAttribute(memrefTy.getLayout(), AttrTypeElision::May); } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 112e3f376bd41..3e75c5d652002 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -48,6 +48,88 @@ void BuiltinDialect::registerAttributes() { addAttributes(); } +//===----------------------------------------------------------------------===// +// ContiguousLayoutAttr +//===----------------------------------------------------------------------===// + +/// Build a row-major contiguous layout +ContiguousLayoutAttr ContiguousLayoutAttr::get(MLIRContext *context, + int64_t offset, int64_t rank) { + SmallVector identityPerm = + llvm::to_vector(llvm::iota_range(0, rank, /*inclusive=*/false)); + return get(context, offset, identityPerm); +} + +bool ContiguousLayoutAttr::isRowMajor() const { + return llvm::all_of(llvm::enumerate(getPermutation()), [](auto e) { + return static_cast(e.index()) == e.value(); + }); +} + +/// Prints a contiguous layout attribute. +void ContiguousLayoutAttr::print(llvm::raw_ostream &os) const { + os << "contiguous<"; + if (isRowMajor()) { + os << getPermutation().size(); + } else { + os << "["; + llvm::interleaveComma(getPermutation(), os); + os << "]"; + } + + if (getOffset() != 0) { + os << ", offset: "; + if (ShapedType::isDynamic(getOffset())) + os << "?"; + else + os << getOffset(); + } + os << ">"; +} + +/// Returns true if this layout is static, i.e. the offset has a known value. +bool ContiguousLayoutAttr::hasStaticLayout() const { + return !ShapedType::isDynamic(getOffset()); +} + +bool ContiguousLayoutAttr::isIdentity() const { + return getOffset() == 0 && isRowMajor(); +} + +/// Returns the contiguous layout attr as an affine map. +AffineMap ContiguousLayoutAttr::getAffineMap() const { + return makePermutedMapWithOffset(getPermutation(), getOffset(), getContext()); +} + +/// Checks that the type-agnostic contiguous layout invariants are satisfied. +LogicalResult +ContiguousLayoutAttr::verify(function_ref emitError, + int64_t offset, ArrayRef permutation) { + int64_t rank = permutation.size(); + llvm::SmallBitVector isPresent(rank, false); + for (int64_t idx : permutation) { + if (idx < 0 || idx >= rank) + return emitError() << "permutation element " << idx + << " is not a non-negative number less than " << rank; + isPresent.set(idx); + } + if (!isPresent.all()) + return emitError() << "permutation does not contain 0 upto " << rank + << " exactly once"; + return success(); +} + +/// Checks that the type-specific strided layout invariants are satisfied. +LogicalResult ContiguousLayoutAttr::verifyLayout( + ArrayRef shape, + function_ref emitError) const { + if (shape.size() != getPermutation().size()) + return emitError() << "expected the rank of the permutation to match the " + "rank of the memref"; + + return success(); +} + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// @@ -209,6 +291,19 @@ DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { // StridedLayoutAttr //===----------------------------------------------------------------------===// +/// Gets a strided layout attribute, canonicalizing to the contiguous layout +/// for 0-D memrefs and 1-D memrefs with a stride of 1 so as to guard against +/// multiple distinct forms of identity layout existing. +MemRefLayoutAttrInterface +StridedLayoutAttr::getCanonical(MLIRContext *context, int64_t offset, + ArrayRef strides) { + if (strides.empty()) + return ContiguousLayoutAttr::get(context, offset, 0); + if (strides.size() == 1 && strides.back() == 1) + return ContiguousLayoutAttr::get(context, offset, 1); + return get(context, offset, strides); +} + /// Prints a strided layout attribute. void StridedLayoutAttr::print(llvm::raw_ostream &os) const { auto printIntOrQuestion = [&](int64_t value) { @@ -1797,6 +1892,26 @@ Attribute DistinctAttr::getReferencedAttr() const { // Attribute Utilities //===----------------------------------------------------------------------===// +AffineMap mlir::makePermutedMapWithOffset(ArrayRef permutation, + int64_t offset, + MLIRContext *context) { + SmallVector exprs; + int64_t rank = permutation.size(); + exprs.reserve(rank); + for (int64_t comesFrom : permutation) + exprs.push_back(getAffineDimExpr(comesFrom, context)); + bool hasDynamicOffset = ShapedType::isDynamic(offset); + AffineExpr offsetExpr = hasDynamicOffset + ? getAffineSymbolExpr(0, context) + : getAffineConstantExpr(offset, context); + if (exprs.empty()) { + exprs.push_back(offsetExpr); + } else if (offset != 0) { + exprs.back() = exprs.back() + offsetExpr; + } + return AffineMap::get(rank, hasDynamicOffset ? 1 : 0, exprs, context); +} + AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef strides, int64_t offset, MLIRContext *context) { diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 3924d082f0628..db2524b15e2d4 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "TypeDetail.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" @@ -519,8 +520,8 @@ MemRefType MemRefType::get(ArrayRef shape, Type elementType, Attribute memorySpace) { // Use default layout for empty attribute. if (!layout) - layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( - shape.size(), elementType.getContext())); + layout = ContiguousLayoutAttr::get(elementType.getContext(), /*offset=*/0, + /*rank=*/shape.size()); // Drop default memory space value and replace it with empty attribute. memorySpace = skipDefaultMemorySpace(memorySpace); @@ -535,8 +536,8 @@ MemRefType MemRefType::getChecked( // Use default layout for empty attribute. if (!layout) - layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( - shape.size(), elementType.getContext())); + layout = ContiguousLayoutAttr::get(elementType.getContext(), /*offset=*/0, + /*rank=*/shape.size()); // Drop default memory space value and replace it with empty attribute. memorySpace = skipDefaultMemorySpace(memorySpace); @@ -548,13 +549,14 @@ MemRefType MemRefType::getChecked( MemRefType MemRefType::get(ArrayRef shape, Type elementType, AffineMap map, Attribute memorySpace) { - // Use default layout for empty map. - if (!map) - map = AffineMap::getMultiDimIdentityMap(shape.size(), - elementType.getContext()); - - // Wrap AffineMap into Attribute. - auto layout = AffineMapAttr::get(map); + MemRefLayoutAttrInterface layout; + if (map) + // Wrap AffineMap into Attribute. + layout = AffineMapAttr::get(map); + else + // Represent the default identity map as a contiguous layout. + layout = ContiguousLayoutAttr::get(elementType.getContext(), /*offset=*/0, + /*rank=*/shape.size()); // Drop default memory space value and replace it with empty attribute. memorySpace = skipDefaultMemorySpace(memorySpace); @@ -567,14 +569,14 @@ MemRefType MemRefType::getChecked(function_ref emitErrorFn, ArrayRef shape, Type elementType, AffineMap map, Attribute memorySpace) { - - // Use default layout for empty map. - if (!map) - map = AffineMap::getMultiDimIdentityMap(shape.size(), - elementType.getContext()); - - // Wrap AffineMap into Attribute. - auto layout = AffineMapAttr::get(map); + MemRefLayoutAttrInterface layout; + if (map) + // Wrap AffineMap into Attribute. + layout = AffineMapAttr::get(map); + else + // Represent the default identity map as a contiguous layout. + layout = ContiguousLayoutAttr::get(elementType.getContext(), /*offset=*/0, + /*rank=*/shape.size()); // Drop default memory space value and replace it with empty attribute. memorySpace = skipDefaultMemorySpace(memorySpace); @@ -586,13 +588,14 @@ MemRefType::getChecked(function_ref emitErrorFn, MemRefType MemRefType::get(ArrayRef shape, Type elementType, AffineMap map, unsigned memorySpaceInd) { - // Use default layout for empty map. - if (!map) - map = AffineMap::getMultiDimIdentityMap(shape.size(), - elementType.getContext()); - - // Wrap AffineMap into Attribute. - auto layout = AffineMapAttr::get(map); + MemRefLayoutAttrInterface layout; + if (map) + // Wrap AffineMap into Attribute. + layout = AffineMapAttr::get(map); + else + // Represent the default identity map as a contiguous layout. + layout = ContiguousLayoutAttr::get(elementType.getContext(), /*offset=*/0, + /*rank=*/shape.size()); // Convert deprecated integer-like memory space to Attribute. Attribute memorySpace = @@ -606,14 +609,14 @@ MemRefType MemRefType::getChecked(function_ref emitErrorFn, ArrayRef shape, Type elementType, AffineMap map, unsigned memorySpaceInd) { - - // Use default layout for empty map. - if (!map) - map = AffineMap::getMultiDimIdentityMap(shape.size(), - elementType.getContext()); - - // Wrap AffineMap into Attribute. - auto layout = AffineMapAttr::get(map); + MemRefLayoutAttrInterface layout; + if (map) + // Wrap AffineMap into Attribute. + layout = AffineMapAttr::get(map); + else + // Represent the default identity map as a contiguous layout. + layout = ContiguousLayoutAttr::get(elementType.getContext(), /*offset=*/0, + /*rank=*/shape.size()); // Convert deprecated integer-like memory space to Attribute. Attribute memorySpace = @@ -649,13 +652,27 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) { if (!isLastDimUnitStride()) return false; + if (auto contiguousLayout = + mlir::dyn_cast(getLayout())) { + ArrayRef perm = contiguousLayout.getPermutation(); + int64_t expectedValue = perm.size() - 1; + for (auto [iter, permVal] : llvm::enumerate(llvm::reverse(perm))) { + if (static_cast(iter) >= n) + return true; + if (permVal != expectedValue) + return false; + expectedValue--; + } + return true; + } auto memrefShape = getShape().take_back(n); - if (ShapedType::isDynamicShape(memrefShape)) - return false; if (getLayout().isIdentity()) return true; + if (ShapedType::isDynamicShape(memrefShape)) + return false; + int64_t offset; SmallVector stridesFull; if (!succeeded(getStridesAndOffset(stridesFull, offset))) @@ -677,31 +694,158 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) { return llvm::equal(strides, llvm::reverse(flattenedDims)); } +/// If `layout` is some permutation of the identity layout with an offset +/// applied to the last dimension - that is, if it has the form (d0, d1, ..., +/// dN) -> (dX, dY, ... dZ + E) for some symbol or constant E, succeed and +/// populate `perm` and `offset` with the discovered values. +static LogicalResult asOffsetPermutation(MemRefLayoutAttrInterface layout, + ArrayRef shape, + SmallVectorImpl &perm, + int64_t &offset) { + if (auto contiguousLayout = mlir::dyn_cast(layout)) { + perm.assign(contiguousLayout.getPermutation().begin(), + contiguousLayout.getPermutation().end()); + offset = contiguousLayout.getOffset(); + return success(); + } + if (auto stridedLayout = mlir::dyn_cast(layout)) { + // We can't reason about dynamic strides + if (llvm::any_of(stridedLayout.getStrides(), ShapedType::isDynamic)) + return failure(); + + int64_t suffixProd = 1; + bool isRowMajor = true; + for (auto [stride, dim] : llvm::zip( + llvm::reverse(stridedLayout.getStrides()), llvm::reverse(shape))) { + if (stride != suffixProd) { + isRowMajor = false; + break; + } + suffixProd *= dim; + } + if (isRowMajor) { + llvm::append_range(perm, llvm::iota_range(0, shape.size(), + /*Inclusive=*/false)); + offset = stridedLayout.getOffset(); + return success(); + } + + SmallVector> stridesAndLocs; + for (auto [idx, stride] : llvm::enumerate(stridedLayout.getStrides())) + stridesAndLocs.emplace_back(stride, static_cast(idx)); + // Sort by increasing stride, ties broken by appearing later in the memref. + llvm::sort(stridesAndLocs, [](auto a, auto b) { + if (a.first == b.first) + return a.second >= b.second; + return a.first < b.first; + }); + int64_t expectedStride = 1; + for (auto [stride, loc] : stridesAndLocs) { + if (stride != expectedStride) + return failure(); + expectedStride *= shape[loc]; + } + perm = llvm::map_to_vector(stridesAndLocs, [](auto x) { return x.second; }); + offset = stridedLayout.getOffset(); + return success(); + } + + auto pullOffset = [&](AffineExpr e) -> bool { + if (isa(e)) + return false; + if (auto constExpr = mlir::dyn_cast(e)) { + offset = constExpr.getValue(); + } else { + offset = ShapedType::kDynamic; + } + return true; + }; + + AffineMap m = layout.getAffineMap(); + if (m.getNumDims() == 0 && m.getNumResults() == 1) { + if (pullOffset(m.getResult(0))) + return success(); + return failure(); + } + + int64_t rank = shape.size(); + if (m.getNumResults() != rank || m.getNumDims() != rank) + return failure(); + + llvm::SmallBitVector seen(rank, false); + for (AffineExpr e : llvm::drop_end(m.getResults())) { + auto dimE = dyn_cast(e); + if (!dimE) + return failure(); + seen.set(dimE.getPosition()); + perm.push_back(dimE.getPosition()); + } + AffineDimExpr lastDim = dyn_cast(m.getResults().back()); + if (!lastDim) { + auto sum = dyn_cast(m.getResults().back()); + if (!sum || sum.getKind() != AffineExprKind::Add) + return failure(); + if (!(pullOffset(sum.getLHS()) && + (lastDim = dyn_cast(sum.getRHS()))) && + !(pullOffset(sum.getRHS()) && + (lastDim = dyn_cast(sum.getLHS())))) + return failure(); + } else { + offset = 0; + } + seen.set(lastDim.getPosition()); + perm.push_back(lastDim.getPosition()); + if (!seen.all()) + return failure(); + return success(); +} + +static SmallVector +computeStridesFromPermutedShape(ArrayRef shape, + ArrayRef perm) { + assert(shape.size() == perm.size() && + "shape and permutation have same length"); + int64_t rank = shape.size(); + SmallVector strides(rank, ShapedType::kDynamic); + strides.reserve(rank); + + // invertPermutationVector() might be a circular dependency betwin IR and + // Utils. + SmallVector strideOrder(rank, -1); + for (auto [idx, dim] : llvm::enumerate(perm)) { + strideOrder[dim] = static_cast(idx); + } + SaturatedInteger strideAccum = SaturatedInteger::wrap(1); + for (int64_t i = rank - 1; i >= 0; --i) { + strides[strideOrder[i]] = strideAccum.asInteger(); + strideAccum = strideAccum * SaturatedInteger::wrap(shape[strideOrder[i]]); + } + return strides; +} + MemRefType MemRefType::canonicalizeStridedLayout() { + MemRefLayoutAttrInterface layout = getLayout(); + if (mlir::isa(layout)) + return *this; + + SmallVector maybePerm; + int64_t maybeOffset; + if (succeeded( + asOffsetPermutation(layout, getShape(), maybePerm, maybeOffset))) { + return MemRefType::Builder(*this).setLayout( + ContiguousLayoutAttr::get(getContext(), maybeOffset, maybePerm)); + } + AffineMap m = getLayout().getAffineMap(); - // Already in canonical form. + // Identity maps that aren't contiguous<> aren't canonical if (m.isIdentity()) - return *this; + return MemRefType::Builder(*this).setLayout({}); // Can't reduce to canonical identity form, return in canonical form. if (m.getNumResults() > 1) return *this; - // Corner-case for 0-D affine maps. - if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { - if (auto cst = llvm::dyn_cast(m.getResult(0))) - if (cst.getValue() == 0) - return MemRefType::Builder(*this).setLayout({}); - return *this; - } - - // 0-D corner case for empty shape that still have an affine map. Example: - // `memref (s0)>>`. This is a 1 element memref whose - // offset needs to remain, just return t. - if (getShape().empty()) - return *this; - // If the canonical strided layout for the sizes of `t` is equal to the // simplified layout of `t` we can just return an empty layout. Otherwise, // just simplify the existing layout. @@ -794,7 +938,7 @@ static LogicalResult getStridesAndOffset(MemRefType t, AffineExpr &offset) { AffineMap m = t.getLayout().getAffineMap(); - if (m.getNumResults() != 1 && !m.isIdentity()) + if (m.getNumResults() != 1) return failure(); auto zero = getAffineConstantExpr(0, t.getContext()); @@ -842,6 +986,25 @@ LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl &strides, return success(); } + // Somewhat happy path: the type uses the contiguous layout, we need to + // compute the strides. + if (auto contiguous = llvm::dyn_cast(getLayout())) { + strides.append(computeStridesFromPermutedShape( + getShape(), contiguous.getPermutation())); + offset = contiguous.getOffset(); + return success(); + } + + SmallVector maybePermutation; + int64_t maybeOffset; + if (succeeded(asOffsetPermutation(getLayout(), getShape(), maybePermutation, + maybeOffset))) { + strides.append( + computeStridesFromPermutedShape(getShape(), maybePermutation)); + offset = maybeOffset; + return success(); + } + // Otherwise, defer to the affine fallback as layouts are supposed to be // convertible to affine maps. AffineExpr offsetExpr; @@ -878,6 +1041,10 @@ bool MemRefType::isStrided() { } bool MemRefType::isLastDimUnitStride() { + if (auto contiguousLayout = mlir::dyn_cast(getLayout())) + return getRank() == 0 || + contiguousLayout.getPermutation().back() == getRank() - 1; + int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(strides, offset); diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index c60ff72ff9fd4..ff344fe841f80 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -77,6 +77,7 @@ __all__ = [ "BoolAttr", "ComplexType", "Context", + "ContiguousLayoutAttr", "DenseBoolArrayAttr", "DenseBoolArrayIterator", "DenseElementsAttr", @@ -1017,6 +1018,40 @@ class Context: Gets a container for accessing dialects by name """ +class ContiguousLayoutAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + offset: int, permutation: list[int], context: Context | None = None + ) -> ContiguousLayoutAttr: + """ + Gets a contiguous layout attribute. + """ + @staticmethod + def get_row_major( + offset: int, rank: int, context: Context | None = None + ) -> ContiguousLayoutAttr: + """ + Gets a row-major contiguous layout attribute with the given offset and rank. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def offset(self) -> int: + """ + Returns the offset in the given contiguous layout attribute. + """ + @property + def permutation(self) -> list[int]: + """ + Returns the value of the fpermutation in the given contiguous layout attribute. + """ + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + class DenseBoolArrayAttr(Attribute): @staticmethod def get( diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py index b875d639e9d40..60ac134271f61 100644 --- a/mlir/python/mlir/extras/types.py +++ b/mlir/python/mlir/extras/types.py @@ -3,12 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from functools import partial -from typing import Optional, List +from typing import Optional, List, Union from ..ir import ( Attribute, + AffineMapAttr, BF16Type, ComplexType, + ContiguousLayoutAttr, F16Type, F32Type, F64Type, @@ -152,7 +154,9 @@ def memref( *shape, element_type: Type = None, memory_space: Optional[int] = None, - layout: Optional[StridedLayoutAttr] = None, + layout: Optional[ + Union[ContiguousLayoutAttr, StridedLayoutAttr, AffineMapAttr] + ] = None, ): if memory_space is not None: memory_space = Attribute.parse(str(memory_space)) diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index 68da79f69cc0a..d86a1dd927923 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1369,6 +1369,33 @@ int printBuiltinAttributes(MlirContext ctx) { if (!mlirAttributeIsALocation(locAttr)) return 24; + int64_t contiguousPerm[3] = {2, 1, 0}; + MlirAttribute contiguousLayoutAttr = + mlirContiguousLayoutAttrGet(ctx, 42, 3, contiguousPerm); + + // CHECK: contiguous<[2, 1, 0], offset: 42> + mlirAttributeDump(contiguousLayoutAttr); + + if (mlirContiguousLayoutAttrGetOffset(contiguousLayoutAttr) != 42 || + mlirContiguousLayoutAttrGetRank(contiguousLayoutAttr) != 3 || + mlirContiguousLayoutAttrGetPermutationEntry(contiguousLayoutAttr, 0) != + 2 || + mlirContiguousLayoutAttrGetPermutationEntry(contiguousLayoutAttr, 1) != + 1 || + mlirContiguousLayoutAttrGetPermutationEntry(contiguousLayoutAttr, 2) != 0) + return 25; + + MlirAttribute rowMajorContiguous = + mlirContiguousLayoutAttrGetRowMajor(ctx, 42, 2); + + // CHECK: contiguous<2, offset: 42> + mlirAttributeDump(rowMajorContiguous); + if (mlirContiguousLayoutAttrGetOffset(rowMajorContiguous) != 42 || + mlirContiguousLayoutAttrGetRank(rowMajorContiguous) != 2 || + mlirContiguousLayoutAttrGetPermutationEntry(rowMajorContiguous, 0) != 0 || + mlirContiguousLayoutAttrGetPermutationEntry(rowMajorContiguous, 1) != 1) + return 26; + return 0; } diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index 5517eafb588e8..fb82c1141c097 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -329,7 +329,7 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x // CHECK-LABEL: func @subview_rank_reducing_leading_operands( // CHECK: %[[MEM:.*]]: memref -func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memref<3xf32, strided<[1], offset: 3>> { +func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memref<3xf32, contiguous<1, offset: 3>> { // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64 // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64 @@ -345,9 +345,9 @@ func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memre // CHECK: %[[CST_STRIDE0:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST_STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %1 = memref.subview %0[1, 0][1, 3][1, 1]: memref<5x3xf32> to memref<3xf32, strided<[1], offset: 3>> + %1 = memref.subview %0[1, 0][1, 3][1, 1]: memref<5x3xf32> to memref<3xf32, contiguous<1, offset: 3>> - return %1 : memref<3xf32, strided<[1], offset: 3>> + return %1 : memref<3xf32, contiguous<1, offset: 3>> } // ----- @@ -656,10 +656,10 @@ func.func @expand_shape_dynamic_with_non_identity_layout( // ----- // CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout -func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, strided<[1], offset: ?>> { +func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, contiguous<1, offset: ?>> { // CHECK-NOT: memref.collapse_shape - %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>> - return %1 : memref<64xf32, strided<[1], offset: ?>> + %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, contiguous<1, offset: ?>> + return %1 : memref<64xf32, contiguous<1, offset: ?>> } // ----- diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 523e894aaef8d..3a21d827076a0 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -175,6 +175,21 @@ func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>> return } + +// ----- + +// CHECK-LABEL: func @assume_alignment_w_offset_contiguous +func.func @assume_alignment_w_offset_contiguous(%0 : memref<4x4xf16, contiguous<2, offset: ?>>) { + // CHECK-DAG: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[OFFSET:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16 + // CHECK-DAG: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1 + // CHECK-DAG: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i64 + // CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[BUFF_ADDR]], %[[ALIGN]] : !llvm.ptr, i64)] : i1 + memref.assume_alignment %0, 16 : memref<4x4xf16, contiguous<2, offset: ?>> + return +} + // ----- // CHECK-LABEL: func @dim_of_unranked @@ -243,6 +258,30 @@ func.func @transpose(%arg0: memref>) { // ----- +// CHECK-LABEL: func @transpose_contiguous +// CHECK: llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue {{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.extractvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +func.func @transpose_contiguous(%arg0: memref>) { + %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref> to memref> + return +} + +// ----- + // CHECK: llvm.mlir.global external @gv0() {addr_space = 0 : i32} : !llvm.array<2 x f32> { // CHECK-NEXT: %0 = llvm.mlir.undef : !llvm.array<2 x f32> // CHECK-NEXT: llvm.return %0 : !llvm.array<2 x f32> @@ -394,15 +433,15 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv // ----- -func.func @atomic_rmw_with_offset(%I : memref<10xi32, strided<[1], offset: 5>>, %ival : i32, %i : index) { - memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32, strided<[1], offset: 5>>) -> i32 +func.func @atomic_rmw_with_offset(%I : memref<10xi32, contiguous<1, offset: 5>>, %ival : i32, %i : index) { + memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32, contiguous<1, offset: 5>>) -> i32 return } // CHECK-LABEL: func @atomic_rmw_with_offset -// CHECK-SAME: %[[ARG0:.+]]: memref<10xi32, strided<[1], offset: 5>> +// CHECK-SAME: %[[ARG0:.+]]: memref<10xi32, contiguous<1, offset: 5>> // CHECK-SAME: %[[ARG1:.+]]: i32 // CHECK-SAME: %[[ARG2:.+]]: index -// CHECK-DAG: %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-DAG: %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, contiguous<1, offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64 // CHECK: %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(5 : index) : i64 @@ -526,9 +565,9 @@ func.func @memref_copy_contiguous(%in: memref<16x4xi32>, %offset: index) { // CHECK-LABEL: func @memref_copy_0d_offset func.func @memref_copy_0d_offset(%in: memref<2xi32>) { %buf = memref.alloc() : memref - %sub = memref.subview %in[1] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>> - %scalar = memref.collapse_shape %sub [] : memref<1xi32, strided<[1], offset: 1>> into memref> - memref.copy %scalar, %buf : memref> to memref + %sub = memref.subview %in[1] [1] [1] : memref<2xi32> to memref<1xi32, contiguous<1, offset: 1>> + %scalar = memref.collapse_shape %sub [] : memref<1xi32, contiguous<1, offset: 1>> into memref> + memref.copy %scalar, %buf : memref> to memref // CHECK: llvm.intr.memcpy return } diff --git a/mlir/test/Dialect/Affine/dma.mlir b/mlir/test/Dialect/Affine/dma.mlir index 7a15206bff872..7171b8ece0a4b 100644 --- a/mlir/test/Dialect/Affine/dma.mlir +++ b/mlir/test/Dialect/Affine/dma.mlir @@ -5,7 +5,7 @@ // Test with loop IVs. func.func @test0(%arg0 : index, %arg1 : index) { %0 = memref.alloc() : memref<100x100xf32> - %1 = memref.alloc() : memref<100x100xf32, affine_map<(d0, d1) -> (d0, d1)>, 2> + %1 = memref.alloc() : memref<100x100xf32, contiguous<2>, 2> %2 = memref.alloc() : memref<1xi32> %c0 = arith.constant 0 : index %c64 = arith.constant 64 : index @@ -26,7 +26,7 @@ func.func @test0(%arg0 : index, %arg1 : index) { // Test with loop IVs and optional stride arguments. func.func @test1(%arg0 : index, %arg1 : index) { %0 = memref.alloc() : memref<100x100xf32> - %1 = memref.alloc() : memref<100x100xf32, affine_map<(d0, d1) -> (d0, d1)>, 2> + %1 = memref.alloc() : memref<100x100xf32, contiguous<2>, 2> %2 = memref.alloc() : memref<1xi32> %c0 = arith.constant 0 : index %c64 = arith.constant 64 : index @@ -49,7 +49,7 @@ func.func @test1(%arg0 : index, %arg1 : index) { // Test with loop IVs and symbols (without symbol keyword). func.func @test2(%arg0 : index, %arg1 : index) { %0 = memref.alloc() : memref<100x100xf32> - %1 = memref.alloc() : memref<100x100xf32, affine_map<(d0, d1) -> (d0, d1)>, 2> + %1 = memref.alloc() : memref<100x100xf32, contiguous<2>, 2> %2 = memref.alloc() : memref<1xi32> %c0 = arith.constant 0 : index %c64 = arith.constant 64 : index @@ -71,7 +71,7 @@ func.func @test2(%arg0 : index, %arg1 : index) { // Test with loop IVs and symbols (with symbol keyword). func.func @test3(%arg0 : index, %arg1 : index) { %0 = memref.alloc() : memref<100x100xf32> - %1 = memref.alloc() : memref<100x100xf32, affine_map<(d0, d1) -> (d0, d1)>, 2> + %1 = memref.alloc() : memref<100x100xf32, contiguous<2>, 2> %2 = memref.alloc() : memref<1xi32> %c0 = arith.constant 0 : index %c64 = arith.constant 64 : index diff --git a/mlir/test/Dialect/Affine/pipeline-data-transfer.mlir b/mlir/test/Dialect/Affine/pipeline-data-transfer.mlir index 9ea282b1dc858..95f0285be1651 100644 --- a/mlir/test/Dialect/Affine/pipeline-data-transfer.mlir +++ b/mlir/test/Dialect/Affine/pipeline-data-transfer.mlir @@ -8,8 +8,8 @@ // CHECK-LABEL: func @loop_nest_dma() { func.func @loop_nest_dma() { - %A = memref.alloc() : memref<256 x f32, affine_map<(d0) -> (d0)>, 0> - %Ah = memref.alloc() : memref<32 x f32, affine_map<(d0) -> (d0)>, 1> + %A = memref.alloc() : memref<256 x f32, contiguous<1>, 0> + %Ah = memref.alloc() : memref<32 x f32, contiguous<1>, 1> %tag = memref.alloc() : memref<1 x f32> @@ -19,15 +19,15 @@ func.func @loop_nest_dma() { affine.for %i = 0 to 8 { affine.dma_start %A[%i], %Ah[%i], %tag[%zero], %num_elts : memref<256 x f32>, memref<32 x f32, 1>, memref<1 x f32> affine.dma_wait %tag[%zero], %num_elts : memref<1 x f32> - %v = affine.load %Ah[%i] : memref<32 x f32, affine_map<(d0) -> (d0)>, 1> + %v = affine.load %Ah[%i] : memref<32 x f32, contiguous<1>, 1> %r = "compute"(%v) : (f32) -> (f32) - affine.store %r, %Ah[%i] : memref<32 x f32, affine_map<(d0) -> (d0)>, 1> + affine.store %r, %Ah[%i] : memref<32 x f32, contiguous<1>, 1> affine.for %j = 0 to 32 { "do_more_compute"(%i, %j) : (index, index) -> () } } memref.dealloc %tag : memref<1 x f32> - memref.dealloc %Ah : memref<32 x f32, affine_map<(d0) -> (d0)>, 1> + memref.dealloc %Ah : memref<32 x f32, contiguous<1>, 1> return } // CHECK: %{{.*}} = memref.alloc() : memref<256xf32> @@ -353,8 +353,8 @@ func.func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>, %Av: memref (d0)>, 0> - %Ah = memref.alloc() : memref<32 x f32, affine_map<(d0) -> (d0)>, 1> + %A = memref.alloc() : memref<256 x f32, contiguous<1>, 0> + %Ah = memref.alloc() : memref<32 x f32, contiguous<1>, 1> %tag = memref.alloc() : memref<1 x f32> %zero = arith.constant 0 : index %num_elts = arith.constant 32 : index @@ -364,11 +364,11 @@ func.func @escaping_and_indexed_use_mix() { affine.dma_start %A[%i], %Ah[%i], %tag[%zero], %num_elts : memref<256 x f32>, memref<32 x f32, 1>, memref<1 x f32> affine.dma_wait %tag[%zero], %num_elts : memref<1 x f32> "compute"(%Ah) : (memref<32 x f32, 1>) -> () - %v = affine.load %Ah[%i] : memref<32 x f32, affine_map<(d0) -> (d0)>, 1> + %v = affine.load %Ah[%i] : memref<32 x f32, contiguous<1>, 1> "foo"(%v) : (f32) -> () } - memref.dealloc %A : memref<256 x f32, affine_map<(d0) -> (d0)>, 0> - memref.dealloc %Ah : memref<32 x f32, affine_map<(d0) -> (d0)>, 1> + memref.dealloc %A : memref<256 x f32, contiguous<1>, 0> + memref.dealloc %Ah : memref<32 x f32, contiguous<1>, 1> return } // No replacement. diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index 8249d59b2374e..42c0de7306c53 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -24,8 +24,8 @@ func.func @buffer_forwarding_conflict( %f = linalg.fill ins(%f0 : f32) outs(%a : tensor) -> tensor // CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref to memref - // CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref to memref> - // CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref to memref> + // CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref to memref + // CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref to memref %r0 = tensor.insert_slice %f into %t[0][%sz][1]: tensor into tensor // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir index c26f1681e4d96..fa8aab25bd584 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir @@ -90,8 +90,8 @@ func.func @alloc_tesor_copy_from_non_default_space_no_cast(%arg0: tensor<128xf32 // CHECK: %[[v3:.+]] = bufferization.to_tensor %[[alloc]] : memref<128xf32, 2> to tensor<128xf32, 1 : i64> // CHECK: %[[alloc_0:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128xf32, 1> // CHECK: memref.copy %[[v1]], %[[alloc_0]] : memref<128xf32, strided<[?], offset: ?>, 1> to memref<128xf32, 1> -// CHECK: %[[subview:.+]] = memref.subview %[[alloc_0]][0] [4] [1] : memref<128xf32, 1> to memref<4xf32, strided<[1]>, 1> -// CHECK: memref.copy %[[v0]], %[[subview]] : memref<4xf32, strided<[?], offset: ?>, 1> to memref<4xf32, strided<[1]>, 1> +// CHECK: %[[subview:.+]] = memref.subview %[[alloc_0]][0] [4] [1] : memref<128xf32, 1> to memref<4xf32, 1> +// CHECK: memref.copy %[[v0]], %[[subview]] : memref<4xf32, strided<[?], offset: ?>, 1> to memref<4xf32, 1> // CHECK: return %[[v3]] : tensor<128xf32, 1 : i64> // ----- diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index e7797d4bc50a9..09df8e665e423 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -543,7 +543,6 @@ func.func @entry(%A : tensor {bufferization.buffer_layout = affine_map<(i // %A, %B and %C are not inplaceable. This test case shows that this kind of // conflict detection has a "transitive" nature. // CHECK-DAG: %[[ALLOC_A:.*]] = memref.alloc -// CHECK-DAG: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]] // CHECK-DAG: %[[ALLOC_B:.*]] = memref.alloc // CHECK-DAG: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]] // CHECK-DAG: %[[ALLOC_C:.*]] = memref.alloc @@ -551,7 +550,7 @@ func.func @entry(%A : tensor {bufferization.buffer_layout = affine_map<(i // CHECK-DAG: memref.copy %[[A]], %[[ALLOC_A]] // CHECK-DAG: memref.copy %[[B]], %[[ALLOC_B]] // CHECK-DAG: memref.copy %[[C]], %[[ALLOC_C]] -// CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]]) +// CHECK-NEXT: call @callee(%[[ALLOC_A]], %[[CASTED_B]], %[[CASTED_C]]) call @callee(%A, %B, %C) : (tensor, tensor, tensor) -> () return } @@ -789,8 +788,8 @@ func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> { // CHECK: return %[[cast0]] : memref<5xf32, strided<[?], offset: ?> return %0 : tensor<5xf32> ^bb2: - // CHECK: %[[m1:.*]] = memref.subview %[[alloc]][2] [5] [1] : memref<10xf32> to memref<5xf32, strided<[1], offset: 2>> - // CHECK: %[[cast1:.*]] = memref.cast %[[m1]] : memref<5xf32, strided<[1], offset: 2>> to memref<5xf32, strided<[?], offset: ?>> + // CHECK: %[[m1:.*]] = memref.subview %[[alloc]][2] [5] [1] : memref<10xf32> to memref<5xf32, contiguous<1, offset: 2>> + // CHECK: %[[cast1:.*]] = memref.cast %[[m1]] : memref<5xf32, contiguous<1, offset: 2>> to memref<5xf32, strided<[?], offset: ?>> %1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32> // CHECK: return %[[cast1]] : memref<5xf32, strided<[?], offset: ?>> return %1 : tensor<5xf32> diff --git a/mlir/test/Dialect/GPU/decompose-memrefs.mlir b/mlir/test/Dialect/GPU/decompose-memrefs.mlir index 1a19221948451..99450cb26de5f 100644 --- a/mlir/test/Dialect/GPU/decompose-memrefs.mlir +++ b/mlir/test/Dialect/GPU/decompose-memrefs.mlir @@ -1,14 +1,13 @@ // RUN: mlir-opt -gpu-decompose-memrefs -allow-unregistered-dialect -split-input-file %s | FileCheck %s -// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)> // CHECK: @decompose_store // CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref) // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] // CHECK: gpu.launch // CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in -// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]] -// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref> -// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref> +// CHECK: %[[IDX:.*]] = affine.linearize_index disjoint [%[[TX]], %[[TY]], %[[TZ]]] by (%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2) +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref> +// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref> func.func @decompose_store(%arg0 : f32, %arg1 : memref) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -26,6 +25,31 @@ func.func @decompose_store(%arg0 : f32, %arg1 : memref) { // ----- +// CHECK: @decompose_store_column_major +// CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref>) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] +// CHECK: gpu.launch +// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in +// CHECK: %[[IDX:.*]] = affine.linearize_index disjoint [%[[TZ]], %[[TY]], %[[TX]]] by (%[[SIZES]]#2, %[[SIZES]]#1, %[[SIZES]]#0) +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref> +// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref> +func.func @decompose_store_column_major(%arg0 : f32, %arg1 : memref>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %block_dim0 = memref.dim %arg1, %c0 : memref> + %block_dim1 = memref.dim %arg1, %c1 : memref> + %block_dim2 = memref.dim %arg1, %c2 : memref> + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { + memref.store %arg0, %arg1[%tx, %ty, %tz] : memref> + gpu.terminator + } + return +} + +// ----- + // CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)> // CHECK: @decompose_store_strided // CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref>) @@ -33,8 +57,8 @@ func.func @decompose_store(%arg0 : f32, %arg1 : memref) { // CHECK: gpu.launch // CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2] -// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref> -// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref> +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref> +// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref> func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -52,26 +76,27 @@ func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref (s0 * s1 + s2 * s3 + s4)> +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> // CHECK: @decompose_load -// CHECK-SAME: (%[[MEM:.*]]: memref) +// CHECK-SAME: (%[[MEM:.*]]: memref>) // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] // CHECK: gpu.launch // CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in -// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]] -// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref> -// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref> +// CHECK: %[[IDX:.*]] = affine.linearize_index disjoint [%[[TX]], %[[TY]], %[[TZ]]] by (%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2) +// CHECK: %[[NEW_OFF:.*]] = affine.apply #[[MAP]]()[%[[IDX]], %[[OFFSET]]] +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[NEW_OFF]]], sizes: [], strides: [] : memref to memref> +// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref> // CHECK: "test.test"(%[[RES]]) : (f32) -> () -func.func @decompose_load(%arg0 : memref) { +func.func @decompose_load(%arg0 : memref>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - %block_dim0 = memref.dim %arg0, %c0 : memref - %block_dim1 = memref.dim %arg0, %c1 : memref - %block_dim2 = memref.dim %arg0, %c2 : memref + %block_dim0 = memref.dim %arg0, %c0 : memref> + %block_dim1 = memref.dim %arg0, %c1 : memref> + %block_dim2 = memref.dim %arg0, %c2 : memref> gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { - %res = memref.load %arg0[%tx, %ty, %tz] : memref + %res = memref.load %arg0[%tx, %ty, %tz] : memref> "test.test"(%res) : (f32) -> () gpu.terminator } diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index 3256daa8e0b59..824c1805eefe5 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -972,8 +972,8 @@ func.func @drop_all_loops(%arg0 : memref<1x1xf32, 3>) -> memref<1x1xf32, 3> // CHECK: linalg.generic{{.*}}memref // CHECK-SLICES-LABEL: func @drop_all_loops -// CHECK-SLICES: memref.subview %{{.*}}[0, 0] [1, 1] [1, 1] : memref<1x1xf32, 3> to memref, 3> -// CHECK-SLICES: linalg.generic{{.*}}memref, 3> +// CHECK-SLICES: memref.subview %{{.*}}[0, 0] [1, 1] [1, 1] : memref<1x1xf32, 3> to memref +// CHECK-SLICES: linalg.generic{{.*}}memref // ----- diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index 176e55e3e6c4a..50bb846cb121b 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -1,10 +1,10 @@ // RUN: mlir-opt %s -transform-interpreter -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s -func.func @dot(%x: memref>, - %y: memref>, +func.func @dot(%x: memref>, + %y: memref>, %v: memref) { - linalg.dot ins(%x, %y: memref>, - memref>) + linalg.dot ins(%x, %y: memref>, + memref>) outs(%v: memref) return } @@ -26,12 +26,12 @@ module attributes {transform.with_named_sequence} { // ----- func.func @matvec(%A: memref>, - %x: memref>, - %y: memref>) { + %x: memref>, + %y: memref>) { linalg.matvec ins(%A, %x: memref>, - memref>) - outs(%y: memref>) + memref>) + outs(%y: memref>) return } @@ -50,8 +50,8 @@ module attributes {transform.with_named_sequence} { // CHECK: scf.for {{.*}} step %[[c5]] // CHECK: scf.for {{.*}} step %[[c6]] // CHECK: linalg.matvec -// CHECK: ins({{.*}}: memref>, memref>) -// CHECK: outs({{.*}}: memref>) +// CHECK: ins({{.*}}: memref>, memref>) +// CHECK: outs({{.*}}: memref>) // ----- @@ -157,11 +157,11 @@ module attributes {transform.with_named_sequence} { // ----- func.func @matvec_perm(%A: memref>, - %x: memref>, - %y: memref>) { + %x: memref>, + %y: memref>) { linalg.matvec ins(%A, %x: memref>, - memref>) - outs(%y: memref>) + memref>) + outs(%y: memref>) return } @@ -180,8 +180,8 @@ module attributes {transform.with_named_sequence} { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]] // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]] // CHECK: linalg.matvec -// CHECK: ins({{.*}}: memref>, memref>) -// CHECK: outs({{.*}}: memref>) +// CHECK: ins({{.*}}: memref>, memref>) +// CHECK: outs({{.*}}: memref>) // ----- diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 02110bc2892d0..e55b72becaa30 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -133,7 +133,7 @@ func.func @multiple_reducing_dims(%arg0 : memref<1x384x384xf32>, // CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1] // CHECK-SAME: : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>> // CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] -// CHECK-SAME: : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref> +// CHECK-SAME: : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref> // ----- @@ -149,7 +149,7 @@ func.func @multiple_reducing_dims_dynamic(%arg0 : memref, // CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1] // CHECK-SAME: : memref to memref<1x?xf32, strided<[?, 1], offset: ?>> // CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] -// CHECK-SAME: : memref<1x?xf32, strided<[?, 1], offset: ?>> to memref> +// CHECK-SAME: : memref<1x?xf32, strided<[?, 1], offset: ?>> to memref> // ----- @@ -466,6 +466,23 @@ func.func @compose_collapse_of_collapse(%arg0 : memref) // ----- +func.func @compose_collapse_of_expand_offset_layout( + %arg0: memref>, %sz0: index, %sz1: index) + -> memref> { + %1 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1] : + memref> into + memref> + %2 = memref.collapse_shape %1 [[0, 1, 2]] : + memref> into + memref> + return %2 : memref> +} +// CHECK-LABEL: func @compose_collapse_of_expand_offset_layout +// CHECK: memref.collapse_shape +// CHECK-SAME: memref> into memref> + +// ----- + func.func @do_not_compose_collapse_of_expand_non_identity_layout( %arg0: memref>, %sz0: index, %sz1: index) -> memref> { @@ -645,6 +662,16 @@ func.func @no_fold_subview_with_non_zero_offset(%arg0 : memref<20x42xf32>) -> me // ----- +func.func @no_fold_subview_with_non_zero_offset_contiguous(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, contiguous<2, offset: 1>> { + %0 = memref.subview %arg0[0, 1] [20, 42] [1, 1] : memref<20x42xf32> to memref<20x42xf32, contiguous<2, offset: 1>> + return %0 : memref<20x42xf32, contiguous<2, offset: 1>> +} +// CHECK-LABEL: func @no_fold_subview_with_non_zero_offset_contiguous( +// CHECK: %[[SUBVIEW:.+]] = memref.subview +// CHECK: return %[[SUBVIEW]] + +// ----- + func.func @no_fold_subview_with_non_unit_stride(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, strided<[42, 2]>> { %0 = memref.subview %arg0[0, 0] [20, 42] [1, 2] : memref<20x42xf32> to memref<20x42xf32, strided<[42, 2]>> return %0 : memref<20x42xf32, strided<[42, 2]>> @@ -981,7 +1008,7 @@ func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>, // CHECK-SAME: %[[ARG0:.+]]: memref<8x?xf32> // CHECK-SAME: %[[ARG1:.+]]: index // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, 0] [1, %[[ARG1]]] [1, 1] -// CHECK-SAME: memref<8x?xf32> to memref> +// CHECK-SAME: memref<8x?xf32> to memref // ----- @@ -1034,12 +1061,12 @@ func.func @fold_trivial_subviews(%m: memref>, // ----- // CHECK-LABEL: func @load_store_nontemporal( -func.func @load_store_nontemporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, %output : memref<32xf32, affine_map<(d0) -> (d0)>>) { +func.func @load_store_nontemporal(%input : memref<32xf32, contiguous<1>>, %output : memref<32xf32, contiguous<1>>) { %1 = arith.constant 7 : index // CHECK: memref.load %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32> - %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, contiguous<1>> // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32> - memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, contiguous<1>> func.return } @@ -1096,24 +1123,24 @@ func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index) // CHECK-LABEL: func @fold_double_transpose( // CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32> -func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> { +func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, contiguous<[4, 2, 1, 3, 0]>> { // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0) - %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> - %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> + %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, contiguous<[1, 0, 4, 3, 2]>> + %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, contiguous<[1, 0, 4, 3, 2]>> to memref<5x3x2x4x1xf32, contiguous<[4, 2, 1, 3, 0]>> // CHECK: return %[[ONETRANSPOSE]] - return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> + return %1 : memref<5x3x2x4x1xf32, contiguous<[4, 2, 1, 3, 0]>> } // ----- // CHECK-LABEL: func @fold_double_transpose2( // CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32> -func.func @fold_double_transpose2(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> { +func.func @fold_double_transpose2(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, contiguous<[4, 2, 1, 3, 0]>> { // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0) - %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d0, d1, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>> - %1 = memref.transpose %0 (d0, d1, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> + %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d0, d1, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<1x2x5x4x3xf32, contiguous<[0, 1, 4, 3, 2]>> + %1 = memref.transpose %0 (d0, d1, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<1x2x5x4x3xf32, contiguous<[0, 1, 4, 3, 2]>> to memref<5x3x2x4x1xf32, contiguous<[4, 2, 1, 3, 0]>> // CHECK: return %[[ONETRANSPOSE]] - return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> + return %1 : memref<5x3x2x4x1xf32, contiguous<[4, 2, 1, 3, 0]>> } // ----- @@ -1121,15 +1148,15 @@ func.func @fold_double_transpose2(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x // CHECK-LABEL: func @fold_identity_transpose( // CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32> func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> { - %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> - %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<1x2x3x4x5xf32> + %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, contiguous<[1, 0, 4, 3, 2]>> + %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, contiguous<[1, 0, 4, 3, 2]>> to memref<1x2x3x4x5xf32> // CHECK: return %[[arg0]] return %1 : memref<1x2x3x4x5xf32> } // ----- -#transpose_map = affine_map<(d0, d1)[s0] -> (d0 + d1 * s0)> +#transpose_map = affine_map<(d0, d1)-> (d1, d0)> // CHECK-LABEL: func @cannot_fold_transpose_cast( // CHECK-SAME: %[[arg0:.*]]: memref diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index 1d6cbfa343ba5..5462d229100fd 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -163,19 +163,19 @@ func.func @rank_zero_memref() -> i4 { func.func @memref_strided_i4(%idx : index) -> i4 { %arr = memref.alloc() : memref<128xi4> - %subview = memref.subview %arr[32] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset:32>> - %1 = memref.load %subview[%idx] : memref<32xi4, strided<[1], offset:32>> + %subview = memref.subview %arr[32] [32] [1] : memref<128xi4> to memref<32xi4, contiguous<1, offset:32>> + %1 = memref.load %subview[%idx] : memref<32xi4, contiguous<1, offset:32>> return %1 : i4 } // CHECK-LABEL: func @memref_strided_i4 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64xi8> -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: 16>> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16] [16] [1] : memref<64xi8> to memref<16xi8, contiguous<1, offset: 16>> // CHECK: %[[LOAD:.+]] = memref.load %[[SUBVIEW]] // CHECK32-LABEL: func @memref_strided_i4 // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32> -// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>> +// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, contiguous<1, offset: 4>> // CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]] // ----- @@ -192,13 +192,13 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 { // CHECK-LABEL: func.func @memref_subview_dynamic_offset_i4( // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2097152xi8> // CHECK: %[[IDX:.*]] = affine.apply -// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [65536] [1] : memref<2097152xi8> to memref<65536xi8, strided<[1], offset: ?>> +// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [65536] [1] : memref<2097152xi8> to memref<65536xi8, contiguous<1, offset: ?>> // CHECK: memref.load %[[SUBVIEW]] // CHECK32-LABEL: func.func @memref_subview_dynamic_offset_i4( // CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<524288xi32> // CHECK32: %[[IDX:.*]] = affine.apply -// CHECK32: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [16384] [1] : memref<524288xi32> to memref<16384xi32, strided<[1], offset: ?>> +// CHECK32: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [16384] [1] : memref<524288xi32> to memref<16384xi32, contiguous<1, offset: ?>> // CHECK32: memref.load %[[SUBVIEW]] // ----- @@ -238,8 +238,8 @@ func.func @reinterpret_cast_memref_load_0D() -> i4 { func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 { %0 = memref.alloc() : memref<5x5xi4> - %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [8], sizes: [25], strides: [1] : memref<5x5xi4> to memref<25xi4, strided<[1], offset:8>> - %1 = memref.load %reinterpret_cast_0[%arg0] : memref<25xi4, strided<[1], offset:8>> + %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [8], sizes: [25], strides: [1] : memref<5x5xi4> to memref<25xi4, contiguous<1, offset:8>> + %1 = memref.load %reinterpret_cast_0[%arg0] : memref<25xi4, contiguous<1, offset:8>> return %1 : i4 } // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 2)> @@ -247,9 +247,9 @@ func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 { // CHECK: func @reinterpret_cast_memref_load_1D( // CHECK-SAME: %[[ARG0:.+]]: index // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<13xi8> -// CHECK: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [4], sizes: [13], strides: [1] : memref<13xi8> to memref<13xi8, strided<[1], offset: 4>> +// CHECK: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [4], sizes: [13], strides: [1] : memref<13xi8> to memref<13xi8, contiguous<1, offset: 4>> // CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK: %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<13xi8, strided<[1], offset: 4>> +// CHECK: %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<13xi8, contiguous<1, offset: 4>> // CHECK: %[[OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] // CHECK: %[[CAST:.+]] = arith.index_cast %[[OFFSET]] : index to i8 // CHECK: %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i8 @@ -261,9 +261,9 @@ func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 { // CHECK32: func @reinterpret_cast_memref_load_1D( // CHECK32-SAME: %[[ARG0:.+]]: index // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32> -// CHECK32: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [1], sizes: [4], strides: [1] : memref<4xi32> to memref<4xi32, strided<[1], offset: 1>> +// CHECK32: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [1], sizes: [4], strides: [1] : memref<4xi32> to memref<4xi32, contiguous<1, offset: 1>> // CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK32: %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<4xi32, strided<[1], offset: 1>> +// CHECK32: %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<4xi32, contiguous<1, offset: 1>> // CHECK32: %[[OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] // CHECK32: %[[CAST:.+]] = arith.index_cast %[[OFFSET]] : index to i32 // CHECK32: %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i32 diff --git a/mlir/test/Dialect/MemRef/make-loop-independent.mlir b/mlir/test/Dialect/MemRef/make-loop-independent.mlir index dca7bc1e67586..1aa12aa3bd055 100644 --- a/mlir/test/Dialect/MemRef/make-loop-independent.mlir +++ b/mlir/test/Dialect/MemRef/make-loop-independent.mlir @@ -12,22 +12,21 @@ func.func @make_alloca_loop_independent(%lb: index, %ub: index, %step: index) { scf.for %i = %lb to %ub step %step { // CHECK: %[[sz:.*]] = affine.apply #[[$map]]()[%[[ub]]] // CHECK: %[[alloca:.*]] = memref.alloca(%[[sz]]) - // CHECK: %[[subview:.*]] = memref.subview %[[alloca]][0] [%[[iv]]] [1] : memref to memref> - // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[subview]] : memref> to memref + // CHECK: %[[subview:.*]] = memref.subview %[[alloca]][0] [%[[iv]]] [1] : memref to memref %alloc = memref.alloca(%i) : memref // memref.subview has special handling. - // CHECK: %[[subview2:.*]] = memref.subview %[[subview]][1] [5] [1] : memref> to memref<5xf32, strided<[1], offset: 1>> - %view = memref.subview %alloc[1][5][1] : memref to memref<5xf32, strided<[1], offset: 1>> + // CHECK: %[[subview2:.*]] = memref.subview %[[subview]][1] [5] [1] : memref to memref<5xf32, contiguous<1, offset: 1>> + %view = memref.subview %alloc[1][5][1] : memref to memref<5xf32, contiguous<1, offset: 1>> // This op takes a memref but does not produce one. The new alloc is used // directly. // CHECK: "test.some_use"(%[[subview2]]) - "test.some_use"(%view) : (memref<5xf32, strided<[1], offset: 1>>) -> () + "test.some_use"(%view) : (memref<5xf32, contiguous<1, offset: 1>>) -> () // This op produces a memref, so the new alloc cannot be used directly. // It is wrapped in a unrealized_conversion_cast. - // CHECK: "test.another_use"(%[[cast]]) : (memref) -> memref + // CHECK: "test.another_use"(%[[subview]]) : (memref) -> memref "test.another_use"(%alloc) : (memref) -> (memref) // CHECK: memref.store %{{.*}}, %[[subview]] @@ -57,7 +56,7 @@ func.func @make_alloca_loop_independent_static(%step: index) { %sz = affine.apply affine_map<(d0)[s0] -> (-d0 + s0)>(%i)[%ub] // CHECK: %[[alloca:.*]] = memref.alloca() : memref<128xf32> - // CHECK: %[[subview:.*]] = memref.subview %[[alloca]][0] [%[[sz]]] [1] : memref<128xf32> to memref> + // CHECK: %[[subview:.*]] = memref.subview %[[alloca]][0] [%[[sz]]] [1] : memref<128xf32> to memref %alloc = memref.alloca(%sz) : memref // CHECK: memref.store %{{.*}}, %[[subview]] diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 7038a6ff744e4..6d3dfc586574a 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -1,21 +1,21 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s // RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (d0 + s0, d1)> +// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (d0 + s0, d1)> // CHECK-LABEL: func @alloc() { func.func @alloc() { ^bb0: // Test simple alloc. // CHECK: %{{.*}} = memref.alloc() : memref<1024x64xf32, 1> - %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + %0 = memref.alloc() : memref<1024x64xf32, contiguous<2>, 1> %c0 = "arith.constant"() {value = 0: index} : () -> index %c1 = "arith.constant"() {value = 1: index} : () -> index // Test alloc with dynamic dimensions. // CHECK: %{{.*}} = memref.alloc(%{{.*}}, %{{.*}}) : memref - %1 = memref.alloc(%c0, %c1) : memref (d0, d1)>, 1> + %1 = memref.alloc(%c0, %c1) : memref, 1> // Test alloc with no dynamic dimensions and one symbol. // CHECK: %{{.*}} = memref.alloc()[%{{.*}}] : memref<2x4xf32, #[[$MAP]], 1> @@ -30,6 +30,14 @@ func.func @alloc() { // CHECK: %{{.*}} = memref.alloc() : memref<2xi32> %4 = memref.alloc() : memref<2 x i32> + // Alloc with affine map + // CHECK: %{{.*}} = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + %5 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + + // Test alloc with no dynamic dimensions and one offset. + // CHECK: %{{.*}} = memref.alloc()[%{{.*}}] : memref<2x4xf32, contiguous<2, offset: ?>, 1> + %6 = memref.alloc()[%c0] : memref<2x4xf32, contiguous<2, offset: ?>, 1> + // CHECK: return return } @@ -39,14 +47,14 @@ func.func @alloca() { ^bb0: // Test simple alloc. // CHECK: %{{.*}} = memref.alloca() : memref<1024x64xf32, 1> - %0 = memref.alloca() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + %0 = memref.alloca() : memref<1024x64xf32, contiguous<2>, 1> %c0 = "arith.constant"() {value = 0: index} : () -> index %c1 = "arith.constant"() {value = 1: index} : () -> index // Test alloca with dynamic dimensions. // CHECK: %{{.*}} = memref.alloca(%{{.*}}, %{{.*}}) : memref - %1 = memref.alloca(%c0, %c1) : memref (d0, d1)>, 1> + %1 = memref.alloca(%c0, %c1) : memref, 1> // Test alloca with no dynamic dimensions and one symbol. // CHECK: %{{.*}} = memref.alloca()[%{{.*}}] : memref<2x4xf32, #[[$MAP]], 1> @@ -60,6 +68,14 @@ func.func @alloca() { // CHECK: %{{.*}} = memref.alloca() {alignment = 64 : i64} : memref<2xi32> %4 = memref.alloca() {alignment = 64} : memref<2 x i32> + // Test alloca with affine map. + // CHECK: %{{.*}} = memref.alloca() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + %5 = memref.alloca() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + + // Test alloca with no dynamic dimensions and one offset. + // CHECK: %{{.*}} = memref.alloca()[%{{.*}}] : memref<2x4xf32, contiguous<2, offset: ?>, 1> + %6 = memref.alloca()[%c0] : memref<2x4xf32, contiguous<2, offset: ?>, 1> + return } @@ -67,26 +83,44 @@ func.func @alloca() { func.func @dealloc() { ^bb0: // CHECK: %{{.*}} = memref.alloc() : memref<1024x64xf32> - %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 0> + %0 = memref.alloc() : memref<1024x64xf32, contiguous<2>, 0> // CHECK: memref.dealloc %{{.*}} : memref<1024x64xf32> - memref.dealloc %0 : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 0> + memref.dealloc %0 : memref<1024x64xf32, contiguous<2>, 0> return } -// CHECK-LABEL: func @load_store +// CHECK-LABEL: @load_store func.func @load_store() { ^bb0: // CHECK: %{{.*}} = memref.alloc() : memref<1024x64xf32, 1> - %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + %0 = memref.alloc() : memref<1024x64xf32, contiguous<2>, 1> %1 = arith.constant 0 : index %2 = arith.constant 1 : index // CHECK: %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x64xf32, 1> - %3 = memref.load %0[%1, %2] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + %3 = memref.load %0[%1, %2] : memref<1024x64xf32, contiguous<2>, 1> // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x64xf32, 1> + memref.store %3, %0[%1, %2] : memref<1024x64xf32, contiguous<2>, 1> + + return +} + +// CHECK-LABEL: func @load_store_affine +func.func @load_store_affine() { +^bb0: + // CHECK: %{{.*}} = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + + %1 = arith.constant 0 : index + %2 = arith.constant 1 : index + + // CHECK: %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + %3 = memref.load %0[%1, %2] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + + // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> memref.store %3, %0[%1, %2] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> return @@ -98,8 +132,8 @@ func.func @dma_ops() { %stride = arith.constant 32 : index %elt_per_stride = arith.constant 16 : index - %A = memref.alloc() : memref<256 x f32, affine_map<(d0) -> (d0)>, 0> - %Ah = memref.alloc() : memref<256 x f32, affine_map<(d0) -> (d0)>, 1> + %A = memref.alloc() : memref<256 x f32, contiguous<1>, 0> + %Ah = memref.alloc() : memref<256 x f32, contiguous<1>, 1> %tag = memref.alloc() : memref<1 x f32> %num_elements = arith.constant 256 : index @@ -226,6 +260,13 @@ func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref, %arg2 : memr // CHECK: memref.cast %{{.*}} : memref<*xf32> to memref<4xf32> %5 = memref.cast %4 : memref<*xf32> to memref<4xf32> + + // CHECK: memref.cast %{{.*}} : memref<4xf32> to memref> + %6 = memref.cast %arg0 : memref<4xf32> to memref> + + // CHECK: memref.cast {{%.*}} : memref> to memref<4xf32> + %7 = memref.cast %6 : memref> to memref<4xf32> + return } @@ -369,7 +410,7 @@ func.func @expand_collapse_shape_static( // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]] %r8 = memref.collapse_shape %arg8 [[0, 1, 2]] : memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>> into - memref<1024xi8, strided<[1], offset: 0>> + memref<1024xi8> // CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2, 3]] %r9 = memref.collapse_shape %arg9 [[0], [1, 2, 3]] : @@ -453,15 +494,15 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref, memref> // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1]] -// CHECK-SAME: memref> into memref> +// CHECK-SAME: memref> into memref %3 = memref.collapse_shape %arg3 [[0, 1]] : memref> into - memref> + memref // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]] output_shape [%arg6, 42] -// CHECK-SAME: memref> into memref +// CHECK-SAME: memref into memref %r3 = memref.expand_shape %3 [[0, 1]] output_shape [%arg6, 42] : - memref> into memref + memref into memref // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] %4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2] @@ -495,6 +536,8 @@ func.func @collapse_shape_to_dynamic func.func @expand_collapse_shape_transposed_layout( %m0: memref>, %m1: memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>, + %m2: memref>, + %m3: memref<4x5x6xf32, contiguous<[2, 1, 0]>>, %sz0: index, %sz1: index) { @@ -511,6 +554,17 @@ func.func @expand_collapse_shape_transposed_layout( %rr1 = memref.collapse_shape %r1 [[0, 1], [2], [3, 4]] : memref<2x2x5x2x3xf32, strided<[2, 1, ?, 3000, 1000], offset: 0>> into memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>> + + %r2 = memref.expand_shape %m2 [[0], [1, 2]] output_shape [%sz0, %sz1, 5] : + memref> into + memref> + %rr2 = memref.collapse_shape %r2 [[0], [1, 2]] : + memref> into + memref> + + %r3 = memref.expand_shape %m3 [[0, 1], [2], [3, 4]] output_shape [2, 2, 5, 2, 3] : + memref<4x5x6xf32, contiguous<[2, 1, 0]>> into + memref<2x2x5x2x3xf32, contiguous<[3, 4, 2, 0, 1]>> return } @@ -606,7 +660,7 @@ func.func @memref_memory_space_cast(%src : memref) -> memref { } // CHECK-LABEL: func @memref_transpose_map -func.func @memref_transpose_map(%src : memref) -> memref (d1 * s0 + d0)>> { - %dst = memref.transpose %src (i, j) -> (j, i) : memref to memref (d1 * s0 + d0)>> - return %dst : memref (d1 * s0 + d0)>> +func.func @memref_transpose_map(%src : memref) -> memref> { + %dst = memref.transpose %src (i, j) -> (j, i) : memref to memref> + return %dst : memref> } diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index acab37e482cfe..b217ebfcd64fd 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -51,9 +51,9 @@ func.func @multi_buffer(%in: memref<16xf32>) { // CHECK: scf.for %[[IV:.*]] = %[[C0]] scf.for %i0 = %c0 to %c16 step %c4 { // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]]) - // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>> + // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, contiguous<1, offset: ?>> %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> - // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>> + // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, contiguous<1, offset: ?>> memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32> "some_use"(%tmp) : (memref<4xf32>) ->() @@ -88,9 +88,9 @@ func.func @multi_buffer_on_affine_loop(%in: memref<16xf32>) { // CHECK: affine.for %[[IV:.*]] = 0 affine.for %i0 = 0 to 16 step 4 { // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]]) - // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>> + // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, contiguous<1, offset: ?>> %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> - // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>> + // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, contiguous<1, offset: ?>> memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32> "some_use"(%tmp) : (memref<4xf32>) ->() @@ -209,9 +209,9 @@ func.func @multi_buffer_one_alloc_with_use_outside_of_loop(%in: memref<16xf32>) // CHECK: scf.for %[[IV:.*]] = %[[C0]] scf.for %i0 = %c0 to %c16 step %c4 { // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]]) - // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>> + // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, contiguous<1, offset: ?>> %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> - // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>> + // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, contiguous<1, offset: ?>> memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32> "some_use"(%tmp) : (memref<4xf32>) ->() @@ -249,7 +249,7 @@ func.func @multi_buffer_no_analysis(%in: memref<16xf32>) { // CHECK: scf.for %[[IV:.*]] = %[[C0]] scf.for %i0 = %c0 to %c16 step %c4 { // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]]) - // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>> + // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, contiguous<1, offset: ?>> "some_write_read"(%tmp) : (memref<4xf32>) ->() } return @@ -284,7 +284,7 @@ func.func @multi_buffer_dealloc(%in: memref<16xf32>) { // CHECK: scf.for %[[IV:.*]] = %[[C0]] scf.for %i0 = %c0 to %c16 step %c4 { // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]]) - // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>> + // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, contiguous<1, offset: ?>> "some_write_read"(%tmp) : (memref<4xf32>) ->() } diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index c1beed95f2006..ff3dc6f15d910 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -324,8 +324,8 @@ func.func @tensor.insert_slice_rank_reducing_1( -> tensor { // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref - // CHECK: memref.subview %[[alloc]][%{{.*}}, %{{.*}}] [1, 1] [1, 1] : memref to memref> - // CHECK: memref.copy {{.*}} : memref to memref> + // CHECK: memref.subview %[[alloc]][%{{.*}}, %{{.*}}] [1, 1] [1, 1] : memref to memref> + // CHECK: memref.copy {{.*}} : memref to memref> %0 = tensor.insert_slice %f into %t1[%idx1, %idx2][1, 1][1, 1] : tensor into tensor return %0 : tensor @@ -402,9 +402,9 @@ func.func @tensor.expand_shape_of_slice( func.func @tensor.expand_shape_of_scalar_slice( %t1: tensor, %o1: index, %s1: index) -> tensor<1xf32> { // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : tensor to memref - // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] : memref to memref> + // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] : memref to memref> %0 = tensor.extract_slice %t1[%o1][1][1] : tensor to tensor - // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] output_shape [1] : memref into memref<1xf32, strided<[1], offset: ?>> + // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] output_shape [1] : memref> into memref<1xf32, contiguous<1, offset: ?>> %1 = tensor.expand_shape %0 [] output_shape [1] : tensor into tensor<1xf32> // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] // CHECK: return %[[r]] @@ -459,9 +459,9 @@ func.func @tensor.collapse_shape_to_scalar(%t1: tensor<1x1x1xf32>) -> tensor) -> tensor { - // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>> + // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, contiguous<1, offset: 1>> %0 = tensor.extract_slice %arg0[1] [1] [1] : tensor<2xi32> to tensor<1xi32> - // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, strided<[1], offset: 1>> into memref> + // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, contiguous<1, offset: 1>> into memref> %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor return %1 : tensor } @@ -643,8 +643,8 @@ func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: ten %result = scf.forall (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<4xf32> { %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<4xf32> to tensor<1xf32> scf.forall.in_parallel { - // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>> - // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>> + // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, contiguous<1, offset: ?>> + // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, contiguous<1, offset: ?>> tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : tensor<1xf32> into tensor<4xf32> } diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index 2983cd30258a5..b5b6df8822893 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -116,8 +116,8 @@ func.func @insert_slice_fun_not_inplace( { // CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) {alignment = 64 : i64} : memref // CHECK: memref.copy %[[A]], %[[ALLOC]] : memref - // CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref to memref<4xf32, strided<[1]>> - // CHECK: memref.copy %[[t]], %[[SV]] : memref<4xf32, strided{{.*}}> to memref<4xf32, strided<[1]>> + // CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref to memref<4xf32> + // CHECK: memref.copy %[[t]], %[[SV]] : memref<4xf32, strided{{.*}}> to memref<4xf32> %r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor // CHECK: return %{{.*}} : memref @@ -257,7 +257,7 @@ func.func @pad_memory_space(%t: tensor, %h1: index, %f: f32, %pos: index) // CHECK: outs(%[[padded_alloc]] : memref<15xf32, 3>) // CHECK: linalg.yield %{{.*}} // CHECK: } - // CHECK: %[[subview:.*]] = memref.subview {{.*}} : memref<15xf32, 3> to memref, 3> + // CHECK: %[[subview:.*]] = memref.subview {{.*}} : memref<15xf32, 3> to memref, 3> // CHECK: memref.copy %[[alloc_tensor]], %[[subview]] %1 = tensor.pad %0 low[2] high[%h1] { ^bb0(%arg0: index): @@ -332,9 +332,9 @@ func.func @dim_not_reading(%t: tensor, %f: f32, %pos: index) // CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0 + 5)> // CHECK-LABEL: func.func @cast_retains_buffer_layout( -// CHECK-SAME: %[[t:.*]]: memref, %[[sz:.*]]: index) -> memref> { +// CHECK-SAME: %[[t:.*]]: memref, %[[sz:.*]]: index) -> memref> { // CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref to memref<10xf32, #[[$map]]> -// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, #[[$map]]> to memref> +// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, #[[$map]]> to memref> // CHECK: return %[[slice]] func.func @cast_retains_buffer_layout( %t: tensor diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir index f78b4b6f6798c..d5384ed92fe8d 100644 --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -260,9 +260,9 @@ module { // CHECK-NOT: memref.copy func.func @canonicalization_and_cse(%m: memref<5xf32>) { %c2 = arith.constant 2 : index - %s0 = memref.subview %m[1] [2] [1] : memref<5xf32> to memref<2xf32, strided<[1], offset: 1>> - %s1 = memref.subview %m[1] [%c2] [1] : memref<5xf32> to memref> - memref.copy %s0, %s1 : memref<2xf32, strided<[1], offset: 1>> to memref> + %s0 = memref.subview %m[1] [2] [1] : memref<5xf32> to memref<2xf32, contiguous<1, offset: 1>> + %s1 = memref.subview %m[1] [%c2] [1] : memref<5xf32> to memref> + memref.copy %s0, %s1 : memref<2xf32, contiguous<1, offset: 1>> to memref> return } diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir index cd56c1bf9695b..8d9a706a5036d 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -127,8 +127,8 @@ func.func @contiguous_inner_most_zero_idx_in_bounds(%src: memref<16x1xf32>, %i:i // CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x1xf32> { // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>> -// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32, strided<[1]>>, vector<8xf32> +// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32>, vector<8xf32> // CHECK: vector.shape_cast %[[READ]] : vector<8xf32> to vector<8x1xf32> // The index to be dropped is == 0, so it's safe to collapse. The "out of @@ -144,8 +144,8 @@ func.func @contiguous_inner_most_zero_idx_out_of_bounds(%src: memref<16x1xf32>, // CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x1xf32> { // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>> -// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32, strided<[1]>>, vector<8xf32> +// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32>, vector<8xf32> // CHECK: vector.shape_cast %[[READ]] : vector<8xf32> to vector<8x1xf32> // The index to be dropped is unknown, but since it's "in bounds", it has to be @@ -159,8 +159,8 @@ func.func @contiguous_inner_most_non_zero_idx_in_bounds(%src: memref<16x1xf32>, // CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x1xf32> { // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>> -// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32, strided<[1]>>, vector<8xf32> +// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32>, vector<8xf32> // CHECK: vector.shape_cast %[[READ]] : vector<8xf32> to vector<8x1xf32> // Same as the top example within this split, but with the outer vector @@ -176,8 +176,8 @@ func.func @contiguous_inner_most_non_zero_idx_in_bounds_scalable(%src: memref<16 // CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<[8]x1xf32> { // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>> -// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32, strided<[1]>>, vector<[8]xf32> +// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32>, vector<[8]xf32> // CHECK: vector.shape_cast %[[READ]] : vector<[8]xf32> to vector<[8]x1xf32> // The index to be dropped is unknown and "out of bounds" - not safe to @@ -435,9 +435,9 @@ func.func @contiguous_inner_most_zero_idx_in_bounds(%dest: memref<16x1xf32>, %v: // CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>, // CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>, // CHECK-SAME: %[[IDX:.*]]: index) { -// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>> +// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32> // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32> -// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>> +// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32> // The index to be dropped is == 0, so it's safe to collapse. The "out of // bounds" attribute is too conservative and will be folded to "in bounds" @@ -451,9 +451,9 @@ func.func @contiguous_inner_most_zero_idx_out_of_bounds(%dest: memref<16x1xf32>, // CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>, // CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>, // CHECK-SAME: %[[IDX:.*]]: index) { -// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>> +// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32> // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32> -// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>> +// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32> // The index to be dropped is unknown, but since it's "in bounds", it has to be // == 0. It's safe to collapse the corresponding dim. @@ -465,9 +465,9 @@ func.func @contiguous_inner_most_dim_non_zero_idx_in_bounds(%dest: memref<16x1xf // CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>, // CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>, // CHECK-SAME: %[[IDX:.*]]: index) { -// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>> +// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32> // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32> -// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>> +// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32> // Same as the top example within this split, but with the outer vector // dim scalable. Note that this example only makes sense when "8 = [8]" (i.e. @@ -481,9 +481,9 @@ func.func @contiguous_inner_most_non_zero_idx_in_bounds_scalable(%dest: memref<1 // CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>, // CHECK-SAME: %[[VEC:.*]]: vector<[8]x1xf32> // CHECK-SAME: %[[IDX:.*]]: index) { -// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>> +// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32> // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<[8]x1xf32> to vector<[8]xf32> -// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<[8]xf32>, memref<16xf32, strided<[1]>> +// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<[8]xf32>, memref<16xf32> // The index to be dropped is unknown and "out of bounds" - not safe to // collapse. @@ -513,9 +513,9 @@ func.func @contiguous_inner_most_dim_with_subview(%dest: memref<1000x1xf32>, %i: // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, // CHECK-SAME: %[[VEC:.*]]: vector<4x1xf32>) { // CHECK: %[[SV_1:.*]] = memref.subview %[[MEM]]{{\[}}%[[IDX_1]], 0] [40, 1] [1, 1] : memref<1000x1xf32> to memref<40x1xf32, strided<[1, 1], offset: ?>> -// CHECK: %[[SV_2:.*]] = memref.subview %[[SV_1]][0, 0] [40, 1] [1, 1] : memref<40x1xf32, strided<[1, 1], offset: ?>> to memref<40xf32, strided<[1], offset: ?>> +// CHECK: %[[SV_2:.*]] = memref.subview %[[SV_1]][0, 0] [40, 1] [1, 1] : memref<40x1xf32, strided<[1, 1], offset: ?>> to memref<40xf32, contiguous<1, offset: ?>> // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<4x1xf32> to vector<4xf32> -// CHECK: vector.transfer_write %[[SC]], %[[SV_2]]{{\[}}%[[IDX_2]]] {in_bounds = [true]} : vector<4xf32>, memref<40xf32, strided<[1], offset: ?>> +// CHECK: vector.transfer_write %[[SC]], %[[SV_2]]{{\[}}%[[IDX_2]]] {in_bounds = [true]} : vector<4xf32>, memref<40xf32, contiguous<1, offset: ?>> // Same as the top example within this split, but with the outer vector // dim scalable. Note that this example only makes sense when "4 = [4]" (i.e. @@ -534,9 +534,9 @@ func.func @contiguous_inner_most_dim_with_subview_scalable_inner_dim(%dest: memr // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, // CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>) { // CHECK: %[[SV_1:.*]] = memref.subview %[[MEM]]{{\[}}%[[IDX_1]], 0] [40, 1] [1, 1] : memref<1000x1xf32> to memref<40x1xf32, strided<[1, 1], offset: ?>> -// CHECK: %[[SV_2:.*]] = memref.subview %[[SV_1]][0, 0] [40, 1] [1, 1] : memref<40x1xf32, strided<[1, 1], offset: ?>> to memref<40xf32, strided<[1], offset: ?>> +// CHECK: %[[SV_2:.*]] = memref.subview %[[SV_1]][0, 0] [40, 1] [1, 1] : memref<40x1xf32, strided<[1, 1], offset: ?>> to memref<40xf32, contiguous<1, offset: ?>> // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32> -// CHECK: vector.transfer_write %[[SC]], %[[SV_2]]{{\[}}%[[IDX_2]]] {in_bounds = [true]} : vector<[4]xf32>, memref<40xf32, strided<[1], offset: ?>> +// CHECK: vector.transfer_write %[[SC]], %[[SV_2]]{{\[}}%[[IDX_2]]] {in_bounds = [true]} : vector<[4]xf32>, memref<40xf32, contiguous<1, offset: ?>> // ----- @@ -552,9 +552,9 @@ func.func @contiguous_inner_most_dim_with_subview_2d(%dest: memref<1000x1x1xf32> // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, // CHECK-SAME: %[[VEC:.*]]: vector<4x1x1xf32>) { // CHECK: %[[SV_1:.*]] = memref.subview %[[MEM]]{{\[}}%[[IDX_1]], 0, 0] [40, 1, 1] [1, 1, 1] : memref<1000x1x1xf32> to memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>> -// CHECK: %[[SV_2:.*]] = memref.subview %[[SV_1]][0, 0, 0] [40, 1, 1] [1, 1, 1] : memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>> to memref<40xf32, strided<[1], offset: ?>> +// CHECK: %[[SV_2:.*]] = memref.subview %[[SV_1]][0, 0, 0] [40, 1, 1] [1, 1, 1] : memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>> to memref<40xf32, contiguous<1, offset: ?>> // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<4x1x1xf32> to vector<4xf32> -// CHECK: vector.transfer_write %[[SC]], %[[SV_2]]{{\[}}%[[IDX_2]]] {in_bounds = [true]} : vector<4xf32>, memref<40xf32, strided<[1], offset: ?>> +// CHECK: vector.transfer_write %[[SC]], %[[SV_2]]{{\[}}%[[IDX_2]]] {in_bounds = [true]} : vector<4xf32>, memref<40xf32, contiguous<1, offset: ?>> // Same as the top example within this split, but with the outer vector // dim scalable. Note that this example only makes sense when "4 = [4]" (i.e. @@ -572,9 +572,9 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable(%dest: memref<1000 // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, // CHECK-SAME: %[[VEC:.*]]: vector<[4]x1x1xf32>) { // CHECK: %[[SV_1:.*]] = memref.subview %[[MEM]]{{\[}}%[[IDX_1]], 0, 0] [40, 1, 1] [1, 1, 1] : memref<1000x1x1xf32> to memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>> -// CHECK: %[[SV_2:.*]] = memref.subview %[[SV_1]][0, 0, 0] [40, 1, 1] [1, 1, 1] : memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>> to memref<40xf32, strided<[1], offset: ?>> +// CHECK: %[[SV_2:.*]] = memref.subview %[[SV_1]][0, 0, 0] [40, 1, 1] [1, 1, 1] : memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>> to memref<40xf32, contiguous<1, offset: ?>> // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<[4]x1x1xf32> to vector<[4]xf32> -// CHECK: vector.transfer_write %[[SC]], %[[SV_2]]{{\[}}%[[IDX_2]]] {in_bounds = [true]} : vector<[4]xf32>, memref<40xf32, strided<[1], offset: ?>> +// CHECK: vector.transfer_write %[[SC]], %[[SV_2]]{{\[}}%[[IDX_2]]] {in_bounds = [true]} : vector<[4]xf32>, memref<40xf32, contiguous<1, offset: ?>> // ----- diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index e840dc6bbf224..ac279f2811c59 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -87,8 +87,8 @@ func.func @transfer_read_dims_mismatch_contiguous( // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8 // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> -// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8> +// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, contiguous<1, offset: ?>> +// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, contiguous<1, offset: ?>>, vector<4xi8> // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8> // CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8> @@ -110,15 +110,13 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( return %res : vector<1x2x6xi32> } -// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> - // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices( // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, // CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32> // CHECK: %[[C_0:.*]] = arith.constant 0 : i32 // CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index // CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> -// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]] +// CHECK: %[[COLLAPSED_IDX:.*]] = affine.linearize_index disjoint [%[[IDX_1]], %[[IDX_2]], %[[C_0_IDX]]] by (43, 4, 6) // CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32> // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices( @@ -142,12 +140,10 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( return %res : vector<2x2xf32> } -// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> - // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] // CHECK-SAME: : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> -// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]() +// CHECK: %[[APPLY:.*]] = affine.linearize_index disjoint [%{{.+}}, %{{.+}}] by (3, 2) // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( // CHECK-128B: memref.collapse_shape @@ -188,27 +184,37 @@ func.func @transfer_read_leading_dynamic_dims( // ----- -// One of the dims to be flattened is dynamic - not supported ATM. +// One of the dims to be flattened is dynamic and the layout is contiguous -func.func @negative_transfer_read_dynamic_dim_to_flatten( +func.func @transfer_read_dynamic_dim_to_flatten( %idx_1: index, %idx_2: index, - %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> { + %mem: memref<1x?x6x2xi32>) -> vector<1x3x2xi32> { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 { in_bounds = [true, true, true] - } : memref<1x?x4x6xi32>, vector<1x2x6xi32> - return %res : vector<1x2x6xi32> + } : memref<1x?x6x2xi32>, vector<1x3x2xi32> + return %res : vector<1x3x2xi32> } -// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten -// CHECK-NOT: memref.collapse_shape -// CHECK-NOT: vector.shape_cast +// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten +// CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index, %[[MEM:.+]]: memref<1x?x6x2xi32> +// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}} +// CHECK-SAME: : memref<1x?x6x2xi32> into memref<1x?xi32> +// CHECK: %[[IDX:.+]] = affine.linearize_index disjoint [%[[IDX_1]], %[[IDX_2]], %[[C0]]] by (6, 2) +// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK-SAME: [%[[C0]], %[[IDX]]], %[[C0_I32]] +// CHECK-SAME: {in_bounds = [true]} +// CHECK-SAME: : memref<1x?xi32>, vector<6xi32> +// CHECK: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<6xi32> to vector<1x3x2xi32> +// CHECK: return %[[RES]] : vector<1x3x2xi32> -// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten -// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten +// CHECK-128B: memref.collapse_shape // ----- @@ -355,9 +361,9 @@ func.func @transfer_write_dims_mismatch_contiguous( // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, // CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) { // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> +// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, contiguous<1, offset: ?>> // CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8> -// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>> +// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, contiguous<1, offset: ?>> // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous( // CHECK-128B: memref.collapse_shape @@ -377,14 +383,12 @@ func.func @transfer_write_dims_mismatch_non_zero_indices( return } -// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> - // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices( // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, // CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>, // CHECK-SAME: %[[VEC:.*]]: vector<1x2x6xi32>) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]] +// CHECK-DAG: %[[IDX:.*]] = affine.linearize_index disjoint [%[[IDX_1]], %[[IDX_2]], %[[C0]]] by (43, 4, 6) // CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32> // CHECK: vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32> @@ -409,10 +413,9 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( return } -// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> - // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( -// CHECK-DAG: %[[APPLY:.*]] = affine.apply #[[$MAP]]() +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[IDX:.*]] = affine.linearize_index disjoint [%{{.*}}, %[[C0]]] by (3, 2) // CHECK-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( @@ -451,27 +454,35 @@ func.func @transfer_write_leading_dynamic_dims( // ----- -// One of the dims to be flattened is dynamic - not supported ATM. +// One of the dims to be flattened is dynamic -func.func @negative_transfer_write_dynamic_to_flatten( +func.func @transfer_write_dynamic_to_flatten( %idx_1: index, %idx_2: index, - %vec : vector<1x2x6xi32>, - %mem: memref<1x?x4x6xi32>) { + %vec : vector<1x3x2xi32>, + %mem: memref<1x?x6x2xi32>) { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : - vector<1x2x6xi32>, memref<1x?x4x6xi32> + vector<1x3x2xi32>, memref<1x?x6x2xi32> return } -// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten -// CHECK-NOT: memref.collapse_shape -// CHECK-NOT: vector.shape_cast +// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten +// CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index, %[[VEC:.+]]: vector<1x3x2xi32>, %[[MEM:.+]]: memref<1x?x6x2xi32> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}} +// CHECK-SAME: : memref<1x?x6x2xi32> into memref<1x?xi32> +// CHECK: %[[IDX:.+]] = affine.linearize_index disjoint [%[[IDX_1]], %[[IDX_2]], %[[C0]]] by (6, 2) +// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x3x2xi32> to vector<6xi32> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] +// CHECK-SAME: [%[[C0]], %[[IDX]]] +// CHECK-SAME: {in_bounds = [true]} +// CHECK-SAME: : vector<6xi32>, memref<1x?xi32> -// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten -// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten +// CHECK-128B: memref.collapse_shape // ----- diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir index 977aec2536b1e..5c87ef6cf90d7 100644 --- a/mlir/test/IR/affine-map.mlir +++ b/mlir/test/IR/affine-map.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt -allow-unregistered-dialect %s | FileCheck %s -// Identity maps used in trivial compositions in MemRefs are optimized away. + #map0 = affine_map<(i, j) -> (i, j)> #map1 = affine_map<(i, j)[s0] -> (i, j)> @@ -207,12 +207,12 @@ // CHECK: #map{{[0-9]*}} = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0)> #map64 = affine_map<(i0, i1)[mod] -> (i0 + i1 + mod)> -// Single identity maps are removed. -// CHECK: @f0(memref<2x4xi8, 1>) +// Single identity maps are not removed anymore (migrate to contiguous). +// CHECK: @f0(memref<2x4xi8, #map{{[0-9]*}}, 1>) func.func private @f0(memref<2x4xi8, #map0, 1>) -// Single identity maps are removed. -// CHECK: @f1(memref<2x4xi8, 1>) +// Single identity maps are not removed anymore (migrate to contiguous). +// CHECK: @f1(memref<2x4xi8, affine_map<(d0, d1)[s0] -> (d0, d1)>, 1>) func.func private @f1(memref<2x4xi8, #map1, 1>) // CHECK: @f2(memref) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index cace1fefa43d6..07bfd3ca22510 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -6,14 +6,14 @@ // CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d4, d3)> #map = affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d4, d3)> +// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> + // CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0) -> (d0)> -#map1 = affine_map<(d0) -> (d0)> +#map2 = affine_map<(d0) -> (d0)> // CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #map{{[0-9]*}} = affine_map<()[s0] -> (0, s0 - 1)> #inline_map_minmax_loop1 = affine_map<()[s0] -> (0, s0 - 1)> @@ -83,23 +83,17 @@ func.func private @tensor_encoding(tensor<16x32xf64, "sparse">) // CHECK: func private @large_shape_dimension(tensor<9223372036854775807xf32>) func.func private @large_shape_dimension(tensor<9223372036854775807xf32>) -// CHECK: func private @functions((memref<1x?x4x?x?xi32, #map>, memref<8xi8>) -> (), () -> ()) -func.func private @functions((memref<1x?x4x?x?xi32, #map, 0>, memref<8xi8, #map1, 0>) -> (), ()->()) +// CHECK: func private @functions((memref<1x?x4x?x?xi32, #map>, memref<8xi8, #map2>) -> (), () -> ()) +func.func private @functions((memref<1x?x4x?x?xi32, #map, 0>, memref<8xi8, #map2, 0>) -> (), ()->()) + +// CHECK: func private @functions_elide_layout((memref<1x?x4x?x?xi32, #map>, memref<8xi8>) -> (), () -> ()) +func.func private @functions_elide_layout((memref<1x?x4x?x?xi32, #map, 0>, memref<8xi8, contiguous<1>, 0>) -> (), ()->()) // CHECK: func private @memrefs2(memref<2x4x8xi8, 1>) -func.func private @memrefs2(memref<2x4x8xi8, #map2, 1>) +func.func private @memrefs2(memref<2x4x8xi8, contiguous<3>, 1>) // CHECK: func private @memrefs3(memref<2x4x8xi8>) -func.func private @memrefs3(memref<2x4x8xi8, affine_map<(d0, d1, d2) -> (d0, d1, d2)>>) - -// CHECK: func private @memrefs_drop_triv_id_inline(memref<2xi8>) -func.func private @memrefs_drop_triv_id_inline(memref<2xi8, affine_map<(d0) -> (d0)>>) - -// CHECK: func private @memrefs_drop_triv_id_inline0(memref<2xi8>) -func.func private @memrefs_drop_triv_id_inline0(memref<2xi8, affine_map<(d0) -> (d0)>, 0>) - -// CHECK: func private @memrefs_drop_triv_id_inline1(memref<2xi8, 1>) -func.func private @memrefs_drop_triv_id_inline1(memref<2xi8, affine_map<(d0) -> (d0)>, 1>) +func.func private @memrefs3(memref<2x4x8xi8, contiguous<3>>) // Test memref with custom memory space @@ -107,25 +101,40 @@ func.func private @memrefs_drop_triv_id_inline1(memref<2xi8, affine_map<(d0) -> func.func private @memrefs_nomap_nospace(memref<5x6x7xf32>) // CHECK: func private @memrefs_map_nospace(memref<5x6x7xf32, #map{{[0-9]*}}>) -func.func private @memrefs_map_nospace(memref<5x6x7xf32, #map3>) +func.func private @memrefs_map_nospace(memref<5x6x7xf32, #map1>) // CHECK: func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>) func.func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>) // CHECK: func private @memrefs_map_intspace(memref<5x6x7xf32, #map{{[0-9]*}}, 5>) -func.func private @memrefs_map_intspace(memref<5x6x7xf32, #map3, 5>) +func.func private @memrefs_map_intspace(memref<5x6x7xf32, #map1, 5>) // CHECK: func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">) func.func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">) // CHECK: func private @memrefs_map_strspace(memref<5x6x7xf32, #map{{[0-9]*}}, "private">) -func.func private @memrefs_map_strspace(memref<5x6x7xf32, #map3, "private">) +func.func private @memrefs_map_strspace(memref<5x6x7xf32, #map1, "private">) // CHECK: func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1 : i64}>) func.func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1}>) // CHECK: func private @memrefs_map_dictspace(memref<5x6x7xf32, #map{{[0-9]*}}, {memSpace = "special", subIndex = 3 : i64}>) -func.func private @memrefs_map_dictspace(memref<5x6x7xf32, #map3, {memSpace = "special", subIndex = 3}>) +func.func private @memrefs_map_dictspace(memref<5x6x7xf32, #map1, {memSpace = "special", subIndex = 3}>) + +// CHECK func private @memrefs_contiguous_attr(memref<5x6x7xf32>) +func.func private @memrefs_contiguous_attr(memref<5x6x7xf32, contiguous<3, offset: 0>>) + +// CHECK func private @memrefs_contiguous_attr_long_perm(memref<5x6x7xf32>) +func.func private @memrefs_contiguous_attr_long_perm(memref<5x6x7xf32, contiguous<[0, 1, 2], offset: 0>>) + +// CHECK func private @memrefs_contiguous_attr_static_offset(memref<5x6x7xf32, contiguous<3, offset: 5>>) +func.func private @memrefs_contiguous_attr_static_offset(memref<5x6x7xf32, contiguous<3, offset: 5>>) + +// CHECK func private @memrefs_contiguous_attr_dynamic_offset(memref<5x6x7xf32, contiguous<3, offset: ?>>) +func.func private @memrefs_contiguous_attr_dynamic_offset(memref<5x6x7xf32, contiguous<3, offset: ?>>) + +// CHECK func private @memrefs_contiguous_attr_0d_as_list(memref) +func.func private @memrefs_contiguous_attr_0d_as_list(memref>) // CHECK: func private @complex_types(complex) -> complex func.func private @complex_types(complex) -> complex @@ -386,13 +395,13 @@ func.func @attributes() { "foo"() {a = 1, b = -423, c = [true, false], d = 16.0 } : () -> () // CHECK: "foo"() {map1 = #map{{[0-9]*}}} - "foo"() {map1 = #map1} : () -> () + "foo"() {map1 = #map2} : () -> () // CHECK: "foo"() {map2 = #map{{[0-9]*}}} "foo"() {map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>} : () -> () // CHECK: "foo"() {map12 = [#map{{[0-9]*}}, #map{{[0-9]*}}]} - "foo"() {map12 = [#map1, #map2]} : () -> () + "foo"() {map12 = [#map2, #map3]} : () -> () // CHECK: "foo"() {set1 = #set{{[0-9]*}}} "foo"() {set1 = #set1} : () -> () diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir index 37eff99d3cebf..1822944a16877 100644 --- a/mlir/test/IR/print-attr-type-aliases.mlir +++ b/mlir/test/IR/print-attr-type-aliases.mlir @@ -61,8 +61,7 @@ // Check that we don't print aliases for things that aren't printed. // CHECK: = loc(fused -// CHECK-NOT: #map -"test.op"() {alias_test = loc(fused (d0)>>>["test.mlir":10:8])} : () -> () +"test.op"() {alias_test = loc(fused>>["test.mlir":10:8])} : () -> () // ----- diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py index b91fdc367cf30..d974659c8cdcb 100644 --- a/mlir/test/python/dialects/memref.py +++ b/mlir/test/python/dialects/memref.py @@ -217,21 +217,10 @@ def check_strides_offset(memref, np_view): check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :]) # fmt: on - # default strides and offset means no stridedlayout attribute means affinemap layout + # default strides and offset means no stridedlayout attribute means contiguous layout assert memref.subview( mem1, (0, 0, 0, 0), (10, 22, 3, 44), (1, 1, 1, 1) - ).type.layout == AffineMapAttr.get( - AffineMap.get( - 4, - 0, - [ - AffineDimExpr.get(0), - AffineDimExpr.get(1), - AffineDimExpr.get(2), - AffineDimExpr.get(3), - ], - ) - ) + ).type.layout == ContiguousLayoutAttr.get_row_major(0, 4) shape = (7, 22, 30, 44) golden_mem = np.zeros(shape, dtype=np.int32) diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py index 2f3c4460d3f59..cf677af87575a 100644 --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -670,6 +670,39 @@ def testStridedLayoutAttr(): print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}") +# CHECK-LABEL: TEST: testContiguousLayoutAttr +@run +def testContiguousLayoutAttr(): + with Context(): + attr = ContiguousLayoutAttr.get(42, [2, 1, 0]) + # CHECK: contiguous<[2, 1, 0], offset: 42> + print(attr) + # CHECK: 42 + print(attr.offset) + # CHECK: 3 + print(len(attr.permutation)) + # CHECK: 2 + print(attr.permutation[0]) + # CHECK: 1 + print(attr.permutation[1]) + # CHECK: 0 + print(attr.permutation[2]) + + dynamic = ShapedType.get_dynamic_stride_or_offset() + attr = ContiguousLayoutAttr.get_row_major(dynamic, 3) + # CHECK: contiguous<3, offset: ?> + print(attr) + # CHECK: offset is dynamic: True + print(f"offset is dynamic: {attr.offset == dynamic}") + # CHECK: rank: 3 + print(f"rank: {len(attr.permutation)}") + # CHECK: 0 + print(attr.permutation[0]) + # CHECK: 1 + print(attr.permutation[1]) + # CHECK: 2 + print(attr.permutation[2]) + # CHECK-LABEL: TEST: testConcreteTypesRoundTrip @run def testConcreteTypesRoundTrip(): diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py index 6ce0fc12d8082..2685664d6c012 100644 --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -487,7 +487,7 @@ def testMemRefType(): memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) # CHECK: memref type: memref<2x3xf32, 2> print("memref type:", memref_f32) - # CHECK: memref layout: AffineMapAttr(affine_map<(d0, d1) -> (d0, d1)>) + # CHECK: memref layout: ContiguousLayoutAttr(contiguous<2>) print("memref layout:", repr(memref_f32.layout)) # CHECK: memref affine map: (d0, d1) -> (d0, d1) print("memref affine map:", memref_f32.affine_map) diff --git a/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp index 3937095c119c3..b5ff79393cf64 100644 --- a/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp +++ b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp @@ -24,7 +24,7 @@ TEST(InferShapeTest, inferRankReducedShapeIdentity) { /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1}); auto expectedType = MemRefType::get( {2}, b.getIndexType(), - StridedLayoutAttr::get(&ctx, /*offset=*/13, /*strides=*/{1})); + ContiguousLayoutAttr::get(&ctx, /*offset=*/13, /*rank=*/1)); EXPECT_EQ(reducedType, expectedType); } @@ -40,7 +40,7 @@ TEST(InferShapeTest, inferRankReducedShapeNonIdentity) { /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1}); auto expectedType = MemRefType::get( {2}, b.getIndexType(), - StridedLayoutAttr::get(&ctx, /*offset=*/2003, /*strides=*/{1})); + ContiguousLayoutAttr::get(&ctx, /*offset=*/2003, /*permutation=*/1)); EXPECT_EQ(reducedType, expectedType); } @@ -55,6 +55,6 @@ TEST(InferShapeTest, inferRankReducedShapeToScalar) { /*resultShape=*/{}, sourceMemref, {2, 3}, {1, 1}, {1, 1}); auto expectedType = MemRefType::get( {}, b.getIndexType(), - StridedLayoutAttr::get(&ctx, /*offset=*/2003, /*strides=*/{})); + ContiguousLayoutAttr::get(&ctx, /*offset=*/2003, /*permutation=*/{})); EXPECT_EQ(reducedType, expectedType); }