Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions onnxscript/rewriter/collapse_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,14 @@

def _potential_redundant_slice(op, data, starts, ends, axes, steps):
"""To identify a slice op"""
return op.Slice(data, starts, ends, axes, steps)
return op.Slice(data, starts, ends, axes, steps, _outputs=["slice_output"])


def _same_shape(op, data: ir.Value, slice_output: ir.Value, **_):
"""Check if the shape of the slice output is the same as the data."""
if data.shape is None or slice_output.shape is None:
return False

Check warning on line 82 in onnxscript/rewriter/collapse_slices.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/collapse_slices.py#L82

Added line #L82 was not covered by tests
return data.shape == slice_output.shape


# Register the rewrite rules
Expand All @@ -83,5 +90,13 @@
_check_if_redundant_slice,
)

# NOTE: The order of the rules is important. Larger pattern should be checked first.
rules = RewriteRuleSet([remove_redundant_slice])
remove_redundant_slice2 = RewriteRule(
_potential_redundant_slice,
_identity_to_itself,
_same_shape,
)

# NOTE: The second rule subsumes the first one. So, we may be able to remove the first one,
# provided shape-inference is run before the rewriter and computes the shape of the slice output.

rules = RewriteRuleSet([remove_redundant_slice, remove_redundant_slice2])
Loading