diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5c61f9f27b34..52e207e90e05 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7560,6 +7560,31 @@ def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ }]; } +def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchOptionalIntType:$storage_offset + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenAsStridedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 799d0f91f5a6..d1f2478a3788 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -29,6 +29,8 @@ #include +#include + using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; @@ -1088,6 +1090,110 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenAsStridedOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenAsStridedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + op.dump(); + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + Value self = op.self(); + int64_t rank = getTensorRank(self); + + std::cout << "Rank: " << rank << std::endl; + TypeConverter *typeConverter = getTypeConverter(); + MLIRContext *context = op.getContext(); + + auto input = adaptor.self(); + RankedTensorType resultType = + typeConverter->convertType(op->getResult(0).getType()) + .cast(); + + + + Value so_value = op.storage_offset(); + Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + + Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value matrixSize = rewriter.create(loc, rewriter.getI64IntegerAttr(25)); + + SmallVector flattenShapeIndices; + flattenShapeIndices.emplace_back(); + int j = 0; + for (auto i : llvm::seq(0, resultType.getRank())) { + flattenShapeIndices[0].push_back(i); + } + + SmallVector flattenShapeValue; + flattenShapeValue.push_back(matrixSize); + + SmallVector flattenShape; + flattenShape.push_back(kUnknownSize); + + auto reducedResultType = + RankedTensorType::get(flattenShape, resultType.getElementType()); + + + auto input_flattened = + rewriter + .create(loc, reducedResultType, + input, flattenShapeIndices) + .getResult(); + + int64_t storage_offset; + if (!matchPattern(so_value, m_TorchConstantInt(&storage_offset))) + return op.emitError("unknown error: storage_offset not found"); + SmallVector offsets; + Value constantSize = + rewriter.create(loc, rewriter.getI64IntegerAttr(storage_offset)); + + offsets.push_back(constantSize); + offsets.push_back(one); + + SmallVector resultShape; + if (!getListConstructElements(op.size(), resultShape)) { + return rewriter.notifyMatchFailure(op, + "unimplemented: the target size is " + "not constructed from ListConstruct"); + } + + SmallVector convertedSizeVector = getTypeConvertedValues( + rewriter, loc, typeConverter, resultShape); + + SmallVector strides; + if (!getListConstructElements(op.stride(), strides)) { + return rewriter.notifyMatchFailure(op, + "unimplemented: the target size is " + "not constructed from ListConstruct"); + } + + SmallVector convertedStridesVector = getTypeConvertedValues( + rewriter, loc, typeConverter, strides); + + + auto index_storage_offset = castIntVectorToIndexVector(rewriter, loc, offsets); + + auto index_size = castIntVectorToIndexVector(rewriter, loc, convertedSizeVector); + + auto index_strides = castIntVectorToIndexVector(rewriter, loc, convertedStridesVector); + + Value result = rewriter.create( + loc, reducedResultType, input_flattened, index_storage_offset, index_size, index_strides); + + rewriter.replaceOpWithNewOp(op, reducedResultType, result); + + op.dump(); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenCatOp : public OpConversionPattern { public: @@ -1359,6 +1465,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index b0d961ced00f..f1ce0554b705 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -697,7 +697,7 @@ void TypeAnalysis::visitOperation(Operation *op, AtenSliceScatterOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenIndexTensorOp, Aten_IndexPutImplOp, AtenIndexPutOp, - AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, + AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenAsStridedOp, AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index babc6dd18565..de83139acbaa 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6045,6 +6045,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.expand\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index d729f81ae224..d73ae757ab78 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -161,7 +161,7 @@ bool Torch::isViewLikeOp(Operation *op) { AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, - AtenNarrowOp, AtenToDeviceOp>(op); + AtenNarrowOp, AtenToDeviceOp, AtenAsStridedOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index a3d4b11cf758..75e993c22605 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -663,6 +663,9 @@ def aten〇repeat(self: List[int], repeats: List[int]) -> List[int]: def aten〇roll(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇as_strided(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇expand(self: List[int], size: List[int], implicit: bool = False) -> List[int]: return upstream_shape_functions.expand(self, size) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index e71c5ad6f4c9..77cc2279ac54 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -510,6 +510,7 @@ def emit_with_mutating_variants(key, **kwargs): # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") + emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)") emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)") emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 3ced7b46b794..4162c548b19a 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2410,6 +2410,29 @@ def forward(self, x, y): def CopyWithDifferentSizesModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 1)) +# ============================================================================== + +class AsStridedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, x): + return torch.ops.aten.as_strided(x, (2, 2), (3, 3), 1) + +@register_test_case(module_factory=lambda: AsStridedModule()) +def AsStridedModule_basic(module, tu: TestUtils): + x = torch.randn(25, 1, 1) + print(x) + print (torch.ops.aten.as_strided(x, (2, 2), (3,3), 1)) + module.forward(x) + +# ============================================================================== class CopyWithDifferentDTypesModule(torch.nn.Module):