Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions onnxscript/rewriter/redundant_scatter_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
from onnxscript.rewriter import pattern as orp


def fail(*args):
return onnxscript.rewriter.MatchResult().fail(*args)


class ScatterAllDynamic(orp.RewriteRuleClassBase):
def pattern(self, op, data, axis, transposed_data, updates):
# Construct update-indices spanning an entire axis:
Expand All @@ -41,24 +37,24 @@ def pattern(self, op, data, axis, transposed_data, updates):
def check(self, context, data, axis, transposed_data, **_):
# Check that updated-indices represent the full range of the first dimension of the transposed data.
# That is: check that the data.shape[axis] matches transposed_data.shape[0].
result = onnxscript.rewriter.MatchResult()
axis_value = ir_utils.get_singleton_value(axis)
if not isinstance(axis_value, int):
return fail("Axis value must be a constant integer.", axis)
return result.fail("Axis value must be a constant integer.", axis)
shape: ir.Shape | None = data.shape
if shape is None:
return fail("Data shape is not statically known.", data)
return result.fail("Data shape is not statically known.", data)
updated_dim_value = shape[axis_value]
transposed_data_shape: ir.Shape | None = transposed_data.shape
if transposed_data_shape is None:
return fail("Transposed data shape is not statically known.", transposed_data)
return result.fail("Transposed data shape is not statically known.", transposed_data)
actual_dim_value = transposed_data_shape[0]
if updated_dim_value != actual_dim_value:
# The first dimension of the transposed data does not match the updated dimension,
# so we cannot apply this rule.
return fail(
return result.fail(
"The first dimension of the transposed data does not match the updated dimension.",
data,
transposed_data,
[data, transposed_data],
)
return True

Expand All @@ -81,20 +77,21 @@ def check(self, context, data, indices, updates, **_):
"""Check if the ScatterND is redundant due to static indices covering entire tensor."""
# To validate data can be replaced directly by updates, we need to check the following:
# 1. they have the same shape
result = onnxscript.rewriter.MatchResult()
if data.shape is None:
return fail("The value 'data' shape is not statically known.", data)
return result.fail("The value 'data' shape is not statically known.", data)
if updates.shape is None:
return fail("The value 'updates' shape is not statically known.", updates)
return result.fail("The value 'updates' shape is not statically known.", updates)
if data.shape != updates.shape:
return fail("The shape of 'data' and 'updates' are different.", data, updates)
return result.fail("The shape of 'data' and 'updates' are different.", [data, updates])

# 2. the indices is referring to the whole data, which is from 0 to data.shape[0]
if indices.const_value is None:
return fail("The value 'indices' is not statically known.", indices)
return result.fail("The value 'indices' is not statically known.", indices)
expected_indices = [[i] for i in range(data.shape[0])]
actual_indices = indices.const_value.numpy().tolist()
if actual_indices != expected_indices:
return fail("The 'indices' is not referring to the whole data.", indices)
return result.fail("The 'indices' is not referring to the whole data.", indices)

return True

Expand Down
Loading