Skip to content

(TorchToLinalg) Support for lowering torch.aten.as_strided #4191

Open
@monorimet

Description

@monorimet

This was tracked previously in pull request #1683 which was closed. That may be used to some degree as reference, but it's quite outdated so not likely a reliable starting point.

Torch versions 2.6.0+ emit this op frequently in model exports through dynamo, without option to decompose at fx graph construction. The best way to fix forward seems to be implementing the lowering for the op explicitly despite the difficulty of doing so.

The previous PR mentioned above supplies a test case for this op:

# ==============================================================================

class AsStridedModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([-1, -1], torch.float32, True),
    ])

    def forward(self, x):
        return torch.ops.aten.as_strided(x, (2, 2), (3, 3), 1)

@register_test_case(module_factory=lambda: AsStridedModule())
def AsStridedModule_basic(module, tu: TestUtils):
    x = torch.randn(25, 1, 1)
    print(x)
    print (torch.ops.aten.as_strided(x, (2, 2), (3,3), 1))
    module.forward(x)

# ==============================================================================

Running the following script with torch >= 2.7.0 will show how relatively trivial ops like chunk end up relying on as_strided in torch dialect IR when exported via torch-mlir's fx export_and_import.

import torch

from torch_mlir import fx
N_CHUNK = 6

class ChunkModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.chunk(N_CHUNK, dim=1)
    
    def sample_inputs(self):
        return torch.rand([1,N_CHUNK,2048]) # dims here are arbitrary
    
module = ChunkModule().eval()
export_output = fx.export_and_import(module, module.sample_inputs(), output_type="torch")
export_output.dump()

Result torch dialect IR:

(show/hide)
module {
  func.func @main(%arg0: !torch.vtensor<[1,6,2048],f32>) -> (!torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>) {
    %int10240 = torch.constant.int 10240
    %int8192 = torch.constant.int 8192
    %int6144 = torch.constant.int 6144
    %int4096 = torch.constant.int 4096
    %int0 = torch.constant.int 0
    %int12288 = torch.constant.int 12288
    %int1 = torch.constant.int 1
    %int2048 = torch.constant.int 2048
    %0 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.as_strided %arg0, %0, %1, %int0 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    %3 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %4 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %5 = torch.aten.as_strided %arg0, %3, %4, %int2048 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    %6 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %7 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %8 = torch.aten.as_strided %arg0, %6, %7, %int4096 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    %9 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %10 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %11 = torch.aten.as_strided %arg0, %9, %10, %int6144 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    %12 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %13 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %14 = torch.aten.as_strided %arg0, %12, %13, %int8192 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    %15 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %16 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %17 = torch.aten.as_strided %arg0, %15, %16, %int10240 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    return %2, %5, %8, %11, %14, %17 : !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>
  }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions