diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 6af84dd1d8..91c3308bc2 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -78,23 +78,34 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: return None -def get_singleton_value(val: ir.Value | None, rank: int | None = None): +def get_singleton_value(val: ir.Value | None, rank: int | Sequence[int] | None = None): """Returns element of a single element tensor constant value, and None otherwise. - If rank is specified, it checks that the value has the given rank. + If an int rank is specified, it checks that the value has the given rank. + If the rank is a sequence of ints, it checks that the value has one of the given ranks. + + Thus, `rank=0` checks for a scalar, `rank=1` checks for a 1D tensor, and + `rank=(0,1)` checks for either a scalar or a 1D tensor. """ np_val = get_numpy_value(val) if np_val is not None and np_val.size == 1: - if rank is None or (np_val.ndim == rank): - return np_val.item() + value = np_val.item() + if (rank is None) or (isinstance(rank, int) and (np_val.ndim == rank)): + return value + if isinstance(rank, Sequence) and (np_val.ndim in rank): + return value return None def is_singleton_value( - val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None + val: ir.Value | None, + expected: float | int | Callable, + *, + rtol: float | None = None, + rank: int | Sequence[int] | None = None, ) -> bool: """Returns True if the value is a single element tensor with given value, and False otherwise.""" - scalar = get_singleton_value(val) + scalar = get_singleton_value(val, rank=rank) if scalar is None: return False if callable(expected): diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py index 524b6f4806..b659afdbc0 100644 --- a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py @@ -43,19 +43,9 @@ def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2): def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() - def is_one(val): - """Check if val is a 0/1 dimensional tensor with a single element equal to 1.""" - np_val = _ir_utils.get_numpy_value(val) - return ( - np_val is not None - and np_val.size == 1 - and np_val.ndim <= 1 - and np_val.item() == 1 - ) - - if not is_one(one1): + if not _ir_utils.is_singleton_value(one1, 1): return check_result.fail("Unsqueeze axes is not [1]", one1) - if not is_one(one2): + if not _ir_utils.is_singleton_value(one2, 1): return check_result.fail("Unsqueeze axes is not [1]", one2) # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)