From 5baf374fc0fc4904cd1241948549980c5d9ead14 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 29 Sep 2025 17:57:14 -0700 Subject: [PATCH 1/3] Extract util for checking 0D or 1D value Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/_ir_utils.py | 19 +++++++++++++------ .../rules/fusion/_rotary_embedding.py | 18 ++++-------------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 6af84dd1d8..96868af636 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -78,23 +78,30 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: return None -def get_singleton_value(val: ir.Value | None, rank: int | None = None): +def get_singleton_value(val: ir.Value | None, rank: int | Sequence[int] | None = None): """Returns element of a single element tensor constant value, and None otherwise. - If rank is specified, it checks that the value has the given rank. + If an int rank is specified, it checks that the value has the given rank. + If the rank is a sequence of ints, it checks that the value has one of the given ranks. + + Thus, `rank=0` checks for a scalar, `rank=1` checks for a 1D tensor, and + `rank=(0,1)` checks for either a scalar or a 1D tensor. """ np_val = get_numpy_value(val) if np_val is not None and np_val.size == 1: - if rank is None or (np_val.ndim == rank): - return np_val.item() + value = np_val.item() + if (rank is None) or (isinstance(rank, int) and (np_val.ndim == rank)): + return value + if isinstance(rank, Sequence) and (np_val.ndim in rank): + return value return None def is_singleton_value( - val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None + val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None, rank: int | Sequence[int] | None = None ) -> bool: """Returns True if the value is a single element tensor with given value, and False otherwise.""" - scalar = get_singleton_value(val) + scalar = get_singleton_value(val, rank=rank) if scalar is None: return False if callable(expected): diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py index 524b6f4806..6e461a5ac0 100644 --- a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py @@ -43,20 +43,10 @@ def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2): def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() - def is_one(val): - """Check if val is a 0/1 dimensional tensor with a single element equal to 1.""" - np_val = _ir_utils.get_numpy_value(val) - return ( - np_val is not None - and np_val.size == 1 - and np_val.ndim <= 1 - and np_val.item() == 1 - ) - - if not is_one(one1): - return check_result.fail("Unsqueeze axes is not [1]", one1) - if not is_one(one2): - return check_result.fail("Unsqueeze axes is not [1]", one2) + if not _ir_utils.is_singleton_value(one1, 1, rank=(0, 1)): + return check_result.fail("Unsqueeze axes is not [1] or 1", one1) + if not _ir_utils.is_singleton_value(one2, 1, rank=(0, 1)): + return check_result.fail("Unsqueeze axes is not [1] or 1", one2) # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) if x is None or x.shape is None or len(x.shape) != 4: From 53869c7efc9c99d6fd1a787d246b726c3f8988a7 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 29 Sep 2025 18:01:20 -0700 Subject: [PATCH 2/3] Run lint Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/_ir_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 96868af636..91c3308bc2 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -98,7 +98,11 @@ def get_singleton_value(val: ir.Value | None, rank: int | Sequence[int] | None = def is_singleton_value( - val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None, rank: int | Sequence[int] | None = None + val: ir.Value | None, + expected: float | int | Callable, + *, + rtol: float | None = None, + rank: int | Sequence[int] | None = None, ) -> bool: """Returns True if the value is a single element tensor with given value, and False otherwise.""" scalar = get_singleton_value(val, rank=rank) From 051b30bdbe69f0298b4b0a730a02092ce6063358 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 30 Sep 2025 09:06:40 -0700 Subject: [PATCH 3/3] Simplify pattern Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/rules/fusion/_rotary_embedding.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py index 6e461a5ac0..b659afdbc0 100644 --- a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py @@ -43,10 +43,10 @@ def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2): def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() - if not _ir_utils.is_singleton_value(one1, 1, rank=(0, 1)): - return check_result.fail("Unsqueeze axes is not [1] or 1", one1) - if not _ir_utils.is_singleton_value(one2, 1, rank=(0, 1)): - return check_result.fail("Unsqueeze axes is not [1] or 1", one2) + if not _ir_utils.is_singleton_value(one1, 1): + return check_result.fail("Unsqueeze axes is not [1]", one1) + if not _ir_utils.is_singleton_value(one2, 1): + return check_result.fail("Unsqueeze axes is not [1]", one2) # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) if x is None or x.shape is None or len(x.shape) != 4: