Skip to content
31 changes: 21 additions & 10 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7327,9 +7327,18 @@ def aten_repeat_interleave_self_int(
self_rank = len(self.shape)
pos_dim = (dim + self_rank) % self_rank
unsqueezed = op.Unsqueeze(self, [pos_dim + 1])
tiles = [1] * (self_rank + 1)
tiles[pos_dim + 1] = repeats
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
if isinstance(repeats, int):
tiles = [1] * (self_rank + 1)
tiles[pos_dim + 1] = repeats
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
else:
# repeats is a symbolic tensor
tile_repeat = op.Concat(
op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)),
op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))),
op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)),
axis=0,
)
tiled = op.Tile(unsqueezed, tile_repeat)
if self_rank == 1:
return op.Identity(tiled)
Expand Down Expand Up @@ -7375,20 +7384,22 @@ def aten_repeat_interleave_Tensor(
if dim is None:
# flatten
self = op.Reshape(self, [-1])
rk = 1
rank = 1
else:
rk = len(self.shape)
rank = len(self.shape)

if rk > 2:
if rank > 2:
shape_x0 = op.Shape(self, start=0, end=1)
shape_x = op.Shape(self, start=1)
self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0))
elif rk == 1:
elif rank == 1:
shape_x = None
self = op.Reshape(self, [-1, 1])
else:
if rk != 2:
raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave")
if rank != 2:
raise NotImplementedError(
f"rank(self)={rank} not implemented for repeat_interleave"
)
shape_x = None

ci = op.CumSum(repeats, [0])
Expand All @@ -7401,7 +7412,7 @@ def aten_repeat_interleave_Tensor(
)
indices = op.Reshape(srows, [-1])
values = op.GatherND(self, op.Unsqueeze(indices, [-1]))
if rk == 2:
if rank == 2:
return values
# shape_x is None at this stage.
assert shape_x is None # for mypy
Expand Down
27 changes: 27 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,33 @@ def forward(self, x, ind):
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_symbolic_tensor(self):
class Model(torch.nn.Module):
def forward(self, x, y):
return torch.repeat_interleave(x, y.shape[1], dim=1) * torch.repeat_interleave(
y, x.shape[1], dim=1
)

inputs = (
torch.arange(4, dtype=torch.float32).reshape((2, 2)),
torch.arange(6, dtype=torch.float32).reshape((2, 3)),
)
onnx_program = torch.onnx.export(
Model(),
inputs,
dynamo=True,
optimize=False,
)
onnx_program = torch.onnx.export(
Model(),
inputs,
input_names=["x", "y"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_sdpa_with_bool_attn_mask(self):
class ScaledDotProductAttention(torch.nn.Module):
def forward(self, query, key, value, attn_mask):
Expand Down