Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7744,6 +7744,10 @@ def aten_scatter_src(
src: TTensor,
) -> TTensor:
"""scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
if len(index.shape) == 0:
index = op.Unsqueeze(index, [0])
if len(src.shape) == 0:
src = op.Unsqueeze(src, [0])
return op.ScatterElements(self, index, src, axis=dim)


Expand All @@ -7756,6 +7760,8 @@ def aten_scatter_value(
) -> TTensor:
"""scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"""
# Ensure value is a scalar tensor and expand it to match index shape
if len(index.shape) == 0:
index = op.Unsqueeze(index, [0])
scalar_tensor = ir.tensor([value], dtype=self.dtype)
src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor)
return op.ScatterElements(self, index, src, axis=dim)
Expand Down
44 changes: 44 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,35 @@ def sample_inputs_scatter_src(op_info, device, dtype, requires_grad, **kwargs):
src_tensor = make_arg(src_shape)
yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, src_tensor))

# Additional test cases for scalar and single-element tensor combinations with dim=0
# Test case: scalar index, scalar src (dim_size=5)
dim_size = 5
data_1d = make_arg((dim_size,))
valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long)
scalar_src = make_arg(())
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, scalar_src))

# Test case: single-element tensor index, scalar src (dim_size=7)
dim_size = 7
data_1d = make_arg((dim_size,))
valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long)
scalar_src = make_arg(())
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, scalar_src))

# Test case: scalar index, single-element tensor src (dim_size=3)
dim_size = 3
data_1d = make_arg((dim_size,))
valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long)
src_1d = make_arg((1,))
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, src_1d))

# Test case: single-element tensor index, single-element tensor src (dim_size=10)
dim_size = 10
data_1d = make_arg((dim_size,))
valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long)
src_1d = make_arg((1,))
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, src_1d))


def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs):
del op_info
Expand Down Expand Up @@ -1423,6 +1452,21 @@ def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs)
]
yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, value))

# Additional test cases for scalar and single-element tensor combinations with dim=0
# Test case: scalar index with scalar value (dim_size=6, value_type=torch.long)
dim_size = 6
data_1d = make_arg((dim_size,))
valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long)
random_value = torch.randint(0, 10, (), device=device, dtype=torch.long).item()
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, random_value))

# Test case: single-element tensor index with scalar value (dim_size=8, value_type=torch.float)
dim_size = 8
data_1d = make_arg((dim_size,))
valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long)
random_value = torch.rand((), device=device, dtype=torch.float).item()
yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, random_value))


def sample_inputs__scaled_dot_product_flash_attention(
op_info, device, dtype, requires_grad, **kwargs
Expand Down
Loading