Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions examples/pattern_matching_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Example demonstrating the new pattern matching functionality."""

import onnx.parser
from onnxscript import ir
from onnxscript.rewriter import pattern


def example_standalone_pattern_matching():
"""Example showing how to use PatternImpl for standalone pattern matching."""

print("=== Standalone Pattern Matching Example ===")

# Define a pattern that matches Identity nodes
def identity_pattern(op, x):
return op.Identity(x)

# Create a PatternImpl for standalone pattern matching (no replacement)
pattern_matcher = pattern.PatternImpl(identity_pattern, name="IdentityMatcher")

# Create a model with an Identity node
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z)
{
z = Identity(x)
}
"""
)
model = ir.serde.deserialize_model(model_proto)

# Find nodes to test pattern matching against
for node in model.graph:
print(f"Testing pattern against {node.op_type} node...")
match_result = pattern_matcher.match(model, model.graph, node)

if match_result is not None:
print(f" ✓ Pattern matched! Found {len(match_result.nodes)} nodes in match.")
print(f" Matched node: {match_result.nodes[0].op_type}")
else:
print(f" ✗ Pattern did not match {node.op_type} node.")


def example_class_based_pattern():
"""Example showing how to use PatternBase for class-based pattern definition."""

print("\n=== Class-Based Pattern Example ===")

class IdentityPatternClass(pattern.PatternBase):
"""A class-based pattern that matches Identity nodes."""

def pattern(self, op, x):
return op.Identity(x)

def check(self, context, x):
"""Custom condition - always succeeds for this example."""
print(f" Checking condition for input: {x}")
return pattern.MatchResult() # Always succeeds

# Create an instance of the pattern class
identity_pattern_class = IdentityPatternClass(name="ClassBasedIdentity")

# Create a PatternImpl from the class
pattern_impl = identity_pattern_class.create_pattern_impl()

print(f"Created pattern matcher: {pattern_impl.name}")

# Use it like any other PatternImpl
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z)
{
z = Identity(x)
}
"""
)
model = ir.serde.deserialize_model(model_proto)

for node in model.graph:
if node.op_type == "Identity":
print(f"Testing class-based pattern against {node.op_type} node...")
match_result = pattern_impl.match(model, model.graph, node)

if match_result is not None:
print(f" ✓ Class-based pattern matched!")
else:
print(f" ✗ Class-based pattern did not match.")


def example_rewrite_rule_still_works():
"""Example showing that existing RewriteRule functionality is preserved."""

print("\n=== Existing RewriteRule Still Works ===")

def identity_pattern(op, x):
return op.Identity(x)

def identity_replacement(op, x):
return op.Identity(x) # No-op replacement

# Create a RewriteRule (which now inherits from PatternImpl)
rule = pattern.RewriteRule(identity_pattern, identity_replacement, name="IdentityRule")

print(f"Created rewrite rule: {rule.name}")
print(f"Rule is also a PatternImpl: {isinstance(rule, pattern.PatternImpl)}")

# The rule can be used both for pattern matching and rewriting
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z)
{
z = Identity(x)
}
"""
)
model = ir.serde.deserialize_model(model_proto)

# Use it for just pattern matching (inherited from PatternImpl)
for node in model.graph:
if node.op_type == "Identity":
print(f"Using RewriteRule for pattern matching on {node.op_type}...")
match_result = rule.match(model, model.graph, node)

if match_result is not None:
print(f" ✓ RewriteRule matched as a pattern matcher!")

# Use it for rewriting (original functionality)
print(f"Using RewriteRule for rewriting...")
count = rule.apply_to_model(model)
print(f" Applied rule {count} times")


if __name__ == "__main__":
example_standalone_pattern_matching()
example_class_based_pattern()
example_rewrite_rule_still_works()
print("\n=== All Examples Completed ===")
Loading