Description
Hi, I'm working on implementing support for aten::as_strided.
%9 = "tensor.collapse_shape"(%0) {reassociation = [[0, 1]]} : (tensor<?x?xf32>) -> tensor<?xf32>
%21 = "tensor.extract_slice"(%9, %15, %16, %17, %18, %19, %20) {operand_segment_sizes = array<i32: 1, 2, 2, 2>, static_offsets = [-9223372036854775808, -9223372036854775808], static_sizes = [-1, -1], static_strides = [-9223372036854775808, -9223372036854775808]} : (tensor<?xf32>, index, index, index, index, index, index) -> tensor<?xf32>
%22 = "tensor.cast"(%21) : (tensor<?xf32>) -> tensor<?xf32>
My inital strategy was going to be to use tensor.extract_slice, however aten::as_strided and tensor.extract_slice have different implementations of striding. Tensor.extract_slice treats a matrix as if it is 1 dimensional, whereas tensor.extract_slice uses its strides with dimensions.
My strategy became to flatten the input matrix to a 1 dimensional vector to replicate the behaviour of aten::as_strided, as I confirmed a [25,1] matrix when added to tensor.extract_slice passed.
However, when passing in a collapsed version of the input with collapse_shape, I realized it would not work, as tensor.extract_slice keeps the same matrices rank with its arguments.
My most up-to-date code is in #1656
Here is the error I am currently getting:
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed with the following diagnostics:
error: expected 1 offset values
note: see current operation: %21 = "tensor.extract_slice"(%9, %15, %16, %17, %18, %19, %20) {operand_segment_sizes = array<i32: 1, 2, 2, 2>, static_offsets = [-9223372036854775808, -9223372036854775808], static_sizes = [-1, -1], static_strides = [-9223372036854775808, -9223372036854775808]} : (tensor<?xf32>, index, index, index, index, index, index) -> tensor<?xf32>
And the test:
class AsStridedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.as_strided(x, (2, 2), (3, 3), 1)
@register_test_case(module_factory=lambda: AsStridedModule())
def AsStridedModule_basic(module, tu: TestUtils):
x = torch.randn(25)
print(x)
print (torch.ops.aten.as_strided(x, (2, 2), (3,3), 1))
module.forward(x)