diff --git a/onnxscript/rewriter/_basics.py b/onnxscript/rewriter/_basics.py index 6529bea627..8ea8a24bb3 100644 --- a/onnxscript/rewriter/_basics.py +++ b/onnxscript/rewriter/_basics.py @@ -16,6 +16,42 @@ import onnxscript.rewriter._rewrite_rule as _rewrite_rule +class MatchFailureInfo: + """Encapsulates information about a pattern match failure.""" + + def __init__( + self, + reason: str = "", + *failure_source: ir.Node | ir.Value, + ): + self.reason = reason + self.failure_sources: tuple[ir.Node | ir.Value, ...] = failure_source + assert all(isinstance(item, (ir.Node, ir.Value)) for item in failure_source), ( + f"All items in failure_source must be ir.Node or ir.Value, got {[type(item) for item in failure_source]}" + ) + + def __str__(self): + return f"MatchFailureInfo(reason={self.reason!r}, failure_sources={self.failure_sources!r})" + + +class MatchFailureError(MatchFailureInfo, Exception): + """Exception raised when a pattern match fails. + + This makes it easier to handle match failures in a compositional way, + for example, during the condition-checking phase of a pattern match. + It allows us to define utility functions without having to check for + and propagate match failures explicitly. + """ + + def __init__( + self, + reason: str = "", + *failure_source: ir.Node | ir.Value, + ): + MatchFailureInfo.__init__(self, reason, *failure_source) + Exception.__init__(self, reason) + + class MatchResult: """The state object used by the pattern-matching algorithm. diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index c8051f8199..b3f298a0f3 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -7,6 +7,7 @@ import onnxscript.ir as ir import onnxscript.ir.passes.common as common_passes from onnxscript.rewriter import pattern +from onnxscript.rewriter._basics import MatchFailureError Dim = Union[int, ir.SymbolicDim] @@ -24,6 +25,24 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) return True +def check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]): + if val.shape is None: + raise MatchFailureError(f"The shape of {val} is unknown.", val) + if val.shape.rank() != len(shape): + raise MatchFailureError( + f"The rank of {val} ({val.shape.rank()} does not match the expected rank {len(shape)}.", + val, + ) + for i, (actual, expected) in enumerate(zip(val.shape, shape)): + if expected not in bindings: + bindings[expected] = actual # type: ignore[assignment] + elif actual != bindings[expected]: + raise MatchFailureError( + f"Dimension {i} of {val} ({actual}) does not have expected size ({bindings[expected]}).", + val, + ) + + def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable: """ Apply the given fusion rules to the model and return the number of fusions applied. diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 90ab74d062..33f2aee8a5 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -174,7 +174,11 @@ def try_rewrite( if var.name is not None: if var.name not in match.bindings: match.bind(var.name, None) - check_match_result = self._condition_function(context, **match.bindings) + try: + check_match_result = self._condition_function(context, **match.bindings) + except _basics.MatchFailureError as e: + check_match_result = _basics.MatchResult() + check_match_result.fail(e.reason, list(e.failure_sources)) if not check_match_result: # If check function was provided, but it failed, return the reason for failure to the tracer. if isinstance(check_match_result, _basics.MatchResult): diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index a918616161..18d79d24d0 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -44,6 +44,7 @@ def __init__(self, *args, **kwargs): "num_heads must be divisible by kv_num_heads" ) self.num_groups = self.num_heads // self.kv_num_heads + self.total_seqlen = self.seqlen + self.past_seqlen # Abbreviations B = self.batchsize @@ -311,12 +312,24 @@ def test_fusion(self): onnx.TensorProto.FLOAT, ["B", self.seqlen, self.kv_num_heads, self.head_size], ) + key_transposed_value_info = onnx.helper.make_tensor_value_info( + "key_transposed", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.head_size, self.total_seqlen], + ) + value_BHSDh_value_info = onnx.helper.make_tensor_value_info( + "value_BHSDh", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.total_seqlen, self.head_size], + ) source_model.graph.value_info.extend( [ query_BHSDh_rope_value_info, key_BHkvSDh_rope_value_info, query_BSHDh_value_info, key_BSHkvDh_value_info, + key_transposed_value_info, + value_BHSDh_value_info, ] ) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index fa827e79aa..1ca4c3b1ff 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -3,33 +3,18 @@ from __future__ import annotations import math +from typing import Union + +import onnx_ir as ir from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern +from onnxscript.rewriter._basics import MatchFailureError + +Dim = Union[int, ir.SymbolicDim] class SDPA(pattern.RewriteRuleClassBase): - def __init__( - self, - name: str, - *, - use_mask: bool, - pre_scale: bool, - pre_scale_q: bool, - use_mul: bool, - has_3d_query: bool, - ): - super().__init__(name=name) - self._use_mask = use_mask - self._pre_scale = pre_scale - # There are some patterns where only the query is scaled before the dot product - # and essentially (query * qk_scale) * key is equivalent to (query * key) * qk_scale - # TODO: Capture patterns where only the key is scaled before the dot product - self._pre_scale_q = pre_scale_q - self._use_mul = use_mul - # Capture patterns where the query is reshaped from 3D to 4D - # after scaling has been applied to query. - self._has_3d_query = has_3d_query - self._scale: float | None = None + _scale: float | None def pattern( self, @@ -41,173 +26,126 @@ def pattern( query_scale, key_scale, qk_scale, - # Shape used for reshaping the query in patterns where query is reshaped - # from 3D to 4D and scaling is applied before the reshaping. - query_reshape, ): - if self._pre_scale: - # Some implementations scale the query and key before computing the dot product - if self._use_mul: - if self._pre_scale_q: - query = op.Mul(query, qk_scale) - else: - query = op.Mul(query, query_scale) - key_transposed = op.Mul(key_transposed, key_scale) - else: - if self._pre_scale_q: - query = op.Div(query, qk_scale) - else: - query = op.Div(query, query_scale) - key_transposed = op.Div(key_transposed, key_scale) - - # There might be patterns where the reshape and transpose are done - # after the pre-scaling. If the inputs are 3D, we need to reshape them to 4D - # and apply the approriate transposes to query. - if self._has_3d_query and self._pre_scale_q: - # Reshape and transpose 3D input of shape (B, S, D) - # to 4D input of shape (B, N, S, H) - queryBNSH = op.Reshape(query, query_reshape) - query = op.Transpose(queryBNSH, perm=[0, 2, 1, 3]) + # Some implementations scale the query and key before computing the dot product + query = pattern.OrValue( + [ + op.Mul(query, query_scale), + op.Div(query, query_scale), + query, + ], + tag_var="query_scaling", + tag_values=["Mul", "Div", "None"], + ) + key_transposed = pattern.OrValue( + [ + op.Mul(key_transposed, key_scale), + op.Div(key_transposed, key_scale), + key_transposed, + ], + tag_var="key_scaling", + tag_values=["Mul", "Div", "None"], + ) attn_score = op.MatMul(query, key_transposed) - if not self._pre_scale: - # Some implementations scale the dot product. - if self._use_mul: - attn_score = op.Mul(attn_score, qk_scale) - else: - attn_score = op.Div(attn_score, qk_scale) - if self._use_mask: - # Some implementations add a mask to the dot product. - attn_score = op.Add(attn_score, mask) + + # Some implementations scale the dot product. + attn_score = pattern.OrValue( + [ + op.Mul(attn_score, qk_scale), + op.Div(attn_score, qk_scale), + attn_score, + ], + tag_var="qk_scaling", + tag_values=["Mul", "Div", "None"], + ) + + # Some implementations add a mask to the dot product. + masked_attn_score = op.Add(attn_score, mask) + attn_score = pattern.OrValue( + [masked_attn_score, attn_score], tag_var="has_mask", tag_values=[True, False] + ) + attn_weight = op.Softmax(attn_score, axis=-1) attn_output = op.MatMul(attn_weight, value) return attn_output def check( - self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale, **_ + self, + context, + query: ir.Value | None, + key_transposed: ir.Value | None, + value: ir.Value | None, + mask: ir.Value | None, + **match_bindings, ): check_result = pattern.MatchResult() - # Check that the scaling factors match what SDPA implements: - - # We need to know the hidden size to check the scaling factors. - if query is None or query.shape is None or len(query.shape) < 2: - return check_result.fail( - "Query shape is not known or has less than 2 dimensions.", query - ) - hidden_size = query.shape[-1] - if not isinstance(hidden_size, int): - return check_result.fail("Hidden size is not an integer.") - - expected_scaling_factor = math.sqrt(hidden_size) - if self._use_mul: - expected_scaling_factor = 1.0 / expected_scaling_factor - - if self._pre_scale and not self._pre_scale_q: - # Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor) - # If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used. - sqrt_scaling_factor = math.sqrt(expected_scaling_factor) - # Calculate the scaling factor for query - if (query_scale_value := _ir_utils.get_singleton_value(query_scale)) is None: - return check_result.fail( - "Query scale is not a scalar.", - query_scale, - ) - # Ensure the scaling factor for key is the same as for query - if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None: - return check_result.fail( - "Key scale is not a scalar.", - key_scale, - ) - if not math.isclose(query_scale_value, key_scale_value, rel_tol=1e-3): - return check_result.fail( - "Query and key scales are not equal.", - query_scale, - ) - if not math.isclose(query_scale_value, sqrt_scaling_factor, rel_tol=1e-3): - self._scale = query_scale_value * query_scale_value - else: - # Pass no scaling factor to SDPA, SDPA will use the default scaling factor - self._scale = None - else: - # Check if qk_scale is a scalar == expected_scaling_factor) - # If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used - if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale)) is None: - return check_result.fail( - "QK scale is not a scalar.", - qk_scale, - ) - if not math.isclose(qk_scale_value, expected_scaling_factor, rel_tol=1e-3): - self._scale = qk_scale_value + + bindings: dict[str, Dim] = {} + + # Check that query/key/value have the expected shapes: + # They all should have same batch-size (B) and number of heads (H). Conceptually, it is + # different for Q and K/V, but the certain op implementations require them to be the same, + # which is usually achieved via tiling/expanding K/V num-heads to match Q num-heads. + # Query and Key should have same head-size (Dh) while value can have different head-size (Dv). + # Key and Value should have same sequence length (Skv), while Query can have different sequence length (S). + _fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"]) + _fusion_utils.check_shape(bindings, key_transposed, ["B", "H", "Dh", "Skv"]) + _fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"]) + + def get_scale_value(tag_name: str, scale_name: str) -> float: + scaling_type = match_bindings.get(tag_name, "None") + if scaling_type == "None": + return 1.0 else: - # Pass no scaling factor to SDPA, SDPA will use the default scaling factor - self._scale = None + scale = match_bindings.get(scale_name) + value = _ir_utils.get_singleton_value(scale) + if value is None: + raise MatchFailureError(f"{scale_name} is not a scalar.", scale) + if scaling_type == "Mul": + return value + else: + assert scaling_type == "Div", f"Unexpected {scale_name} scaling operation" + return 1.0 / value + + query_scale_value = get_scale_value("query_scaling", "query_scale") + key_scale_value = get_scale_value("key_scaling", "key_scale") + qk_scale_value = get_scale_value("qk_scaling", "qk_scale") + + self._scale = query_scale_value * key_scale_value * qk_scale_value + + # If the scaling factor is the default one, we can skip passing it to SDPA. + + head_size = bindings["Dh"] + if not isinstance(head_size, int): + return check_result - # check ranks/shapes + default_scaling_factor = 1.0 / math.sqrt(head_size) + + if math.isclose(self._scale, default_scaling_factor, rel_tol=1e-5, abs_tol=1e-8): + # Pass no scaling factor to SDPA, SDPA will use the default scaling factor + self._scale = None return check_result def rewrite( self, op, - query, - key_transposed, - value, - mask, - query_scale, - key_scale, - qk_scale, - query_reshape=None, + query: ir.Value | None, + key_transposed: ir.Value | None, + value: ir.Value | None, + mask: ir.Value | None, **_, ): - if self._pre_scale and self._pre_scale_q: - if self._use_mul: - query_mul = op.Mul(query, qk_scale) - else: - query_mul = op.Div(query, qk_scale) - # Reshape and transpose 3D input of shape (B, S, D) - # to 4D input of shape (B, N, S, H) - if self._has_3d_query: - queryBNSH = op.Reshape(query_mul, query_reshape) - query = op.Transpose(queryBNSH, perm=[0, 2, 1, 3]) - else: - query = query_mul - sdpa_args = [query, key_transposed, value] - if self._use_mask: + if mask is not None: sdpa_args.append(mask) + # If the scale is None, SDPA will use the default scaling factor, which is 1/sqrt(head_size). return op.SDPA(*sdpa_args, scale=self._scale, _domain="ai.onnxruntime.fusion") -parameter_combinations = [ - { - "name": f"sdpa_{'masked_' if use_mask else 'unmasked_'}{'pre_' if pre_scale else 'post_'}{'only_q_' if pre_scale_q else ''}{'mul' if use_mul else 'div'}{'_3d_query' if has_3d_query else ''}", - "use_mask": use_mask, - "pre_scale": pre_scale, - "pre_scale_q": pre_scale_q, - "use_mul": use_mul, - "has_3d_query": has_3d_query, - } - for use_mask in [False, True] - for pre_scale in [False, True] - for pre_scale_q in [False, True] - for use_mul in [False, True] - for has_3d_query in [False, True] -] - # Dynamically create the rules -sdpa_rules = pattern.RewriteRuleSet( - [ - SDPA.rule( - params["name"], - use_mask=params["use_mask"], - pre_scale=params["pre_scale"], - pre_scale_q=params["pre_scale_q"], - use_mul=params["use_mul"], - has_3d_query=params["has_3d_query"], - ) - for params in parameter_combinations - ] -) +sdpa_rules = pattern.RewriteRuleSet([SDPA.rule()]) fuse_sdpa = _fusion_utils.apply_fusion_rules(sdpa_rules) diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index 74c718147f..88eec4fe5d 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -26,7 +26,12 @@ MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR) SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR) -CUSTOM_SCALE_FACTOR = 2.0 +# Custom scale factors for testing +CUSTOM_SCALE_FACTOR = 1.0 / math.sqrt(80) +CUSTOM_MUL_SCALE_FACTOR = CUSTOM_SCALE_FACTOR +CUSTOM_DIV_SCALE_FACTOR = 1.0 / CUSTOM_SCALE_FACTOR +SQRT_CUSTOM_MUL_SCALE_FACTOR = math.sqrt(CUSTOM_MUL_SCALE_FACTOR) +SQRT_CUSTOM_DIV_SCALE_FACTOR = math.sqrt(CUSTOM_DIV_SCALE_FACTOR) @script() @@ -78,7 +83,7 @@ def _unmasked_post_mul_sdpa_script(query, key, value): @script() def _custom_scale_pre_div_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + divisor = op.Constant(value_float=SQRT_CUSTOM_DIV_SCALE_FACTOR) scaled_query = op.Div(query, divisor) scaled_key = op.Div(key_transposed, divisor) attn_score = op.MatMul(scaled_query, scaled_key) @@ -90,7 +95,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value): @script() def _custom_scale_pre_mul_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + multiplier = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) scaled_query = op.Mul(query, multiplier) scaled_key = op.Mul(key_transposed, multiplier) attn_score = op.MatMul(scaled_query, scaled_key) @@ -102,8 +107,8 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value): @script() def _custom_multi_scale_pre_mul_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - multiplier_q = op.Constant(value_float=CUSTOM_SCALE_FACTOR) - multiplier_k = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + multiplier_q = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) + multiplier_k = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) scaled_query = op.Mul(query, multiplier_q) scaled_key = op.Mul(key_transposed, multiplier_k) attn_score = op.MatMul(scaled_query, scaled_key) @@ -115,7 +120,7 @@ def _custom_multi_scale_pre_mul_sdpa_script(query, key, value): @script() def _custom_scale_post_div_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + divisor = op.Constant(value_float=CUSTOM_DIV_SCALE_FACTOR) attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Div(attn_score, divisor) attn_weight = op.Softmax(scaled_attn_score, axis=-1) @@ -126,7 +131,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value): @script() def _custom_scale_post_mul_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + multiplier = op.Constant(value_float=CUSTOM_MUL_SCALE_FACTOR) attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Mul(attn_score, multiplier) attn_weight = op.Softmax(scaled_attn_score, axis=-1) @@ -187,7 +192,7 @@ def _masked_post_mul_sdpa_script(query, key, value, mask): @script() def _custom_scale_pre_div_sdpa_script(query, key, value, mask): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + divisor = op.Constant(value_float=SQRT_CUSTOM_DIV_SCALE_FACTOR) scaled_query = op.Div(query, divisor) scaled_key = op.Div(key_transposed, divisor) attn_score = op.MatMul(scaled_query, scaled_key) @@ -200,7 +205,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value, mask): @script() def _custom_scale_pre_mul_sdpa_script(query, key, value, mask): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + multiplier = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) scaled_query = op.Mul(query, multiplier) scaled_key = op.Mul(key_transposed, multiplier) attn_score = op.MatMul(scaled_query, scaled_key) @@ -213,7 +218,7 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value, mask): @script() def _custom_scale_post_div_sdpa_script(query, key, value, mask): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + divisor = op.Constant(value_float=CUSTOM_DIV_SCALE_FACTOR) attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Div(attn_score, divisor) masked_attn_score = op.Add(scaled_attn_score, mask) @@ -225,7 +230,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value, mask): @script() def _custom_scale_post_mul_sdpa_script(query, key, value, mask): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + multiplier = op.Constant(value_float=CUSTOM_MUL_SCALE_FACTOR) attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Mul(attn_score, multiplier) masked_attn_score = op.Add(scaled_attn_score, mask) @@ -260,6 +265,34 @@ def get_ort_inputs(self): return self._ort_inputs +class InvalidSDPATestCase: + def __init__(self, script_func): + self.script_func = script_func + + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + qk_type = FLOAT[B, N, S, H] + # We broadcast value in the batch dimension, which is not supported by SDPA fusion + v_type = FLOAT[1, N, S, H] + mask_type = FLOAT[B, N, S, S] + model_proto = self.script_func.to_model_proto( + input_types=[qk_type, qk_type, v_type, mask_type], output_types=[qk_type] + ) + self._onnx_model = ir.serde.deserialize_model(model_proto) + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "query": numpy.random.rand(B, N, S, H).astype(numpy.float32), + "key": numpy.random.rand(B, N, S, H).astype(numpy.float32), + "value": numpy.random.rand(1, N, S, H).astype(numpy.float32), + "mask": numpy.random.rand(B, N, S, S).astype(numpy.float32), + } + self._ort_inputs = inputs + return self._ort_inputs + + class TestSDPAFusion(unittest.TestCase): @parameterized.parameterized.expand( [ @@ -307,11 +340,7 @@ def test_sdpa_fusion(self, name, script_func): if "custom" in name: self.assertIsNotNone(sdpa_node.attributes.get("scale")) scale_factor = sdpa_node.attributes["scale"].value - self.assertIsNotNone(scale_factor) - if "pre" in name: - self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR * CUSTOM_SCALE_FACTOR) - elif "post" in name: - self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR) + self.assertAlmostEqual(scale_factor, CUSTOM_SCALE_FACTOR, delta=1e-8) else: # These tests are for the default scaling factors, no scale factor is passed to SDPA # pattern rewriting check functions should be sufficient to check if expected value @@ -321,6 +350,13 @@ def test_sdpa_fusion(self, name, script_func): # new_outputs = ort_run("optimized", model, inputs) # assert_allclose(new_outputs, original_outputs) + def test_invalid_sdpa_fusion_value_batch_dim(self): + test_case = InvalidSDPATestCase(_masked_pre_mul_sdpa_script) + model = test_case.get_onnx_model() + onnxscript.optimizer.optimize(model) + count = fuse_sdpa(model) + self.assertEqual(count, 0) + if __name__ == "__main__": unittest.main()