Skip to content

[MLIR][TORCH] Add linalg support for aten::as_strided #1656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7560,6 +7560,31 @@ def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
}];
}

def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchOptionalIntType:$storage_offset
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenAsStridedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
108 changes: 108 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include <numeric>

#include <iostream>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
Expand Down Expand Up @@ -1088,6 +1090,110 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
};
} // namespace

namespace {
class ConvertAtenAsStridedOp : public OpConversionPattern<AtenAsStridedOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenAsStridedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

op.dump();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Location loc = op.getLoc();
Value self = op.self();
int64_t rank = getTensorRank(self);

std::cout << "Rank: " << rank << std::endl;
TypeConverter *typeConverter = getTypeConverter();
MLIRContext *context = op.getContext();

auto input = adaptor.self();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();



Value so_value = op.storage_offset();
Value zero = rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));

Value one = rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
Value matrixSize = rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(25));

SmallVector<ReassociationIndices> flattenShapeIndices;
flattenShapeIndices.emplace_back();
int j = 0;
for (auto i : llvm::seq<int64_t>(0, resultType.getRank())) {
flattenShapeIndices[0].push_back(i);
}

SmallVector<Value> flattenShapeValue;
flattenShapeValue.push_back(matrixSize);

SmallVector<int64_t> flattenShape;
flattenShape.push_back(kUnknownSize);

auto reducedResultType =
RankedTensorType::get(flattenShape, resultType.getElementType());


auto input_flattened =
rewriter
.create<tensor::CollapseShapeOp>(loc, reducedResultType,
input, flattenShapeIndices)
.getResult();

int64_t storage_offset;
if (!matchPattern(so_value, m_TorchConstantInt(&storage_offset)))
return op.emitError("unknown error: storage_offset not found");
SmallVector<Value> offsets;
Value constantSize =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(storage_offset));

offsets.push_back(constantSize);
offsets.push_back(one);

SmallVector<Value> resultShape;
if (!getListConstructElements(op.size(), resultShape)) {
return rewriter.notifyMatchFailure(op,
"unimplemented: the target size is "
"not constructed from ListConstruct");
}

SmallVector<Value> convertedSizeVector = getTypeConvertedValues(
rewriter, loc, typeConverter, resultShape);

SmallVector<Value> strides;
if (!getListConstructElements(op.stride(), strides)) {
return rewriter.notifyMatchFailure(op,
"unimplemented: the target size is "
"not constructed from ListConstruct");
}

SmallVector<Value> convertedStridesVector = getTypeConvertedValues(
rewriter, loc, typeConverter, strides);


auto index_storage_offset = castIntVectorToIndexVector(rewriter, loc, offsets);

auto index_size = castIntVectorToIndexVector(rewriter, loc, convertedSizeVector);

auto index_strides = castIntVectorToIndexVector(rewriter, loc, convertedStridesVector);

Value result = rewriter.create<tensor::ExtractSliceOp>(
loc, reducedResultType, input_flattened, index_storage_offset, index_size, index_strides);

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, reducedResultType, result);

op.dump();
return success();
}
};
} // namespace

namespace {
class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {
public:
Expand Down Expand Up @@ -1359,6 +1465,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
target.addIllegalOp<AtenCatOp>();
patterns.add<ConvertAtenCatOp>(typeConverter, context);
target.addIllegalOp<AtenAsStridedOp>();
patterns.add<ConvertAtenAsStridedOp>(typeConverter, context);
target.addIllegalOp<AtenBroadcastToOp>();
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
target.addIllegalOp<AtenContiguousOp>();
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenSliceScatterOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp,
AtenZero_Op, AtenIndexTensorOp, Aten_IndexPutImplOp, AtenIndexPutOp,
AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp,
AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenAsStridedOp,
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6045,6 +6045,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.expand\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ bool Torch::isViewLikeOp(Operation *op) {
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenToDeviceOp>(op);
AtenNarrowOp, AtenToDeviceOp, AtenAsStridedOp>(op);
}

Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,9 @@ def aten〇repeat(self: List[int], repeats: List[int]) -> List[int]:
def aten〇roll(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇as_strided(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇expand(self: List[int], size: List[int], implicit: bool = False) -> List[int]:
return upstream_shape_functions.expand(self, size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def emit_with_mutating_variants(key, **kwargs):

# Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)")
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")
Expand Down
23 changes: 23 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2410,6 +2410,29 @@ def forward(self, x, y):
def CopyWithDifferentSizesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 1))

# ==============================================================================

class AsStridedModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])

def forward(self, x):
return torch.ops.aten.as_strided(x, (2, 2), (3, 3), 1)

@register_test_case(module_factory=lambda: AsStridedModule())
def AsStridedModule_basic(module, tu: TestUtils):
x = torch.randn(25, 1, 1)
print(x)
print (torch.ops.aten.as_strided(x, (2, 2), (3,3), 1))
module.forward(x)

# ==============================================================================

class CopyWithDifferentDTypesModule(torch.nn.Module):

Expand Down