Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,27 @@ def get_dim(value: ir.Value | None, dim: int) -> ir.SymbolicDim | int | None:
if dim < 0 or dim >= shape.rank():
return None
return shape[dim]


def same_shape(shape1: ir.Shape | None, shape2: ir.Shape | None) -> bool:
"""Check if two shapes are semantically the same."""
if shape1 is None or shape2 is None:
return False

# If any dim is unknown, the shapes are not the same
if shape1.has_unknown_dim() or shape2.has_unknown_dim():
return False

return shape1 == shape2


def same_dim(dim1: ir.SymbolicDim | int, dim2: ir.SymbolicDim | int) -> bool:
"""Check if two dimensions are semantically the same."""
if type(dim1) is not type(dim2):
return False
if isinstance(dim1, int) and isinstance(dim2, int):
return dim1 == dim2
assert isinstance(dim1, ir.SymbolicDim) and isinstance(dim2, ir.SymbolicDim)
if dim1.value is None or dim2.value is None:
return False
return dim1.value == dim2.value
10 changes: 3 additions & 7 deletions onnxscript/rewriter/rules/common/_collapse_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging

from onnxscript import ir
from onnxscript.rewriter._ir_utils import is_singleton_value
from onnxscript.rewriter import _ir_utils
from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,14 +82,10 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_
if data.shape is None or slice_output.shape is None:
return False

if not is_singleton_value(steps, 1):
if not _ir_utils.is_singleton_value(steps, 1):
return False

# If any dim is unknown, the shapes are not the same
if data.shape.has_unknown_dim() or slice_output.shape.has_unknown_dim():
return False

return data.shape == slice_output.shape
return _ir_utils.same_shape(data.shape, slice_output.shape)


# Register the rewrite rules
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/rewriter/rules/common/_redundant_scatter_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import onnx_ir as ir

import onnxscript.rewriter
from onnxscript.rewriter import _ir_utils as ir_utils
from onnxscript.rewriter import _ir_utils
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet


Expand All @@ -41,7 +41,7 @@ def check(self, context, data, axis, transposed_data, **_):
# Check that updated-indices represent the full range of the first dimension of the transposed data.
# That is: check that the data.shape[axis] matches transposed_data.shape[0].
result = onnxscript.rewriter.MatchResult()
axis_value = ir_utils.get_singleton_value(axis)
axis_value = _ir_utils.get_singleton_value(axis)
if not isinstance(axis_value, int):
return result.fail("Axis value must be a constant integer.", axis)
shape: ir.Shape | None = data.shape
Expand All @@ -54,7 +54,7 @@ def check(self, context, data, axis, transposed_data, **_):
"Transposed data shape is not statically known.", transposed_data
)
actual_dim_value = transposed_data_shape[0]
if updated_dim_value != actual_dim_value:
if not _ir_utils.same_dim(updated_dim_value, actual_dim_value):
# The first dimension of the transposed data does not match the updated dimension,
# so we cannot apply this rule.
return result.fail(
Expand Down Expand Up @@ -87,7 +87,7 @@ def check(self, context, data, indices, updates, **_):
return result.fail("The value 'data' shape is not statically known.", data)
if updates.shape is None:
return result.fail("The value 'updates' shape is not statically known.", updates)
if data.shape != updates.shape:
if not _ir_utils.same_shape(data.shape, updates.shape):
return result.fail(
"The shape of 'data' and 'updates' are different.", [data, updates]
)
Expand Down
Loading