Open
Description
This was tracked previously in pull request #1683 which was closed. That may be used to some degree as reference, but it's quite outdated so not likely a reliable starting point.
Torch versions 2.6.0+ emit this op frequently in model exports through dynamo, without option to decompose at fx graph construction. The best way to fix forward seems to be implementing the lowering for the op explicitly despite the difficulty of doing so.
The previous PR mentioned above supplies a test case for this op:
# ==============================================================================
class AsStridedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.as_strided(x, (2, 2), (3, 3), 1)
@register_test_case(module_factory=lambda: AsStridedModule())
def AsStridedModule_basic(module, tu: TestUtils):
x = torch.randn(25, 1, 1)
print(x)
print (torch.ops.aten.as_strided(x, (2, 2), (3,3), 1))
module.forward(x)
# ==============================================================================
Running the following script with torch >= 2.7.0 will show how relatively trivial ops like chunk
end up relying on as_strided in torch dialect IR when exported via torch-mlir's fx export_and_import.
import torch
from torch_mlir import fx
N_CHUNK = 6
class ChunkModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.chunk(N_CHUNK, dim=1)
def sample_inputs(self):
return torch.rand([1,N_CHUNK,2048]) # dims here are arbitrary
module = ChunkModule().eval()
export_output = fx.export_and_import(module, module.sample_inputs(), output_type="torch")
export_output.dump()
Result torch dialect IR:
(show/hide)
module {
func.func @main(%arg0: !torch.vtensor<[1,6,2048],f32>) -> (!torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>) {
%int10240 = torch.constant.int 10240
%int8192 = torch.constant.int 8192
%int6144 = torch.constant.int 6144
%int4096 = torch.constant.int 4096
%int0 = torch.constant.int 0
%int12288 = torch.constant.int 12288
%int1 = torch.constant.int 1
%int2048 = torch.constant.int 2048
%0 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.as_strided %arg0, %0, %1, %int0 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
%3 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%4 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%5 = torch.aten.as_strided %arg0, %3, %4, %int2048 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
%6 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%7 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%8 = torch.aten.as_strided %arg0, %6, %7, %int4096 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
%9 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%10 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%11 = torch.aten.as_strided %arg0, %9, %10, %int6144 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
%12 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%13 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.as_strided %arg0, %12, %13, %int8192 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
%15 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%16 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%17 = torch.aten.as_strided %arg0, %15, %16, %int10240 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
return %2, %5, %8, %11, %14, %17 : !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>
}
}
Metadata
Metadata
Assignees
Labels
No labels