Skip to content
15 changes: 12 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7327,9 +7327,18 @@ def aten_repeat_interleave_self_int(
self_rank = len(self.shape)
pos_dim = (dim + self_rank) % self_rank
unsqueezed = op.Unsqueeze(self, [pos_dim + 1])
tiles = [1] * (self_rank + 1)
tiles[pos_dim + 1] = repeats
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
if isinstance(repeats, int):
tiles = [1] * (self_rank + 1)
tiles[pos_dim + 1] = repeats
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
else:
# repeats is a symbolic tensor
tile_repeat = op.Concat(
op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)),
op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))),
op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)),
axis=0,
)
tiled = op.Tile(unsqueezed, tile_repeat)
if self_rank == 1:
return op.Identity(tiled)
Expand Down
Loading