Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7738,26 +7738,26 @@ def aten_scalar_tensor_sym_number(

@torch_op("aten::scatter.src", trace_only=True)
def aten_scatter_src(
self: TReal,
self: TTensor,
dim: int, # we have to use int here because ScatterElements() will use this attribute
index: TInt,
src: TReal,
) -> TReal:
src: TTensor,
) -> TTensor:
"""scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
return op.ScatterElements(self, index, src, axis=dim)


@torch_op("aten::scatter.value", trace_only=True)
def aten_scatter_value(
self: TReal,
self: TTensor,
dim: int, # we have to use int here because ScatterElements() will use this attribute
index: TInt,
value: TReal,
) -> TReal:
value: float,
) -> TTensor:
"""scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"""
# Ensure value is a scalar tensor and expand it to match index shape
scalar_tensor = op.CastLike(value, self)
src = op.Expand(scalar_tensor, op.Shape(index))
scalar_tensor = ir.tensor([value], dtype=self.dtype)
src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor)
return op.ScatterElements(self, index, src, axis=dim)


Expand Down
4 changes: 2 additions & 2 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,9 +1407,9 @@ def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs)
# (self_shape, index_shape, dim, value)
((5, 5), (2, 3), 0, 1.0), # 2D scatter on dim=0 with scalar value
((5, 5), (3, 2), 1, -2.5), # 2D scatter on dim=1 with scalar value
((3, 4, 5), (2, 2, 3), 0, 0.0), # 3D scatter on dim=0 with scalar value
((3, 4, 5), (2, 2, 3), 0, False), # 3D scatter on dim=0 with scalar value
((3, 4, 5), (2, 2, 3), 1, 3.14), # 3D scatter on dim=1 with scalar value
((3, 4, 5), (2, 2, 3), 2, -1.0), # 3D scatter on dim=2 with scalar value
((3, 4, 5), (2, 2, 3), 2, -1), # 3D scatter on dim=2 with scalar value
((10,), (3,), 0, 5.0), # 1D scatter with scalar value
]

Expand Down
Loading