diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ab992e0580..88c353e777 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8616,12 +8616,11 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: return op.CastLike(self, other) -@torch_op("aten::unbind.int") +@torch_op("aten::unbind.int", trace_only=True) def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" - split_sizes = op.Constant(value_int=1) - return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False) + return op.SplitToSequence(self, axis=dim, keepdims=False) @torch_op("aten::unflatten.int", trace_only=True)