Skip to content

[RFC] Support for aten::as_strided #1683

Open
@JakopinA

Description

@JakopinA

Hi, I'm working on implementing support for aten::as_strided.

  %9 = "tensor.collapse_shape"(%0) {reassociation = [[0, 1]]} : (tensor<?x?xf32>) -> tensor<?xf32>
  %21 = "tensor.extract_slice"(%9, %15, %16, %17, %18, %19, %20) {operand_segment_sizes = array<i32: 1, 2, 2, 2>, static_offsets = [-9223372036854775808, -9223372036854775808], static_sizes = [-1, -1], static_strides = [-9223372036854775808, -9223372036854775808]} : (tensor<?xf32>, index, index, index, index, index, index) -> tensor<?xf32>
  %22 = "tensor.cast"(%21) : (tensor<?xf32>) -> tensor<?xf32>

My inital strategy was going to be to use tensor.extract_slice, however aten::as_strided and tensor.extract_slice have different implementations of striding. Tensor.extract_slice treats a matrix as if it is 1 dimensional, whereas tensor.extract_slice uses its strides with dimensions.
My strategy became to flatten the input matrix to a 1 dimensional vector to replicate the behaviour of aten::as_strided, as I confirmed a [25,1] matrix when added to tensor.extract_slice passed.
However, when passing in a collapsed version of the input with collapse_shape, I realized it would not work, as tensor.extract_slice keeps the same matrices rank with its arguments.

My most up-to-date code is in #1656

Here is the error I am currently getting:

torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed with the following diagnostics:
        error: expected 1 offset values
        note: see current operation: %21 = "tensor.extract_slice"(%9, %15, %16, %17, %18, %19, %20) {operand_segment_sizes = array<i32: 1, 2, 2, 2>, static_offsets = [-9223372036854775808, -9223372036854775808], static_sizes = [-1, -1], static_strides = [-9223372036854775808, -9223372036854775808]} : (tensor<?xf32>, index, index, index, index, index, index) -> tensor<?xf32>

And the test:

class AsStridedModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([-1, -1], torch.float32, True),
    ])

    def forward(self, x):
        return torch.ops.aten.as_strided(x, (2, 2), (3, 3), 1)

@register_test_case(module_factory=lambda: AsStridedModule())
def AsStridedModule_basic(module, tu: TestUtils):
    x = torch.randn(25)
    print(x)
    print (torch.ops.aten.as_strided(x, (2, 2), (3,3), 1))
    module.forward(x)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions