Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,34 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
return None


def get_singleton_value(val: ir.Value | None, rank: int | None = None):
def get_singleton_value(val: ir.Value | None, rank: int | Sequence[int] | None = None):
"""Returns element of a single element tensor constant value, and None otherwise.

If rank is specified, it checks that the value has the given rank.
If an int rank is specified, it checks that the value has the given rank.
If the rank is a sequence of ints, it checks that the value has one of the given ranks.

Thus, `rank=0` checks for a scalar, `rank=1` checks for a 1D tensor, and
`rank=(0,1)` checks for either a scalar or a 1D tensor.
"""
np_val = get_numpy_value(val)
if np_val is not None and np_val.size == 1:
if rank is None or (np_val.ndim == rank):
return np_val.item()
value = np_val.item()
if (rank is None) or (isinstance(rank, int) and (np_val.ndim == rank)):
return value
if isinstance(rank, Sequence) and (np_val.ndim in rank):
return value
return None


def is_singleton_value(
val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None
val: ir.Value | None,
expected: float | int | Callable,
*,
rtol: float | None = None,
rank: int | Sequence[int] | None = None,
) -> bool:
"""Returns True if the value is a single element tensor with given value, and False otherwise."""
scalar = get_singleton_value(val)
scalar = get_singleton_value(val, rank=rank)
if scalar is None:
return False
if callable(expected):
Expand Down
14 changes: 2 additions & 12 deletions onnxscript/rewriter/rules/fusion/_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,9 @@ def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2):
def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined]
check_result = pattern.MatchResult()

def is_one(val):
"""Check if val is a 0/1 dimensional tensor with a single element equal to 1."""
np_val = _ir_utils.get_numpy_value(val)
return (
np_val is not None
and np_val.size == 1
and np_val.ndim <= 1
and np_val.item() == 1
)

if not is_one(one1):
if not _ir_utils.is_singleton_value(one1, 1):
return check_result.fail("Unsqueeze axes is not [1]", one1)
if not is_one(one2):
if not _ir_utils.is_singleton_value(one2, 1):
return check_result.fail("Unsqueeze axes is not [1]", one2)

# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
Expand Down
Loading