From 88adde4d329b043772199d2651296c6857774bd5 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 13 Mar 2025 12:26:28 -0700 Subject: [PATCH 01/28] A couple of MHA extensions --- onnxscript/rewriter/ort_fusions/mha.py | 52 +++++++++++++++------ onnxscript/rewriter/ort_fusions/mha_test.py | 2 +- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index aa3d801a08..9db0118a05 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -45,8 +45,9 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) class MultiHeadAttention(pattern.RewriteRuleClassBase): - def __init__(self): - super().__init__("MHA") + def __init__(self, name, *, transpose_4d: bool): + super().__init__(name) + self._transpose_4d = transpose_4d def pattern( self, @@ -93,11 +94,25 @@ def pattern( # Transpose from (B, S, H, D/H) to (B, H, S, D/H) value_BHSDh = op.Transpose(value_BSHDh, perm=[0, 2, 1, 3]) + # This is workaround for examples where there is a duplication of Unsqueeze op + # to generate a 2D positions-ids from a 1D position-ids. This can be eliminated + # if we have CSE-optimization to eliminate the duplicate Unsqueeze ops. + # For now, same flag (transpose_4d) controls this variation. A different flag + # can be added if we see instances that mix the two. + if self._transpose_4d: + position_ids_q = op.Unsqueeze(position_ids, [0]) + position_ids_k = op.Unsqueeze(position_ids, [0]) + else: + + position_ids_q = position_ids + position_ids_k = position_ids + query_BHSDh_rope = op.RotaryEmbedding( - query_BHSDh, position_ids, cos, sin, _domain="com.microsoft" + query_BHSDh, position_ids_q, cos, sin, _domain="com.microsoft" ) + key_BHSDh_rope = op.RotaryEmbedding( - key_BHSDh, position_ids, cos, sin, _domain="com.microsoft" + key_BHSDh, position_ids_k, cos, sin, _domain="com.microsoft" ) # Concatenate past_key cache and current key, and transpose to enable @@ -105,13 +120,17 @@ def pattern( key_seq = op.Concat(past_key, key_BHSDh_rope, axis=-2) # Transpose last two axes of key_seq to compute dot-product via matmul. - key_seq_BH_Skv_Dh = op.Reshape( - key_seq, _allow_other_inputs=True, _outputs=["key_seq_BH_Skv_Dh"] - ) - key_seq_BH_Dh_Skv = op.Transpose(key_seq_BH_Skv_Dh, perm=[0, 2, 1]) - key_seq_B_H_Dh_Skv = op.Reshape( - key_seq_BH_Dh_Skv, _allow_other_inputs=True, _outputs=["key_seq_B_H_Dh_Skv"] - ) + if self._transpose_4d: + key_seq_B_H_Dh_Skv = op.Transpose(key_seq, perm=[0, 1, 3, 2]) + else: + # Transpose after converting to 3D + key_seq_BH_Skv_Dh = op.Reshape( + key_seq, _allow_other_inputs=True, _outputs=["key_seq_BH_Skv_Dh"] + ) + key_seq_BH_Dh_Skv = op.Transpose(key_seq_BH_Skv_Dh, perm=[0, 2, 1]) + key_seq_B_H_Dh_Skv = op.Reshape( + key_seq_BH_Dh_Skv, _allow_other_inputs=True, _outputs=["key_seq_B_H_Dh_Skv"] + ) # Concatenate past_value cache and current value value_seq = op.Concat(past_value, value_BHSDh, axis=-2) @@ -198,11 +217,13 @@ def rewrite( # Switch to 3D RotaryEmbedding # TODO: forward other attributes + zero_1d = op.Constant(value_ints=[0]) + position_ids_2d = op.Unsqueeze(position_ids, zero_1d) query_BSD_rope = op.RotaryEmbedding( - query_BSD, position_ids, cos, sin, _domain="com.microsoft" + query_BSD, position_ids_2d, cos, sin, _domain="com.microsoft" ) key_BSD_rope = op.RotaryEmbedding( - key_BSD, position_ids, cos, sin, _domain="com.microsoft" + key_BSD, position_ids_2d, cos, sin, _domain="com.microsoft" ) return op.MultiHeadAttention( @@ -220,9 +241,10 @@ def rewrite( ) -_rule1 = MultiHeadAttention.rule() +_mha_4d_transpose = MultiHeadAttention.rule("MHA_3D_Transpose", transpose_4d=True) +_mha_3d_transpose = MultiHeadAttention.rule("MHA_3D_Transpose", transpose_4d=False) -mha_rules = pattern.RewriteRuleSet([_rule1]) +mha_rules = pattern.RewriteRuleSet([_mha_4d_transpose, _mha_3d_transpose]) def fuse_mha(model: ir.Model, *, debug: bool = False) -> int: diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py index eeefa187ca..41e473fba1 100644 --- a/onnxscript/rewriter/ort_fusions/mha_test.py +++ b/onnxscript/rewriter/ort_fusions/mha_test.py @@ -32,7 +32,7 @@ def test_smollm(self): # Fuse SDPA and MHA sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) - mha_count = xformers.fuse_mha(model) + mha_count = xformers.fuse_mha(model, debug=True) self.assertGreater(mha_count, 0) if test_with_ort: From c4c1f71b0967312f62e595e6b778a29bea6fd21a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 13 Mar 2025 12:26:53 -0700 Subject: [PATCH 02/28] Run lint --- onnxscript/rewriter/ort_fusions/mha.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 9db0118a05..95d831d750 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -103,14 +103,13 @@ def pattern( position_ids_q = op.Unsqueeze(position_ids, [0]) position_ids_k = op.Unsqueeze(position_ids, [0]) else: - position_ids_q = position_ids position_ids_k = position_ids query_BHSDh_rope = op.RotaryEmbedding( query_BHSDh, position_ids_q, cos, sin, _domain="com.microsoft" ) - + key_BHSDh_rope = op.RotaryEmbedding( key_BHSDh, position_ids_k, cos, sin, _domain="com.microsoft" ) From 0425174daa71f3fcb1fe0bed45e73d2a8c1662e9 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 13 Mar 2025 12:46:43 -0700 Subject: [PATCH 03/28] Minor fixes --- onnxscript/rewriter/ort_fusions/mha.py | 7 ++++--- onnxscript/rewriter/ort_fusions/mha_test.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 95d831d750..c3c8f50bf9 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -217,12 +217,13 @@ def rewrite( # Switch to 3D RotaryEmbedding # TODO: forward other attributes zero_1d = op.Constant(value_ints=[0]) - position_ids_2d = op.Unsqueeze(position_ids, zero_1d) + if self._transpose_4d: + position_ids = op.Unsqueeze(position_ids, zero_1d) query_BSD_rope = op.RotaryEmbedding( - query_BSD, position_ids_2d, cos, sin, _domain="com.microsoft" + query_BSD, position_ids, cos, sin, _domain="com.microsoft" ) key_BSD_rope = op.RotaryEmbedding( - key_BSD, position_ids_2d, cos, sin, _domain="com.microsoft" + key_BSD, position_ids, cos, sin, _domain="com.microsoft" ) return op.MultiHeadAttention( diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py index 41e473fba1..eeefa187ca 100644 --- a/onnxscript/rewriter/ort_fusions/mha_test.py +++ b/onnxscript/rewriter/ort_fusions/mha_test.py @@ -32,7 +32,7 @@ def test_smollm(self): # Fuse SDPA and MHA sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) - mha_count = xformers.fuse_mha(model, debug=True) + mha_count = xformers.fuse_mha(model) self.assertGreater(mha_count, 0) if test_with_ort: From 4b1a68da30bb55ebff83fdd1964f03b506f791b0 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 13 Mar 2025 17:56:38 -0700 Subject: [PATCH 04/28] Update GQA --- onnxscript/rewriter/ort_fusions/gqa2.py | 239 ++++++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 onnxscript/rewriter/ort_fusions/gqa2.py diff --git a/onnxscript/rewriter/ort_fusions/gqa2.py b/onnxscript/rewriter/ort_fusions/gqa2.py new file mode 100644 index 0000000000..4749ade5f1 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gqa2.py @@ -0,0 +1,239 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Sequence, Union + +import onnxscript.ir as ir +from onnxscript.rewriter import _ir_utils, pattern + +""" +The MultiHeadAttention pattern: generate an instance + MHA (query, key, value, None, None, mask, past_key, past_value) +where query has shape (B, S, D), key has shape (B, Skv, D), and value has shape (B, Skv, Dv). +The next two inputs bias and key_padding_mask are None in this pattern. The mask (attention_bias) +must be of shape (1 or B, 1 or H, S, St). past_key and past_value are of shape (B, H, Spast, Dh). + +We use the following abbreviations for the dimensions: +B: Batch size +S: Sequence length +D: input embedding dimension +Dv: value hidden size (usually, Dv = D) +H: number of heads +Dh: head size or embedding dimension per head (usually, D = H * Dh) +Skv: key/value sequence length +St: total sequence length + +In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh). +The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh). +""" + +Dim = Union[int, ir.SymbolicDim] + + +def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: + if val.shape is None: + return False + if val.shape.rank() != len(shape): + return False + for actual, expected in zip(val.shape, shape): + if expected not in bindings: + bindings[expected] = actual # type: ignore[assignment] + elif actual != bindings[expected]: + return False + return True + + +class GroupQueryAttention(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("MHA") + + def pattern( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + mask, + past_key, + past_value, + position_ids, + cos, + sin, + ): + # First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H) + + # Reshape from (B, S, D) to (B, S, H, D/H) + query_BSHDh = op.Reshape( + query_BSD, + _allow_other_inputs=True, + _allow_other_attributes=True, + _outputs=["query_BSHDh"], + ) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + # Reshape from (B, S, D) to (B, S, H, D/H) + key_BSHkvDh = op.Reshape( + key_BSDkv, + _allow_other_inputs=True, + _allow_other_attributes=True, + _outputs=["key_BSHkvDh"], + ) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + + # Reshape from (B, S, D) to (B, S, H, D/H) + value_BSHkvDh = op.Reshape( + value_BSDkv, + _allow_other_inputs=True, + _allow_other_attributes=True, + _outputs=["value_BSHkvDh"], + ) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + + query_BHSDh_rope = op.RotaryEmbedding( + query_BHSDh, position_ids, cos, sin, _domain="com.microsoft" + ) + key_BHkvSDh_rope = op.RotaryEmbedding( + key_BHkvSDh, position_ids, cos, sin, _domain="com.microsoft" + ) + + # Concatenate past_key cache and current key, and transpose to enable + # dot-product attention computation. + + key_seq_BHkvSDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + key_seq_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSDh, 2) + key_seq_BHkvGSDh = op.Expand(key_seq_BHkv1SDh, _allow_other_inputs=True) + key_seq_BHSkvDh = op.Reshape( + key_seq_BHkvGSDh, _allow_other_inputs=True, _outputs=["key_seq_BHSkvDh"]) + key_seq_BHDhSkv = op.Transpose( + key_seq_BHSkvDh, _allow_other_inputs=True, _outputs=["key_seq_BHDhSkv"] + ) + + # Concatenate past_value cache and current value + value_seq_BHkvSDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + value_seq_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSDh, 2) + value_seq_BHkvGSDh = op.Expand(value_seq_BHkv1SDh, _allow_other_inputs=True) + key_seq_BHSkvDh = op.Reshape( + value_seq_BHkvGSDh, _allow_other_inputs=True, _outputs=["key_seq_BHSkvDh"]) + + attention_BHSDh = op.SDPA( + query_BHSDh_rope, + key_seq_BHDhSkv, + key_seq_BHSkvDh, + mask, + _domain="ai.onnxruntime.fusion", + ) + + # Transpose attention back to (B, S, H, D/H) + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_BSD = op.Reshape( + attention_BSHDh, _allow_other_inputs=True, _outputs=["attention_BSD"] + ) + return attention_BSD, key_seq_BHkvSDh, value_seq_BHkvSDh + + def check( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + mask, + past_key, + past_value, + # query_BSHDh, + # key_BSHkvDh, + # value_BSHkvDh, + **_, + ): + # bindings: dict[str, Dim] = {} + + # def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + # return not _check_shape(bindings, val, dims) + + # if no_match(query_BSD, ["B", "S", "D"]): + # return False + # if no_match(key_BSDkv, ["B", "Skv", "D"]): + # return False + # if no_match(value_BSDkv, ["B", "Skv", "D"]): + # return False + + # if no_match(past_key, ["B", "H", "Spast", "Dh"]): + # return False + # if no_match(past_value, ["B", "H", "Spast", "Dv"]): + # return False + # if no_match(query_BSHDh, ["B", "S", "H", "Dh"]): + # return False + # if no_match(key_BSHkvDh, ["B", "S", "H", "Dh"]): + # return False + # if no_match(value_BSHkvDh, ["B", "S", "H", "Dh"]): + # return False + + # TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St) + # But this also, unforunately, depends on ORT version. + + # TODO: verify Reshapes: + # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: + # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: + # or check Reshape's shape-input value + + return True + + def rewrite( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + mask, + past_key, + past_value, + # key_BSHkvDh, + # position_ids, + # cos, + # sin, + **_, + ): + # num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) + # if not isinstance(num_heads, int): + # return None + + # # Switch to 3D RotaryEmbedding + # # TODO: forward other attributes + # query_BSD_rope = op.RotaryEmbedding( + # query_BSD, position_ids, cos, sin, _domain="com.microsoft" + # ) + # key_BSD_rope = op.RotaryEmbedding( + # key_BSDkv, position_ids, cos, sin, _domain="com.microsoft" + # ) + + return op.DummyGQA( + query_BSD, + key_BSDkv, + value_BSDkv, + None, # bias + None, # key padding mask + mask, # attention mask/bias + past_key, + past_value, + # num_heads=num_heads, + _domain="com.microsoft", + _outputs=3, + ) + + +_rule1 = GroupQueryAttention.rule() + +# _rule1 = GroupQueryAttention.rule("GQA", use_2d_matmul=False) + +gqa_rules = pattern.RewriteRuleSet([_rule1]) + + +def fuse_gqa(model: ir.Model) -> int: + count = gqa_rules.apply_to_model(model) + print(f"GQA count: {count}") + # remove_unused_nodes(model) + return count From 4dbb44c327cb72389df12a076aae0d8962651dca Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 13 Mar 2025 18:15:26 -0700 Subject: [PATCH 05/28] Minor fixes --- onnxscript/rewriter/ort_fusions/gqa2.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa2.py b/onnxscript/rewriter/ort_fusions/gqa2.py index 4749ade5f1..23cc1dcf03 100644 --- a/onnxscript/rewriter/ort_fusions/gqa2.py +++ b/onnxscript/rewriter/ort_fusions/gqa2.py @@ -93,11 +93,14 @@ def pattern( # Transpose from (B, S, H, D/H) to (B, H, S, D/H) value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + position_ids_q = op.Unsqueeze(position_ids, [0]) + position_ids_k = op.Unsqueeze(position_ids, [0]) + query_BHSDh_rope = op.RotaryEmbedding( - query_BHSDh, position_ids, cos, sin, _domain="com.microsoft" + query_BHSDh, position_ids_q, cos, sin, _domain="com.microsoft" ) key_BHkvSDh_rope = op.RotaryEmbedding( - key_BHkvSDh, position_ids, cos, sin, _domain="com.microsoft" + key_BHkvSDh, position_ids_k, cos, sin, _domain="com.microsoft" ) # Concatenate past_key cache and current key, and transpose to enable @@ -116,13 +119,13 @@ def pattern( value_seq_BHkvSDh = op.Concat(past_value, value_BHkvSDh, axis=-2) value_seq_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSDh, 2) value_seq_BHkvGSDh = op.Expand(value_seq_BHkv1SDh, _allow_other_inputs=True) - key_seq_BHSkvDh = op.Reshape( - value_seq_BHkvGSDh, _allow_other_inputs=True, _outputs=["key_seq_BHSkvDh"]) + value_seq_BHSkvDh = op.Reshape( + value_seq_BHkvGSDh, _allow_other_inputs=True, _outputs=["value_seq_BHSkvDh"]) attention_BHSDh = op.SDPA( query_BHSDh_rope, key_seq_BHDhSkv, - key_seq_BHSkvDh, + value_seq_BHSkvDh, mask, _domain="ai.onnxruntime.fusion", ) From 4d73c6e588151002da901529f394540e9dc307cb Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 14 Mar 2025 17:20:09 -0700 Subject: [PATCH 06/28] Switch to new GQA --- onnxscript/rewriter/ort_fusions/gqa.py | 271 ++++++++++++++++-------- onnxscript/rewriter/ort_fusions/gqa2.py | 242 --------------------- 2 files changed, 177 insertions(+), 336 deletions(-) delete mode 100644 onnxscript/rewriter/ort_fusions/gqa2.py diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 4bad28c789..827fd739f9 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -2,149 +2,232 @@ # Licensed under the MIT License. from __future__ import annotations +from typing import Sequence, Union + import onnxscript.ir as ir -from onnxscript.optimizer import remove_unused_nodes from onnxscript.rewriter import pattern +""" +GroupQueryAttention: This generalizes MHA by allowing the number of heads to be different +for query and key/value. + +We use the following abbreviations for the dimensions: +B: Batch size +S: Sequence length (for current query/key/value) +D: input embedding dimension +Dkv: key/value hidden size +H: number of heads (must be an integral multiple of Hkv) +Hkv: number of heads for key/value +Dh: head size or embedding dimension per head (usually, D = H * Dh) +Skv: key/value sequence length +St: total sequence length + +In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh). +The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh). +""" + +Dim = Union[int, ir.SymbolicDim] + + +def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: + if val.shape is None: + return False + if val.shape.rank() != len(shape): + return False + for actual, expected in zip(val.shape, shape): + if expected not in bindings: + bindings[expected] = actual # type: ignore[assignment] + elif actual != bindings[expected]: + return False + return True + class GroupQueryAttention(pattern.RewriteRuleClassBase): - def __init__(self, name: str, *, use_2d_matmul: bool): - super().__init__(name, remove_nodes=False) - self._use_2d_matmul = use_2d_matmul - - def _compute_packed_QKV(self, op, input, weight): - if self._use_2d_matmul: - # Convert batched input of shape (B, S, D) to 2D input (B*S, D) - input = op.Reshape(input, _allow_other_inputs=True) - projected = op.MatMul(input, weight) - if self._use_2d_matmul: - # Convert 2D output back to batched output of shape (B, S, D) - projected = op.Reshape(projected, _allow_other_inputs=True) - # Split combined QKV into Q, K, and V - query_3d = op.Slice(projected, _allow_other_inputs=True) - key_3d = op.Slice(projected, _allow_other_inputs=True) - value_3d = op.Slice(projected, _allow_other_inputs=True) - # Reshape from (B, S, D) to (B, S, H, D/H) - query_4d = op.Reshape( - query_3d, + def __init__(self): + super().__init__("GQA") + + def pattern( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + mask, + past_key, + past_value, + position_ids, + cos, + sin, + ): + # Reshape query from (B, S, D) to (B, S, H, D/H) + query_BSHDh = op.Reshape( + query_BSD, _allow_other_inputs=True, _allow_other_attributes=True, - _outputs=["query_mm_reshaped"], + _outputs=["query_BSHDh"], ) # Transpose from (B, S, H, D/H) to (B, H, S, D/H) - query = op.Transpose(query_4d, perm=[0, 2, 1, 3]) - key_4d = op.Reshape( - key_3d, + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H) + key_BSHkvDh = op.Reshape( + key_BSDkv, _allow_other_inputs=True, _allow_other_attributes=True, - _outputs=["key_mm_reshaped"], + _outputs=["key_BSHkvDh"], ) - key = op.Transpose(key_4d, perm=[0, 2, 1, 3]) - value_4d = op.Reshape( - value_3d, + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + + # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H) + value_BSHkvDh = op.Reshape( + value_BSDkv, _allow_other_inputs=True, _allow_other_attributes=True, - _outputs=["value_mm_reshaped"], + _outputs=["value_BSHkvDh"], ) - value = op.Transpose(value_4d, perm=[0, 2, 1, 3]) + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) - return query, key, value + position_ids_q = op.Unsqueeze(position_ids, [0]) + position_ids_k = op.Unsqueeze(position_ids, [0]) - def pattern( - self, - op, - input, - qkv_weight, - mask, - cos, - sin, - past_key, - past_value, - position_ids, - ): - query, key, value = self._compute_packed_QKV(op, input, qkv_weight) + query_BHSDh_rope = op.RotaryEmbedding( + query_BHSDh, position_ids_q, cos, sin, _domain="com.microsoft" + ) + key_BHkvSDh_rope = op.RotaryEmbedding( + key_BHkvSDh, position_ids_k, cos, sin, _domain="com.microsoft" + ) - query_rope = op.RotaryEmbedding(query, position_ids, cos, sin, _domain="com.microsoft") + # Concatenate past_key cache and current key, expand across heads + # that share key/value and transpose to enable dot-product attention computation. - key_rope = op.RotaryEmbedding(key, position_ids, cos, sin, _domain="com.microsoft") - present_key = op.Concat(past_key, key_rope, axis=-2) - # Transpose last two axes of present_key to compute dot-product via matmul. - present_key = op.Transpose(present_key, perm=[0, 1, 3, 2]) + key_seq_BHkvSDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + key_seq_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSDh, 2) + key_seq_BHkvGSDh = op.Expand(key_seq_BHkv1SDh, _allow_other_inputs=True) + key_seq_BHSkvDh = op.Reshape( + key_seq_BHkvGSDh, _allow_other_inputs=True, _outputs=["key_seq_BHSkvDh"] + ) + key_seq_BHDhSkv = op.Transpose( + key_seq_BHSkvDh, _allow_other_inputs=True, _outputs=["key_seq_BHDhSkv"] + ) - present_value = op.Concat(past_value, value, axis=-2) + # Concatenate past_value cache and current value, expand across heads + # that share key/value. + value_seq_BHkvSDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + value_seq_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSDh, 2) + value_seq_BHkvGSDh = op.Expand(value_seq_BHkv1SDh, _allow_other_inputs=True) + value_seq_BHSkvDh = op.Reshape( + value_seq_BHkvGSDh, _allow_other_inputs=True, _outputs=["value_seq_BHSkvDh"] + ) - attention = op.SDPA( - query_rope, present_key, present_value, mask, _domain="ai.onnxruntime.fusion" + attention_BHSDh = op.SDPA( + query_BHSDh_rope, + key_seq_BHDhSkv, + value_seq_BHSkvDh, + mask, + _domain="ai.onnxruntime.fusion", ) - # Transpose back to (B, S, H, D/H) - attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) + + # Transpose attention back to (B, S, H, D/H) + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) # Reshape back to (B, S, D) - attention_reshaped = op.Reshape( - attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"] + attention_BSD = op.Reshape( + attention_BSHDh, _allow_other_inputs=True, _outputs=["attention_BSD"] ) - return attention_reshaped, present_key, present_value + return attention_BSD, key_seq_BHkvSDh, value_seq_BHkvSDh def check( self, op, - # query_mm_reshaped, - # key_mm_reshaped, - # value_mm_reshaped, - # key_reshaped, - # key_transposed, - # attention_reshaped, + query_BSD, + key_BSDkv, + value_BSDkv, + mask, + past_key, + past_value, + # query_BSHDh, + # key_BSHkvDh, + # value_BSHkvDh, **_, ): - # bindings: dict[str, int] = {} - # status = ( - # _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"]) - # and _check_shape(bindings, key_mm_reshaped, ["B", "S", "H", "d_h"]) - # and _check_shape(bindings, value_mm_reshaped, ["B", "S", "H", "d_h"]) - # and _check_shape(bindings, key_reshaped, ["B*H", "KVS", "d_h"]) - # and _check_shape(bindings, key_transposed, ["B", "H", "d_h", "KVS"]) - # and _check_shape(bindings, attention_reshaped, ["B", "S", "H*d_h"]) - # ) - # if not status: + # bindings: dict[str, Dim] = {} + + # def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + # return not _check_shape(bindings, val, dims) + + # if no_match(query_BSD, ["B", "S", "D"]): + # return False + # if no_match(key_BSDkv, ["B", "Skv", "D"]): + # return False + # if no_match(value_BSDkv, ["B", "Skv", "D"]): # return False - # if bindings["B"] * bindings["H"] != bindings["B*H"]: + + # if no_match(past_key, ["B", "H", "Spast", "Dh"]): + # return False + # if no_match(past_value, ["B", "H", "Spast", "Dv"]): + # return False + # if no_match(query_BSHDh, ["B", "S", "H", "Dh"]): + # return False + # if no_match(key_BSHkvDh, ["B", "S", "H", "Dh"]): # return False - # if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]: + # if no_match(value_BSHkvDh, ["B", "S", "H", "Dh"]): # return False + + # TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St) + # But this also, unforunately, depends on ORT version. + + # TODO: verify Reshapes: + # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: + # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: + # or check Reshape's shape-input value + return True def rewrite( self, op, - input, - qkv_weight, + query_BSD, + key_BSDkv, + value_BSDkv, mask, - cos, - sin, past_key, past_value, - position_ids, - query_mm_reshaped, + # key_BSHkvDh, + # position_ids, + # cos, + # sin, **_, ): - num_heads = query_mm_reshaped.shape[2] - qkv = op.MatMul(input, qkv_weight) - return op.GroupQueryAttention( - qkv, - None, # key - None, # value + # num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) + # if not isinstance(num_heads, int): + # return None + + # # Switch to 3D RotaryEmbedding + # # TODO: forward other attributes + # query_BSD_rope = op.RotaryEmbedding( + # query_BSD, position_ids, cos, sin, _domain="com.microsoft" + # ) + # key_BSD_rope = op.RotaryEmbedding( + # key_BSDkv, position_ids, cos, sin, _domain="com.microsoft" + # ) + + return op.DummyGQA( + query_BSD, + key_BSDkv, + value_BSDkv, + None, # bias + None, # key padding mask + mask, # attention mask/bias past_key, past_value, - # seqlens_k, - # total_sequence_length, - cos, - sin, - num_heads=num_heads, + # num_heads=num_heads, _domain="com.microsoft", _outputs=3, ) -_rule1 = GroupQueryAttention.rule("MHA_2dmm", use_2d_matmul=False) +_rule1 = GroupQueryAttention.rule() gqa_rules = pattern.RewriteRuleSet([_rule1]) @@ -152,5 +235,5 @@ def rewrite( def fuse_gqa(model: ir.Model) -> int: count = gqa_rules.apply_to_model(model) print(f"GQA count: {count}") - remove_unused_nodes(model) + # remove_unused_nodes(model) return count diff --git a/onnxscript/rewriter/ort_fusions/gqa2.py b/onnxscript/rewriter/ort_fusions/gqa2.py deleted file mode 100644 index 23cc1dcf03..0000000000 --- a/onnxscript/rewriter/ort_fusions/gqa2.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -from typing import Sequence, Union - -import onnxscript.ir as ir -from onnxscript.rewriter import _ir_utils, pattern - -""" -The MultiHeadAttention pattern: generate an instance - MHA (query, key, value, None, None, mask, past_key, past_value) -where query has shape (B, S, D), key has shape (B, Skv, D), and value has shape (B, Skv, Dv). -The next two inputs bias and key_padding_mask are None in this pattern. The mask (attention_bias) -must be of shape (1 or B, 1 or H, S, St). past_key and past_value are of shape (B, H, Spast, Dh). - -We use the following abbreviations for the dimensions: -B: Batch size -S: Sequence length -D: input embedding dimension -Dv: value hidden size (usually, Dv = D) -H: number of heads -Dh: head size or embedding dimension per head (usually, D = H * Dh) -Skv: key/value sequence length -St: total sequence length - -In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh). -The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh). -""" - -Dim = Union[int, ir.SymbolicDim] - - -def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: - if val.shape is None: - return False - if val.shape.rank() != len(shape): - return False - for actual, expected in zip(val.shape, shape): - if expected not in bindings: - bindings[expected] = actual # type: ignore[assignment] - elif actual != bindings[expected]: - return False - return True - - -class GroupQueryAttention(pattern.RewriteRuleClassBase): - def __init__(self): - super().__init__("MHA") - - def pattern( - self, - op, - query_BSD, - key_BSDkv, - value_BSDkv, - mask, - past_key, - past_value, - position_ids, - cos, - sin, - ): - # First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H) - - # Reshape from (B, S, D) to (B, S, H, D/H) - query_BSHDh = op.Reshape( - query_BSD, - _allow_other_inputs=True, - _allow_other_attributes=True, - _outputs=["query_BSHDh"], - ) - # Transpose from (B, S, H, D/H) to (B, H, S, D/H) - query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) - - # Reshape from (B, S, D) to (B, S, H, D/H) - key_BSHkvDh = op.Reshape( - key_BSDkv, - _allow_other_inputs=True, - _allow_other_attributes=True, - _outputs=["key_BSHkvDh"], - ) - # Transpose from (B, S, H, D/H) to (B, H, S, D/H) - key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) - - # Reshape from (B, S, D) to (B, S, H, D/H) - value_BSHkvDh = op.Reshape( - value_BSDkv, - _allow_other_inputs=True, - _allow_other_attributes=True, - _outputs=["value_BSHkvDh"], - ) - # Transpose from (B, S, H, D/H) to (B, H, S, D/H) - value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) - - position_ids_q = op.Unsqueeze(position_ids, [0]) - position_ids_k = op.Unsqueeze(position_ids, [0]) - - query_BHSDh_rope = op.RotaryEmbedding( - query_BHSDh, position_ids_q, cos, sin, _domain="com.microsoft" - ) - key_BHkvSDh_rope = op.RotaryEmbedding( - key_BHkvSDh, position_ids_k, cos, sin, _domain="com.microsoft" - ) - - # Concatenate past_key cache and current key, and transpose to enable - # dot-product attention computation. - - key_seq_BHkvSDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) - key_seq_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSDh, 2) - key_seq_BHkvGSDh = op.Expand(key_seq_BHkv1SDh, _allow_other_inputs=True) - key_seq_BHSkvDh = op.Reshape( - key_seq_BHkvGSDh, _allow_other_inputs=True, _outputs=["key_seq_BHSkvDh"]) - key_seq_BHDhSkv = op.Transpose( - key_seq_BHSkvDh, _allow_other_inputs=True, _outputs=["key_seq_BHDhSkv"] - ) - - # Concatenate past_value cache and current value - value_seq_BHkvSDh = op.Concat(past_value, value_BHkvSDh, axis=-2) - value_seq_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSDh, 2) - value_seq_BHkvGSDh = op.Expand(value_seq_BHkv1SDh, _allow_other_inputs=True) - value_seq_BHSkvDh = op.Reshape( - value_seq_BHkvGSDh, _allow_other_inputs=True, _outputs=["value_seq_BHSkvDh"]) - - attention_BHSDh = op.SDPA( - query_BHSDh_rope, - key_seq_BHDhSkv, - value_seq_BHSkvDh, - mask, - _domain="ai.onnxruntime.fusion", - ) - - # Transpose attention back to (B, S, H, D/H) - attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) - # Reshape back to (B, S, D) - attention_BSD = op.Reshape( - attention_BSHDh, _allow_other_inputs=True, _outputs=["attention_BSD"] - ) - return attention_BSD, key_seq_BHkvSDh, value_seq_BHkvSDh - - def check( - self, - op, - query_BSD, - key_BSDkv, - value_BSDkv, - mask, - past_key, - past_value, - # query_BSHDh, - # key_BSHkvDh, - # value_BSHkvDh, - **_, - ): - # bindings: dict[str, Dim] = {} - - # def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - # return not _check_shape(bindings, val, dims) - - # if no_match(query_BSD, ["B", "S", "D"]): - # return False - # if no_match(key_BSDkv, ["B", "Skv", "D"]): - # return False - # if no_match(value_BSDkv, ["B", "Skv", "D"]): - # return False - - # if no_match(past_key, ["B", "H", "Spast", "Dh"]): - # return False - # if no_match(past_value, ["B", "H", "Spast", "Dv"]): - # return False - # if no_match(query_BSHDh, ["B", "S", "H", "Dh"]): - # return False - # if no_match(key_BSHkvDh, ["B", "S", "H", "Dh"]): - # return False - # if no_match(value_BSHkvDh, ["B", "S", "H", "Dh"]): - # return False - - # TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St) - # But this also, unforunately, depends on ORT version. - - # TODO: verify Reshapes: - # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: - # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: - # or check Reshape's shape-input value - - return True - - def rewrite( - self, - op, - query_BSD, - key_BSDkv, - value_BSDkv, - mask, - past_key, - past_value, - # key_BSHkvDh, - # position_ids, - # cos, - # sin, - **_, - ): - # num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) - # if not isinstance(num_heads, int): - # return None - - # # Switch to 3D RotaryEmbedding - # # TODO: forward other attributes - # query_BSD_rope = op.RotaryEmbedding( - # query_BSD, position_ids, cos, sin, _domain="com.microsoft" - # ) - # key_BSD_rope = op.RotaryEmbedding( - # key_BSDkv, position_ids, cos, sin, _domain="com.microsoft" - # ) - - return op.DummyGQA( - query_BSD, - key_BSDkv, - value_BSDkv, - None, # bias - None, # key padding mask - mask, # attention mask/bias - past_key, - past_value, - # num_heads=num_heads, - _domain="com.microsoft", - _outputs=3, - ) - - -_rule1 = GroupQueryAttention.rule() - -# _rule1 = GroupQueryAttention.rule("GQA", use_2d_matmul=False) - -gqa_rules = pattern.RewriteRuleSet([_rule1]) - - -def fuse_gqa(model: ir.Model) -> int: - count = gqa_rules.apply_to_model(model) - print(f"GQA count: {count}") - # remove_unused_nodes(model) - return count From 9c79b98ade750c8ead9a7b9e80edcafb7565e390 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 14 Mar 2025 22:10:07 -0700 Subject: [PATCH 07/28] Fix variable naming --- onnxscript/rewriter/ort_fusions/gqa.py | 33 ++++++++++++++------------ 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 827fd739f9..75886d220b 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -14,13 +14,16 @@ We use the following abbreviations for the dimensions: B: Batch size S: Sequence length (for current query/key/value) -D: input embedding dimension -Dkv: key/value hidden size -H: number of heads (must be an integral multiple of Hkv) + Hkv: number of heads for key/value -Dh: head size or embedding dimension per head (usually, D = H * Dh) -Skv: key/value sequence length -St: total sequence length +G = number of groups +H: number of heads = G * Hkv + +Dh: head size or embedding dimension per head +D: input embedding dimension (hidden size) = H * Dh +Dkv: key/value hidden size = Hkv * Dh + +Skv: key/value sequence length (after concatenation of past and current key/value) In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh). The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh). @@ -102,11 +105,11 @@ def pattern( # Concatenate past_key cache and current key, expand across heads # that share key/value and transpose to enable dot-product attention computation. - key_seq_BHkvSDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) - key_seq_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSDh, 2) - key_seq_BHkvGSDh = op.Expand(key_seq_BHkv1SDh, _allow_other_inputs=True) + key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + key_seq_BHkv1SkvDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) + key_seq_BHkvGSkvDh = op.Expand(key_seq_BHkv1SkvDh, _allow_other_inputs=True) key_seq_BHSkvDh = op.Reshape( - key_seq_BHkvGSDh, _allow_other_inputs=True, _outputs=["key_seq_BHSkvDh"] + key_seq_BHkvGSkvDh, _allow_other_inputs=True, _outputs=["key_seq_BHSkvDh"] ) key_seq_BHDhSkv = op.Transpose( key_seq_BHSkvDh, _allow_other_inputs=True, _outputs=["key_seq_BHDhSkv"] @@ -114,11 +117,11 @@ def pattern( # Concatenate past_value cache and current value, expand across heads # that share key/value. - value_seq_BHkvSDh = op.Concat(past_value, value_BHkvSDh, axis=-2) - value_seq_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSDh, 2) - value_seq_BHkvGSDh = op.Expand(value_seq_BHkv1SDh, _allow_other_inputs=True) + value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + value_seq_BHkv1SkvDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) + value_seq_BHkvGSkvDh = op.Expand(value_seq_BHkv1SkvDh, _allow_other_inputs=True) value_seq_BHSkvDh = op.Reshape( - value_seq_BHkvGSDh, _allow_other_inputs=True, _outputs=["value_seq_BHSkvDh"] + value_seq_BHkvGSkvDh, _allow_other_inputs=True, _outputs=["value_seq_BHSkvDh"] ) attention_BHSDh = op.SDPA( @@ -135,7 +138,7 @@ def pattern( attention_BSD = op.Reshape( attention_BSHDh, _allow_other_inputs=True, _outputs=["attention_BSD"] ) - return attention_BSD, key_seq_BHkvSDh, value_seq_BHkvSDh + return attention_BSD, key_seq_BHkvSkvDh, value_seq_BHkvSkvDh def check( self, From 0d0e8aed050c8dc24ceb1e148c5f8dc7d38a797b Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 17 Mar 2025 14:29:49 -0700 Subject: [PATCH 08/28] Add num heads attributes --- onnxscript/rewriter/ort_fusions/gqa.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 75886d220b..db6bc33414 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -5,7 +5,7 @@ from typing import Sequence, Union import onnxscript.ir as ir -from onnxscript.rewriter import pattern +from onnxscript.rewriter import _ir_utils, pattern """ GroupQueryAttention: This generalizes MHA by allowing the number of heads to be different @@ -196,15 +196,18 @@ def rewrite( mask, past_key, past_value, + query_BSHDh, + key_BSHkvDh, # key_BSHkvDh, # position_ids, # cos, # sin, **_, ): - # num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) - # if not isinstance(num_heads, int): - # return None + num_heads = _ir_utils.get_dim(query_BSHDh, 2) + kv_num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) + if not isinstance(num_heads, int) or not isinstance(kv_num_heads, int): + return None # # Switch to 3D RotaryEmbedding # # TODO: forward other attributes @@ -215,16 +218,16 @@ def rewrite( # key_BSDkv, position_ids, cos, sin, _domain="com.microsoft" # ) - return op.DummyGQA( + return op.GroupQueryAttention( query_BSD, key_BSDkv, value_BSDkv, - None, # bias - None, # key padding mask - mask, # attention mask/bias past_key, past_value, - # num_heads=num_heads, + # skipped optional inputs: seqlens_k, total_sequence_length, cos_cache, sin_cache + num_heads=num_heads, + kv_num_heads=kv_num_heads, + # skipped optional attributes: do_rotary, local_window_size, rotary_interleaved, scale, smooth_softmax, softcap _domain="com.microsoft", _outputs=3, ) From b5885829e958bffcacaffc02675e02ea251ac7a8 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 18 Mar 2025 12:51:17 -0700 Subject: [PATCH 09/28] Use seqlens and totalseqlen --- onnxscript/rewriter/ort_fusions/gqa.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index db6bc33414..d777cd1195 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -48,6 +48,7 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) class GroupQueryAttention(pattern.RewriteRuleClassBase): def __init__(self): super().__init__("GQA") + self.remove_nodes = False def pattern( self, @@ -58,7 +59,9 @@ def pattern( mask, past_key, past_value, - position_ids, + # position_ids, + past_seq_length, + total_seq_length, cos, sin, ): @@ -92,6 +95,7 @@ def pattern( # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + position_ids = op.Range(past_seq_length, total_seq_length, 1) position_ids_q = op.Unsqueeze(position_ids, [0]) position_ids_k = op.Unsqueeze(position_ids, [0]) @@ -198,6 +202,8 @@ def rewrite( past_value, query_BSHDh, key_BSHkvDh, + past_seq_length, + total_seq_length, # key_BSHkvDh, # position_ids, # cos, @@ -217,6 +223,11 @@ def rewrite( # key_BSD_rope = op.RotaryEmbedding( # key_BSDkv, position_ids, cos, sin, _domain="com.microsoft" # ) + total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) + one_0D = op.Constant(value_int=1, dtype=ir.DataType.INT32) + seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D) + zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) + seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D) return op.GroupQueryAttention( query_BSD, @@ -224,6 +235,9 @@ def rewrite( value_BSDkv, past_key, past_value, + seqlens_k, + total_seq_length_int32, + # mask, # TODO: this is not a valid input for GQA # skipped optional inputs: seqlens_k, total_sequence_length, cos_cache, sin_cache num_heads=num_heads, kv_num_heads=kv_num_heads, From 3febda2fce5e72db512f75c0fc88c608b8715c09 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 18 Mar 2025 13:41:56 -0700 Subject: [PATCH 10/28] Add cos and sin cache --- onnxscript/rewriter/ort_fusions/gqa.py | 48 ++++++++++++++++---------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index d777cd1195..3f5efab368 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -100,10 +100,20 @@ def pattern( position_ids_k = op.Unsqueeze(position_ids, [0]) query_BHSDh_rope = op.RotaryEmbedding( - query_BHSDh, position_ids_q, cos, sin, _domain="com.microsoft" + query_BHSDh, + position_ids_q, + cos, + sin, + _domain="com.microsoft", + _outputs=["query_BHSDh_rope"], ) key_BHkvSDh_rope = op.RotaryEmbedding( - key_BHkvSDh, position_ids_k, cos, sin, _domain="com.microsoft" + key_BHkvSDh, + position_ids_k, + cos, + sin, + _domain="com.microsoft", + _outputs=["key_BHkvSDh_rope"], ) # Concatenate past_key cache and current key, expand across heads @@ -153,6 +163,8 @@ def check( mask, past_key, past_value, + query_BHSDh_rope, + key_BHkvSDh_rope, # query_BSHDh, # key_BSHkvDh, # value_BSHkvDh, @@ -189,6 +201,15 @@ def check( # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: # or check Reshape's shape-input value + # Rotary embedding attributes + query_rotary_attributes = query_BHSDh_rope.producer().attributes + key_rotary_attributes = key_BHkvSDh_rope.producer().attributes + query_interleaved = query_rotary_attributes.get("interleaved", 0) + key_interleaved = key_rotary_attributes.get("interleaved", 0) + if query_interleaved != key_interleaved: + return False + self._interleaved = query_interleaved + return True def rewrite( @@ -197,17 +218,13 @@ def rewrite( query_BSD, key_BSDkv, value_BSDkv, - mask, past_key, past_value, query_BSHDh, key_BSHkvDh, - past_seq_length, total_seq_length, - # key_BSHkvDh, - # position_ids, - # cos, - # sin, + cos, + sin, **_, ): num_heads = _ir_utils.get_dim(query_BSHDh, 2) @@ -215,14 +232,6 @@ def rewrite( if not isinstance(num_heads, int) or not isinstance(kv_num_heads, int): return None - # # Switch to 3D RotaryEmbedding - # # TODO: forward other attributes - # query_BSD_rope = op.RotaryEmbedding( - # query_BSD, position_ids, cos, sin, _domain="com.microsoft" - # ) - # key_BSD_rope = op.RotaryEmbedding( - # key_BSDkv, position_ids, cos, sin, _domain="com.microsoft" - # ) total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) one_0D = op.Constant(value_int=1, dtype=ir.DataType.INT32) seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D) @@ -237,11 +246,14 @@ def rewrite( past_value, seqlens_k, total_seq_length_int32, + cos, + sin, # mask, # TODO: this is not a valid input for GQA - # skipped optional inputs: seqlens_k, total_sequence_length, cos_cache, sin_cache num_heads=num_heads, kv_num_heads=kv_num_heads, - # skipped optional attributes: do_rotary, local_window_size, rotary_interleaved, scale, smooth_softmax, softcap + do_rotary=1, + rotary_interleaved=self._interleaved, + # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap _domain="com.microsoft", _outputs=3, ) From afcf0a71d6a9e47a830836679c68563345dd5e88 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 18 Mar 2025 18:01:14 -0700 Subject: [PATCH 11/28] Fix int32 type --- onnxscript/rewriter/ort_fusions/gqa.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 3f5efab368..6ab13ec213 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -233,8 +233,9 @@ def rewrite( return None total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) - one_0D = op.Constant(value_int=1, dtype=ir.DataType.INT32) - seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D) + one_0D = op.Constant(value_int=1) + one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32) + seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32) zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D) @@ -252,7 +253,7 @@ def rewrite( num_heads=num_heads, kv_num_heads=kv_num_heads, do_rotary=1, - rotary_interleaved=self._interleaved, + rotary_interleaved=self._interleaved.value, # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap _domain="com.microsoft", _outputs=3, From 03c08c7edc0180d01cd31fe40091f681c39b50aa Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 20 Mar 2025 08:23:41 -0700 Subject: [PATCH 12/28] GQA fusion --- onnxscript/rewriter/ort_fusions/gqa.py | 2 ++ onnxscript/rewriter/ort_fusions/gqa_test.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 onnxscript/rewriter/ort_fusions/gqa_test.py diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index db6bc33414..ade315b55a 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -179,6 +179,8 @@ def check( # TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St) # But this also, unforunately, depends on ORT version. + # TODO: check that mask is causal. Latest ORT is adding support for + # non-causal masks, but not yet for all EPs. # TODO: verify Reshapes: # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py new file mode 100644 index 0000000000..c751452901 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Testing GQA fusion.""" + +import numpy + +import onnxscript.ir as ir +from onnxscript import script +from onnxscript.onnx_opset import opset18 as op +from onnxscript.onnx_types import FLOAT, INT64 + +@script() +def _gqa_prompt_script(query, key, value): + pass \ No newline at end of file From 0bc603f754d430f09c060dada375a1bc1cc8f8fb Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 24 Mar 2025 17:44:21 -0700 Subject: [PATCH 13/28] Basic GQA test --- .../rewriter/ort_fusions/gqa_basic_test.py | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 onnxscript/rewriter/ort_fusions/gqa_basic_test.py diff --git a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py new file mode 100644 index 0000000000..a2fb23a819 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py @@ -0,0 +1,160 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import packaging.version + +import math +import onnx +import onnxruntime as ort +import numpy as np +import onnxscript +import torch +from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run + +# Simple GQA: no rotary embedding, no past key/value, no cos/sin cache, no seqlens/total_seqlen + +class GQA1(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.batchsize = 2 + self.seqlen = 8 + self.kv_seqlen = self.seqlen + self.headsize = 16 + self.num_heads = 20 + self.kv_num_heads = 10 + self.hidden_size = self.headsize * self.num_heads + self.kv_hidden_size = self.headsize * self.kv_num_heads + self.num_groups = self.num_heads // self.kv_num_heads + + def random_inputs(self): + B = self.batchsize + S = self.seqlen + D = self.hidden_size + Dkv = self.kv_hidden_size + query = np.random.rand(B, S, D).astype(np.float32) + key = np.random.rand(B, S, Dkv).astype(np.float32) + value = np.random.rand(B, S, Dkv).astype(np.float32) + return { + "query": query, + "key": key, + "value": value, + } + + def ort_gqa_model(self): + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.headsize + H = self.num_heads + Hkv = self.kv_num_heads + return onnx.parser.parse_model( + f""" + + GQA (float[B, S, {D}] query, float[B, S, {Dkv}] key, float[B, S, {Dkv}] value) + => (float[B, S, {D}] attn, + float[B, {Hkv}, S, {Dh}] past_key, + float[B, {Hkv}, S, {Dh}] past_value) + {{ + total_seqlen = Shape (query) + total_seqlen_int32 = Cast (total_seqlen) + one = Constant () + total_seqlen_int32_minus_1 = Sub (total_seqlen_int32, one) + batchsize = Shape (query) + seqlens_k = Tile (total_seqlen_int32_minus_1, batchsize) + attn, past_key, past_value = com.microsoft.GroupQueryAttention + (query, key, value, , , seqlens_k, total_seqlen_int32) + }} + """ + ) + + def onnx_gqa_script(self): + op = onnxscript.opset18 + scale_factor = math.sqrt(math.sqrt(self.headsize)) + minval = torch.finfo(torch.float32).min + minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) + H = [self.num_heads] + Hkv = [self.kv_num_heads] + Dh = [self.headsize] + G = [self.num_groups] + minus_1 = [-1] # inferred dimension in Reshape op + + @onnxscript.script() + def gqa(query_BSD, key_BSD, value_BSD): + # Shapes used for Reshape ops. Note that we have a few different options on how shapes are + # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate + # existing dimension and one inferred dimension respectively). The following shapes are + # based on what is observed in Phi models generated by the exporter. + B = op.Shape(query_BSD, start=0, end=1) + S = op.Shape(query_BSD, start=1, end=2) + shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSD = op.Concat(B, S, minus_1, axis=0) + shape_BHkvGSDh = op.Concat(B, Hkv, G, S, Dh, axis=0) + shape_BHSDh = op.Concat(B, H, S, Dh, axis=0) + shape_SS = op.Concat(S, S, axis=0) + + # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. + # D is different for Q and K/V (not reflected in the names, unfortunately). + # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only + # one sequence length (S) for all Q, K, and V (with no cache). + query_BSHDh = op.Reshape(query_BSD, shape_BSHDh) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + key_BSHkvDh = op.Reshape(key_BSD, shape_BSHkvDh) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + key_BHkv1SDh = op.Unsqueeze(key_BHkvSDh, 2) + key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) + key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) + + value_BSHkvDh = op.Reshape(value_BSD, shape_BSHkvDh) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + value_BHkv1SDh = op.Unsqueeze(value_BHkvSDh, 2) + value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) + value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) + + # Generate a causal mask where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] + all_min = op.ConstantOfShape(shape_SS, value=minval_tp) + one = op.Constant(value_int=1) + mask = op.Trilu(all_min, one, upper=1) + + # Now, compute attention: + key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=scale_factor) + scaled_query = op.Div(query_BHSDh, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) + + # Reshape back to original shape: + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) + return attention_BSD, key_BHkvSDh, value_BHkvSDh + return gqa + + def test_equivalence(self): + inputs = self.random_inputs() + + ort_gqa_model = self.ort_gqa_model() + session = ort.InferenceSession( + ort_gqa_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs1 = session.run(None, inputs) + + gqa_model = self.onnx_gqa_script() + outputs2 = gqa_model(inputs["query"], inputs["key"], inputs["value"]) + + self.assertEqual(len(outputs1), len(outputs2)) + assert_allclose(outputs1, outputs2) + + +# past_seqlen = 0 +# total_seqlen = past_seqlen + S +# seqlens_k = np.array([total_seqlen-1], dtype=np.int32) +# total_seqlen_input = np.array(total_seqlen, dtype=np.int32) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 794f0dd8ac537ec4173acd3bb87289cfb5b09dd4 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 27 Mar 2025 13:45:47 -0700 Subject: [PATCH 14/28] Minor refactoring --- .../rewriter/ort_fusions/gqa_basic_test.py | 71 +++++++++++++------ 1 file changed, 48 insertions(+), 23 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py index a2fb23a819..d9969a0890 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py @@ -2,23 +2,27 @@ # Licensed under the MIT License. from __future__ import annotations +import math import unittest -import packaging.version - -import math +import numpy as np import onnx import onnxruntime as ort -import numpy as np -import onnxscript import torch -from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run + +from onnxscript import FLOAT, script +from onnxscript import opset18 as op +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose + +# This is a basic test that verifies that a proposed expanded computation is equivalent to +# ORT's GQA (for the specific configuration considered). # Simple GQA: no rotary embedding, no past key/value, no cos/sin cache, no seqlens/total_seqlen + class GQA1(unittest.TestCase): def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.batchsize = 2 self.seqlen = 8 self.kv_seqlen = self.seqlen @@ -43,7 +47,7 @@ def random_inputs(self): "value": value, } - def ort_gqa_model(self): + def fused_model(self): D = self.hidden_size Dkv = self.kv_hidden_size Dh = self.headsize @@ -57,20 +61,23 @@ def ort_gqa_model(self): float[B, {Hkv}, S, {Dh}] past_key, float[B, {Hkv}, S, {Dh}] past_value) {{ + # Generate seqlens_k and total_seqlen inputs for GQA: + # In this test case, all batch elements have same sequence length. + total_seqlen = Shape (query) total_seqlen_int32 = Cast (total_seqlen) one = Constant () total_seqlen_int32_minus_1 = Sub (total_seqlen_int32, one) batchsize = Shape (query) seqlens_k = Tile (total_seqlen_int32_minus_1, batchsize) + attn, past_key, past_value = com.microsoft.GroupQueryAttention (query, key, value, , , seqlens_k, total_seqlen_int32) }} """ ) - def onnx_gqa_script(self): - op = onnxscript.opset18 + def expanded_model_script(self): scale_factor = math.sqrt(math.sqrt(self.headsize)) minval = torch.finfo(torch.float32).min minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) @@ -78,16 +85,16 @@ def onnx_gqa_script(self): Hkv = [self.kv_num_heads] Dh = [self.headsize] G = [self.num_groups] - minus_1 = [-1] # inferred dimension in Reshape op + minus_1 = [-1] # inferred dimension in Reshape op - @onnxscript.script() - def gqa(query_BSD, key_BSD, value_BSD): + @script() + def gqa(query, key, value): # Shapes used for Reshape ops. Note that we have a few different options on how shapes are # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate # existing dimension and one inferred dimension respectively). The following shapes are # based on what is observed in Phi models generated by the exporter. - B = op.Shape(query_BSD, start=0, end=1) - S = op.Shape(query_BSD, start=1, end=2) + B = op.Shape(query, start=0, end=1) + S = op.Shape(query, start=1, end=2) shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) shape_BSD = op.Concat(B, S, minus_1, axis=0) @@ -99,16 +106,16 @@ def gqa(query_BSD, key_BSD, value_BSD): # D is different for Q and K/V (not reflected in the names, unfortunately). # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only # one sequence length (S) for all Q, K, and V (with no cache). - query_BSHDh = op.Reshape(query_BSD, shape_BSHDh) + query_BSHDh = op.Reshape(query, shape_BSHDh) query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) - key_BSHkvDh = op.Reshape(key_BSD, shape_BSHkvDh) + key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) key_BHkv1SDh = op.Unsqueeze(key_BHkvSDh, 2) key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) - value_BSHkvDh = op.Reshape(value_BSD, shape_BSHkvDh) + value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) value_BHkv1SDh = op.Unsqueeze(value_BHkvSDh, 2) value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) @@ -133,19 +140,37 @@ def gqa(query_BSD, key_BSD, value_BSD): attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) return attention_BSD, key_BHkvSDh, value_BHkvSDh + return gqa + def expanded_model(self): + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.headsize + Hkv = self.kv_num_heads + return self.expanded_model_script().to_model_proto( + input_types=(FLOAT["B", "S", D], FLOAT["B", "S", Dkv], FLOAT["B", "S", Dkv]), + output_types=( + FLOAT["B", "S", D], + FLOAT["B", Hkv, "S", Dh], + FLOAT["B", Hkv, "S", Dh], + ), + ) + def test_equivalence(self): inputs = self.random_inputs() - ort_gqa_model = self.ort_gqa_model() + fused_model = self.fused_model() session = ort.InferenceSession( - ort_gqa_model.SerializeToString(), providers=("CPUExecutionProvider",) + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) ) outputs1 = session.run(None, inputs) - gqa_model = self.onnx_gqa_script() - outputs2 = gqa_model(inputs["query"], inputs["key"], inputs["value"]) + expanded_model = self.expanded_model() + session = ort.InferenceSession( + expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs2 = session.run(None, inputs) self.assertEqual(len(outputs1), len(outputs2)) assert_allclose(outputs1, outputs2) @@ -157,4 +182,4 @@ def test_equivalence(self): # total_seqlen_input = np.array(total_seqlen, dtype=np.int32) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From a7ba01b68bea44a35599a7761a13ad102e5673b8 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 27 Mar 2025 14:12:00 -0700 Subject: [PATCH 15/28] Switch to script --- .../rewriter/ort_fusions/gqa_basic_test.py | 69 +++++++++---------- 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py index d9969a0890..95716c695d 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py @@ -10,10 +10,13 @@ import onnxruntime as ort import torch +import onnxscript from onnxscript import FLOAT, script from onnxscript import opset18 as op from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose +msft_op = onnxscript.values.Opset("com.microsoft", 1) + # This is a basic test that verifies that a proposed expanded computation is equivalent to # ORT's GQA (for the specific configuration considered). @@ -47,35 +50,34 @@ def random_inputs(self): "value": value, } - def fused_model(self): - D = self.hidden_size - Dkv = self.kv_hidden_size - Dh = self.headsize + def fused_model_script(self): H = self.num_heads Hkv = self.kv_num_heads - return onnx.parser.parse_model( - f""" - - GQA (float[B, S, {D}] query, float[B, S, {Dkv}] key, float[B, S, {Dkv}] value) - => (float[B, S, {D}] attn, - float[B, {Hkv}, S, {Dh}] past_key, - float[B, {Hkv}, S, {Dh}] past_value) - {{ - # Generate seqlens_k and total_seqlen inputs for GQA: - # In this test case, all batch elements have same sequence length. - - total_seqlen = Shape (query) - total_seqlen_int32 = Cast (total_seqlen) - one = Constant () - total_seqlen_int32_minus_1 = Sub (total_seqlen_int32, one) - batchsize = Shape (query) - seqlens_k = Tile (total_seqlen_int32_minus_1, batchsize) - - attn, past_key, past_value = com.microsoft.GroupQueryAttention - (query, key, value, , , seqlens_k, total_seqlen_int32) - }} - """ - ) + + @script() + def gqa(query, key, value): + # Generate seqlens_k and total_seqlen inputs for GQA: + # In this test case, all batch elements have same sequence length. + + total_seqlen = op.Shape(query, start=1, end=2) + total_seqlen_int32 = op.Cast(total_seqlen, to=6) + total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) + batchsize = op.Shape(query, start=0, end=1) + seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) + attn, past_key, past_value = msft_op.GroupQueryAttention( + query, + key, + value, + None, + None, + seqlens_k, + total_seqlen_int32, + num_heads=H, + kv_num_heads=Hkv, + ) + return attn, past_key, past_value + + return gqa def expanded_model_script(self): scale_factor = math.sqrt(math.sqrt(self.headsize)) @@ -143,12 +145,12 @@ def gqa(query, key, value): return gqa - def expanded_model(self): + def to_proto(self, model_script): D = self.hidden_size Dkv = self.kv_hidden_size Dh = self.headsize Hkv = self.kv_num_heads - return self.expanded_model_script().to_model_proto( + return model_script.to_model_proto( input_types=(FLOAT["B", "S", D], FLOAT["B", "S", Dkv], FLOAT["B", "S", Dkv]), output_types=( FLOAT["B", "S", D], @@ -160,13 +162,13 @@ def expanded_model(self): def test_equivalence(self): inputs = self.random_inputs() - fused_model = self.fused_model() + fused_model = self.to_proto(self.fused_model_script()) # self.fused_model() session = ort.InferenceSession( fused_model.SerializeToString(), providers=("CPUExecutionProvider",) ) outputs1 = session.run(None, inputs) - expanded_model = self.expanded_model() + expanded_model = self.to_proto(self.expanded_model_script()) # self.expanded_model() session = ort.InferenceSession( expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) ) @@ -176,10 +178,5 @@ def test_equivalence(self): assert_allclose(outputs1, outputs2) -# past_seqlen = 0 -# total_seqlen = past_seqlen + S -# seqlens_k = np.array([total_seqlen-1], dtype=np.int32) -# total_seqlen_input = np.array(total_seqlen, dtype=np.int32) - if __name__ == "__main__": unittest.main() From d68b0e7b3ced6a2fdb57efc1000fec3fd2e7c263 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 27 Mar 2025 14:14:46 -0700 Subject: [PATCH 16/28] Add blank line --- onnxscript/rewriter/ort_fusions/gqa_basic_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py index 95716c695d..895f7e54a5 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py @@ -64,6 +64,7 @@ def gqa(query, key, value): total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) batchsize = op.Shape(query, start=0, end=1) seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) + attn, past_key, past_value = msft_op.GroupQueryAttention( query, key, From df5c69d490a0ab4b7b10469195a2b7c04dfe00eb Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 28 Mar 2025 10:24:33 -0700 Subject: [PATCH 17/28] Add test case with past and rotary --- .../rewriter/ort_fusions/gqa_basic_test.py | 210 ++++++++++++++++++ 1 file changed, 210 insertions(+) diff --git a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py index 895f7e54a5..eb2c21d259 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py @@ -179,5 +179,215 @@ def test_equivalence(self): assert_allclose(outputs1, outputs2) +class GQA2(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.batchsize = 2 + self.seqlen = 8 + self.kv_seqlen = self.seqlen + self.past_seqlen = 16 + self.headsize = 16 + self.num_heads = 20 + self.kv_num_heads = 10 + self.hidden_size = self.headsize * self.num_heads + self.kv_hidden_size = self.headsize * self.kv_num_heads + self.num_groups = self.num_heads // self.kv_num_heads + + def random_inputs(self): + B = self.batchsize + S = self.seqlen + Skvp = self.past_seqlen + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.headsize + Hkv = self.kv_num_heads + total_seqlen = S + Skvp + max_seqlen = total_seqlen + query = np.random.rand(B, S, D).astype(np.float32) + key = np.random.rand(B, S, Dkv).astype(np.float32) + value = np.random.rand(B, S, Dkv).astype(np.float32) + past_key = np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32) + past_value = np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32) + cos = np.random.rand(max_seqlen, Dh // 2).astype(np.float32) + sin = np.random.rand(max_seqlen, Dh // 2).astype(np.float32) + + return { + "query": query, + "key": key, + "value": value, + "past_key": past_key, + "past_value": past_value, + "cos": cos, + "sin": sin, + } + + def fused_model_script(self): + H = self.num_heads + Hkv = self.kv_num_heads + + @script() + def gqa(query, key, value, past_key, past_value, cos, sin): + # Generate seqlens_k and total_seqlen inputs for GQA: + # In this test case, all batch elements have same sequence length. + + total_seqlen = op.Shape(query, start=1, end=2) + total_seqlen_int32 = op.Cast(total_seqlen, to=6) + total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) + batchsize = op.Shape(query, start=0, end=1) + seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) + + attn, past_key, past_value = msft_op.GroupQueryAttention( + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_seqlen_int32, + cos, + sin, + num_heads=H, + kv_num_heads=Hkv, + ) + return attn, past_key, past_value + + return gqa + + def expanded_model_script(self): + scale_factor = math.sqrt(math.sqrt(self.headsize)) + minval = torch.finfo(torch.float32).min + minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) + H = [self.num_heads] + Hkv = [self.kv_num_heads] + Dh = [self.headsize] + G = [self.num_groups] + minus_1 = [-1] # inferred dimension in Reshape op + + @script() + def gqa(query, key, value, past_key, past_value, cos, sin): + # Shapes used for Reshape ops. Note that we have a few different options on how shapes are + # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate + # existing dimension and one inferred dimension respectively). The following shapes are + # based on what is observed in Phi models generated by the exporter. + B = op.Shape(query, start=0, end=1) + S = op.Shape(query, start=1, end=2) + past_seq_length_1D = op.Shape(past_key, start=2, end=3) + past_seq_length = op.Squeeze(past_seq_length_1D, 0) + S_0D = op.Squeeze(S, 0) + total_seq_length = op.Add(past_seq_length, S_0D) + shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSD = op.Concat(B, S, minus_1, axis=0) + shape_BHkvGSDh = op.Concat(B, Hkv, G, S, Dh, axis=0) + shape_BHSDh = op.Concat(B, H, S, Dh, axis=0) + shape_SS = op.Concat(S, S, axis=0) + + # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. + # D is different for Q and K/V (not reflected in the names, unfortunately). + # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only + # one sequence length (S) for all Q, K, and V (with no cache). + query_BSHDh = op.Reshape(query, shape_BSHDh) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + + value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + + # Concat past and do rotary embedding + position_ids = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids, [0]) + position_ids_k = op.Unsqueeze(position_ids, [0]) + + query_BHSDh_rope = msft_op.RotaryEmbedding( + query_BHSDh, + position_ids_q, + cos, + sin, + ) + key_BHkvSDh_rope = msft_op.RotaryEmbedding( + key_BHkvSDh, + position_ids_k, + cos, + sin, + ) + key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + + value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + + # Now, expand from shared heads to all heads + key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) + key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) + key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) + + value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) + value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) + value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) + + # Generate a causal mask where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] + all_min = op.ConstantOfShape(shape_SS, value=minval_tp) + one = op.Constant(value_int=1) + mask = op.Trilu(all_min, one, upper=1) + + # Now, compute attention: + key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=scale_factor) + scaled_query = op.Div(query_BHSDh_rope, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) + + # Reshape back to original shape: + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) + return attention_BSD, key_BHkvSDh, value_BHkvSDh + + return gqa + + def to_proto(self, model_script): + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.headsize + Hkv = self.kv_num_heads + + return model_script.to_model_proto( + input_types=( + FLOAT["B", "S", D], + FLOAT["B", "S", Dkv], + FLOAT["B", "S", Dkv], + FLOAT["B", Hkv, "Skvp", Dh], + FLOAT["B", Hkv, "Skvp", Dh], + FLOAT["max_seqlen", Dh // 2], + FLOAT["max_seqlen", Dh // 2], + ), + output_types=( + FLOAT["B", "S", D], + FLOAT["B", Hkv, "S", Dh], + FLOAT["B", Hkv, "S", Dh], + ), + ) + + def test_equivalence(self): + inputs = self.random_inputs() + + fused_model = self.to_proto(self.fused_model_script()) # self.fused_model() + session = ort.InferenceSession( + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs1 = session.run(None, inputs) + + expanded_model = self.to_proto(self.expanded_model_script()) # self.expanded_model() + session = ort.InferenceSession( + expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs2 = session.run(None, inputs) + + self.assertEqual(len(outputs1), len(outputs2)) + assert_allclose(outputs1, outputs2) + + if __name__ == "__main__": unittest.main() From 045fc6f2deac304c01be83f0a1984554b2457bb9 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 28 Mar 2025 17:44:42 -0700 Subject: [PATCH 18/28] Add new test --- .../rewriter/ort_fusions/_test_utils.py | 7 ++- .../rewriter/ort_fusions/gqa_basic_test.py | 56 +++++++++++-------- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index 12bdcf2d4d..e76edd99c4 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -34,10 +34,15 @@ def ort_run(model_name: str, model, inputs): def assert_allclose(outputs, expected_outputs, rtol=1e-2, atol=1e-2): - for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)): + print("Comparing outputs...") + for i, (baseline_output, optimized_output) in reversed(list(enumerate(zip(expected_outputs, outputs)))) : try: + print(f"output {i} shapes: {baseline_output.shape}, {optimized_output.shape}") np.testing.assert_equal(baseline_output.shape, optimized_output.shape) np.testing.assert_allclose(baseline_output, optimized_output, rtol=rtol, atol=atol) except AssertionError as e: + diff_mask = ~np.isclose(baseline_output, optimized_output, rtol=rtol, atol=atol) + diff = np.where(diff_mask, 'X', ' ') + print(diff) print(f"Failed for output {i} with rtol={rtol} and atol={atol}\n{e}") raise diff --git a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py index eb2c21d259..ccc6f1c880 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py @@ -182,7 +182,7 @@ def test_equivalence(self): class GQA2(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.batchsize = 2 + self.batchsize = 1 self.seqlen = 8 self.kv_seqlen = self.seqlen self.past_seqlen = 16 @@ -229,9 +229,10 @@ def fused_model_script(self): def gqa(query, key, value, past_key, past_value, cos, sin): # Generate seqlens_k and total_seqlen inputs for GQA: # In this test case, all batch elements have same sequence length. - - total_seqlen = op.Shape(query, start=1, end=2) - total_seqlen_int32 = op.Cast(total_seqlen, to=6) + S = op.Shape(query, start=1, end=2) + past_seq_length = op.Shape(past_key, start=2, end=3) + total_seq_length = op.Add(past_seq_length, S) + total_seqlen_int32 = op.Cast(total_seq_length, to=6) total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) batchsize = op.Shape(query, start=0, end=1) seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) @@ -248,6 +249,7 @@ def gqa(query, key, value, past_key, past_value, cos, sin): sin, num_heads=H, kv_num_heads=Hkv, + do_rotary=1, ) return attn, past_key, past_value @@ -262,6 +264,7 @@ def expanded_model_script(self): Dh = [self.headsize] G = [self.num_groups] minus_1 = [-1] # inferred dimension in Reshape op + plus_1 = [1] @script() def gqa(query, key, value, past_key, past_value, cos, sin): @@ -271,16 +274,18 @@ def gqa(query, key, value, past_key, past_value, cos, sin): # based on what is observed in Phi models generated by the exporter. B = op.Shape(query, start=0, end=1) S = op.Shape(query, start=1, end=2) - past_seq_length_1D = op.Shape(past_key, start=2, end=3) - past_seq_length = op.Squeeze(past_seq_length_1D, 0) - S_0D = op.Squeeze(S, 0) - total_seq_length = op.Add(past_seq_length, S_0D) + past_seq_length = op.Shape(past_key, start=2, end=3) + total_seq_length = op.Add(past_seq_length, S) + # past_seq_length = op.Squeeze(past_seq_length_1D, [0]) + # S_0D = op.Squeeze(S,[0]) + shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) shape_BSD = op.Concat(B, S, minus_1, axis=0) - shape_BHkvGSDh = op.Concat(B, Hkv, G, S, Dh, axis=0) - shape_BHSDh = op.Concat(B, H, S, Dh, axis=0) - shape_SS = op.Concat(S, S, axis=0) + shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0) + + shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0) + shape_ST = op.Concat(S, total_seq_length, axis=0) # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. # D is different for Q and K/V (not reflected in the names, unfortunately). @@ -296,19 +301,22 @@ def gqa(query, key, value, past_key, past_value, cos, sin): value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) # Concat past and do rotary embedding - position_ids = op.Range(past_seq_length, total_seq_length, 1) - position_ids_q = op.Unsqueeze(position_ids, [0]) - position_ids_k = op.Unsqueeze(position_ids, [0]) + position_ids_1d = op.Range(past_seq_length, total_seq_length, 1) + position_ids_2d = op.Unsqueeze(position_ids_1d, [0]) + tile_B_1 = op.Concat(B, plus_1, axis=0) + position_ids = op.Tile(position_ids_2d, tile_B_1) + # position_ids_q = op.Tile(position_ids, B) # op.Unsqueeze(position_ids, [0]) + # position_ids_k = op.Tile(position_ids, B) # op.Unsqueeze(position_ids, [0]) query_BHSDh_rope = msft_op.RotaryEmbedding( query_BHSDh, - position_ids_q, + position_ids, cos, sin, ) key_BHkvSDh_rope = msft_op.RotaryEmbedding( key_BHkvSDh, - position_ids_k, + position_ids, cos, sin, ) @@ -326,9 +334,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin): value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) # Generate a causal mask where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] - all_min = op.ConstantOfShape(shape_SS, value=minval_tp) + all_min = op.ConstantOfShape(shape_ST, value=minval_tp) + past_0D = op.Squeeze(past_seq_length) one = op.Constant(value_int=1) - mask = op.Trilu(all_min, one, upper=1) + past_plus_1 = op.Add(past_0D, one) + mask = op.Trilu(all_min, past_plus_1, upper=1) # Now, compute attention: key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) @@ -343,7 +353,7 @@ def gqa(query, key, value, past_key, past_value, cos, sin): # Reshape back to original shape: attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) - return attention_BSD, key_BHkvSDh, value_BHkvSDh + return attention_BSD, key_seq_BHkvSkvDh, value_seq_BHkvSkvDh return gqa @@ -365,8 +375,8 @@ def to_proto(self, model_script): ), output_types=( FLOAT["B", "S", D], - FLOAT["B", Hkv, "S", Dh], - FLOAT["B", Hkv, "S", Dh], + FLOAT["B", Hkv, "ST", Dh], + FLOAT["B", Hkv, "ST", Dh], ), ) @@ -385,8 +395,8 @@ def test_equivalence(self): ) outputs2 = session.run(None, inputs) - self.assertEqual(len(outputs1), len(outputs2)) - assert_allclose(outputs1, outputs2) + self.assertEqual(len(outputs2), len(outputs1)) + assert_allclose(outputs2, outputs1) if __name__ == "__main__": From edf289f7e7a93d9facc9cbe20a64f4b1eaaa4c8e Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sat, 29 Mar 2025 11:09:04 -0700 Subject: [PATCH 19/28] Cleanup test case --- .../rewriter/ort_fusions/_test_utils.py | 6 +- .../rewriter/ort_fusions/gqa_basic_test.py | 73 +++++++++++-------- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index e76edd99c4..314618123c 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -35,14 +35,16 @@ def ort_run(model_name: str, model, inputs): def assert_allclose(outputs, expected_outputs, rtol=1e-2, atol=1e-2): print("Comparing outputs...") - for i, (baseline_output, optimized_output) in reversed(list(enumerate(zip(expected_outputs, outputs)))) : + for i, (baseline_output, optimized_output) in reversed( + list(enumerate(zip(expected_outputs, outputs))) + ): try: print(f"output {i} shapes: {baseline_output.shape}, {optimized_output.shape}") np.testing.assert_equal(baseline_output.shape, optimized_output.shape) np.testing.assert_allclose(baseline_output, optimized_output, rtol=rtol, atol=atol) except AssertionError as e: diff_mask = ~np.isclose(baseline_output, optimized_output, rtol=rtol, atol=atol) - diff = np.where(diff_mask, 'X', ' ') + diff = np.where(diff_mask, "X", " ") print(diff) print(f"Failed for output {i} with rtol={rtol} and atol={atol}\n{e}") raise diff --git a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py index ccc6f1c880..88dbd0d32e 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py @@ -139,7 +139,7 @@ def gqa(query, key, value): attn_weight = op.Softmax(masked_attn_score, axis=-1) attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) - # Reshape back to original shape: + # Reshape back to BSD format attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) return attention_BSD, key_BHkvSDh, value_BHkvSDh @@ -163,13 +163,13 @@ def to_proto(self, model_script): def test_equivalence(self): inputs = self.random_inputs() - fused_model = self.to_proto(self.fused_model_script()) # self.fused_model() + fused_model = self.to_proto(self.fused_model_script()) session = ort.InferenceSession( fused_model.SerializeToString(), providers=("CPUExecutionProvider",) ) outputs1 = session.run(None, inputs) - expanded_model = self.to_proto(self.expanded_model_script()) # self.expanded_model() + expanded_model = self.to_proto(self.expanded_model_script()) session = ort.InferenceSession( expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) ) @@ -182,17 +182,43 @@ def test_equivalence(self): class GQA2(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.batchsize = 1 + # Config parameters + self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1? self.seqlen = 8 self.kv_seqlen = self.seqlen self.past_seqlen = 16 self.headsize = 16 self.num_heads = 20 self.kv_num_heads = 10 + + # Computed config parameters self.hidden_size = self.headsize * self.num_heads self.kv_hidden_size = self.headsize * self.kv_num_heads self.num_groups = self.num_heads // self.kv_num_heads + # Abbreviations + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.headsize + Hkv = self.kv_num_heads + + # Input/output types have some parameters as dynamic (even though the + # test case instance has specific values above). + self.input_types = ( + FLOAT["B", "S", D], # query + FLOAT["B", "S", Dkv], # key + FLOAT["B", "S", Dkv], # value + FLOAT["B", Hkv, "Skvp", Dh], # past_key + FLOAT["B", Hkv, "Skvp", Dh], # past_value + FLOAT["max_seqlen", Dh // 2], # cos + FLOAT["max_seqlen", Dh // 2], # sin + ) + self.output_types = ( + FLOAT["B", "S", D], # attention + FLOAT["B", Hkv, "ST", Dh], # present_key + FLOAT["B", Hkv, "ST", Dh], # present_value + ) + def random_inputs(self): B = self.batchsize S = self.seqlen @@ -333,7 +359,8 @@ def gqa(query, key, value, past_key, past_value, cos, sin): value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) - # Generate a causal mask where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] + # Generate causal mask: + # where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] all_min = op.ConstantOfShape(shape_ST, value=minval_tp) past_0D = op.Squeeze(past_seq_length) one = op.Constant(value_int=1) @@ -350,46 +377,30 @@ def gqa(query, key, value, past_key, past_value, cos, sin): attn_weight = op.Softmax(masked_attn_score, axis=-1) attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) - # Reshape back to original shape: + # Reshape back to BSD format attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) + return attention_BSD, key_seq_BHkvSkvDh, value_seq_BHkvSkvDh return gqa - def to_proto(self, model_script): - D = self.hidden_size - Dkv = self.kv_hidden_size - Dh = self.headsize - Hkv = self.kv_num_heads - - return model_script.to_model_proto( - input_types=( - FLOAT["B", "S", D], - FLOAT["B", "S", Dkv], - FLOAT["B", "S", Dkv], - FLOAT["B", Hkv, "Skvp", Dh], - FLOAT["B", Hkv, "Skvp", Dh], - FLOAT["max_seqlen", Dh // 2], - FLOAT["max_seqlen", Dh // 2], - ), - output_types=( - FLOAT["B", "S", D], - FLOAT["B", Hkv, "ST", Dh], - FLOAT["B", Hkv, "ST", Dh], - ), - ) - def test_equivalence(self): inputs = self.random_inputs() - fused_model = self.to_proto(self.fused_model_script()) # self.fused_model() + fused_model = self.fused_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) session = ort.InferenceSession( fused_model.SerializeToString(), providers=("CPUExecutionProvider",) ) outputs1 = session.run(None, inputs) - expanded_model = self.to_proto(self.expanded_model_script()) # self.expanded_model() + expanded_model = self.expanded_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) session = ort.InferenceSession( expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) ) From 1efdb269b4832419ae84b9d626ea3fbc2aef3fb2 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sat, 29 Mar 2025 11:37:14 -0700 Subject: [PATCH 20/28] Remove debug print --- onnxscript/rewriter/ort_fusions/_test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index cd168e2b3d..e1a6be338d 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -36,7 +36,6 @@ def ort_run(model_name: str, model, inputs): def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4): for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)): try: - print(f"output {i} shapes: {baseline_output.shape}, {optimized_output.shape}") np.testing.assert_equal(baseline_output.shape, optimized_output.shape) np.testing.assert_allclose(baseline_output, optimized_output, rtol=rtol, atol=atol) except AssertionError as e: From 13a71c038412461258f744f17ab43b67d39aa3af Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 30 Mar 2025 13:40:14 -0700 Subject: [PATCH 21/28] Minor cleanup --- .../rewriter/ort_fusions/gqa_basic_test.py | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py index 88dbd0d32e..c53f4a6880 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py @@ -36,6 +36,22 @@ def __init__(self, *args, **kwargs): self.kv_hidden_size = self.headsize * self.kv_num_heads self.num_groups = self.num_heads // self.kv_num_heads + # Abbreviations + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.headsize + Hkv = self.kv_num_heads + + query_type = FLOAT["B", "S", D] + key_type = FLOAT["B", "S", Dkv] + value_type = FLOAT["B", "S", Dkv] + self.input_types = (query_type, key_type, value_type) + + attention_type = FLOAT["B", "S", D] + past_key_type = FLOAT["B", Hkv, "S", Dh] + past_value_type = FLOAT["B", Hkv, "S", Dh] + self.output_types = (attention_type, past_key_type, past_value_type) + def random_inputs(self): B = self.batchsize S = self.seqlen @@ -146,30 +162,22 @@ def gqa(query, key, value): return gqa - def to_proto(self, model_script): - D = self.hidden_size - Dkv = self.kv_hidden_size - Dh = self.headsize - Hkv = self.kv_num_heads - return model_script.to_model_proto( - input_types=(FLOAT["B", "S", D], FLOAT["B", "S", Dkv], FLOAT["B", "S", Dkv]), - output_types=( - FLOAT["B", "S", D], - FLOAT["B", Hkv, "S", Dh], - FLOAT["B", Hkv, "S", Dh], - ), - ) - def test_equivalence(self): inputs = self.random_inputs() - fused_model = self.to_proto(self.fused_model_script()) + fused_model = self.fused_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) session = ort.InferenceSession( fused_model.SerializeToString(), providers=("CPUExecutionProvider",) ) outputs1 = session.run(None, inputs) - expanded_model = self.to_proto(self.expanded_model_script()) + expanded_model = self.expanded_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) session = ort.InferenceSession( expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) ) From e3dadc9a8777d24c67a5009f7dfe4820445eaf60 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 1 Apr 2025 16:23:07 -0700 Subject: [PATCH 22/28] Add causal mask pattern --- onnxscript/rewriter/ort_fusions/gqa.py | 40 ++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 1d1f451fd3..4724a9873f 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -45,6 +45,39 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) return True +def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111): + seq_len = op.Shape(input_ids, end=2, start=1) + seq_len_0D = op.Squeeze(seq_len) + + past_seq_len = op.Shape(past_kv_cache, end=3, start=2) + past_seq_len_0D = op.Squeeze(past_seq_len) + + total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) + total_seq_len = op.Reshape(total_seq_len_0D, [-1]) + + # The Phi modeling code generates the following +1 as the target-length, which seems + # unnecessary in this context. But using it for pattern-matching against + # generated onnx model. + total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1]) + + current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) + mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) + mask_all_min = op.Expand(-3.4028235e38, mask_shape) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1) + current_range_as_column = op.Reshape(current_range, [-1, 1]) + boolean_mask = op.Greater(total_range_as_row, current_range_as_column) + float_0_1_mask = op.Cast(boolean_mask, to=1) + float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) + mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1]) + mask_B1ST_plus = op.Expand(mask_4d, shape_B111) + + # Get rid of the extra +1 added above: total_seq_len is enough, no + # need for total_seq_len+1. + mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1]) + return mask_B1ST + + class GroupQueryAttention(pattern.RewriteRuleClassBase): def __init__(self): super().__init__("GQA") @@ -56,14 +89,15 @@ def pattern( query_BSD, key_BSDkv, value_BSDkv, - mask, past_key, past_value, - # position_ids, + input_ids, past_seq_length, total_seq_length, cos, sin, + some_kv_cache, + shape_B111, ): # Reshape query from (B, S, D) to (B, S, H, D/H) query_BSHDh = op.Reshape( @@ -138,6 +172,8 @@ def pattern( value_seq_BHkvGSkvDh, _allow_other_inputs=True, _outputs=["value_seq_BHSkvDh"] ) + mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111) + attention_BHSDh = op.SDPA( query_BHSDh_rope, key_seq_BHDhSkv, From b7a13980604c7c5841e970c4f7e2ce4bbde76924 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 2 Apr 2025 12:43:13 -0700 Subject: [PATCH 23/28] Add test case --- onnxscript/optimizer/_constant_folding.py | 12 +- onnxscript/rewriter/ort_fusions/gqa.py | 36 +- .../rewriter/ort_fusions/gqa_basic_test.py | 422 --------------- onnxscript/rewriter/ort_fusions/gqa_test.py | 498 +++++++++++++++++- 4 files changed, 517 insertions(+), 451 deletions(-) delete mode 100644 onnxscript/rewriter/ort_fusions/gqa_basic_test.py diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 034724a3a8..cf856d5422 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -405,17 +405,13 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: shape = _get_input(node, 1) if input is None or shape is None: return None + input_shape = input.shape - if input_shape is None: - return None - # input_shape_dims = list(input_shape.dims) - # if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in input_shape_dims): - # return None shape_value = state.get_shape_value(shape) - if shape_value is None: + + if shape_value is None or input_shape is None: return None - # target_shape_dims = list(shape_value.dims) - # if input_shape_dims == target_shape_dims: + # No need to check for special values like -1, 0, etc. here if _same_shape(input_shape, shape_value): return op.Identity(input) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 4724a9873f..ccd12ddd07 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -196,13 +196,12 @@ def check( query_BSD, key_BSDkv, value_BSDkv, - mask, past_key, past_value, query_BHSDh_rope, key_BHkvSDh_rope, - # query_BSHDh, - # key_BSHkvDh, + query_BSHDh, + key_BSHkvDh, # value_BSHkvDh, **_, ): @@ -239,6 +238,16 @@ def check( # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: # or check Reshape's shape-input value + result = pattern.MatchResult() + num_heads = _ir_utils.get_dim(query_BSHDh, 2) + kv_num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) + if not isinstance(num_heads, int): + return result.fail("Unable to determine num_heads value", query_BSHDh) + if not isinstance(kv_num_heads, int): + return result.fail("Unable to determine kv_num_heads value", key_BSHkvDh) + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + # Rotary embedding attributes query_rotary_attributes = query_BHSDh_rope.producer().attributes key_rotary_attributes = key_BHkvSDh_rope.producer().attributes @@ -258,18 +267,11 @@ def rewrite( value_BSDkv, past_key, past_value, - query_BSHDh, - key_BSHkvDh, total_seq_length, cos, sin, **_, ): - num_heads = _ir_utils.get_dim(query_BSHDh, 2) - kv_num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) - if not isinstance(num_heads, int) or not isinstance(kv_num_heads, int): - return None - total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) one_0D = op.Constant(value_int=1) one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32) @@ -288,10 +290,10 @@ def rewrite( cos, sin, # mask, # TODO: this is not a valid input for GQA - num_heads=num_heads, - kv_num_heads=kv_num_heads, + num_heads=self.num_heads, + kv_num_heads=self.kv_num_heads, do_rotary=1, - rotary_interleaved=self._interleaved.value, + rotary_interleaved=self._interleaved, # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap _domain="com.microsoft", _outputs=3, @@ -303,8 +305,10 @@ def rewrite( gqa_rules = pattern.RewriteRuleSet([_rule1]) -def fuse_gqa(model: ir.Model) -> int: +def fuse_gqa(model: ir.Model, debug: bool = False) -> int: count = gqa_rules.apply_to_model(model) - print(f"GQA count: {count}") - # remove_unused_nodes(model) + if debug and count == 0: + tracer = pattern.MatchingTracer() + gqa_rules.apply_to_model(model, tracer=tracer) + tracer.report() return count diff --git a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py b/onnxscript/rewriter/ort_fusions/gqa_basic_test.py deleted file mode 100644 index c53f4a6880..0000000000 --- a/onnxscript/rewriter/ort_fusions/gqa_basic_test.py +++ /dev/null @@ -1,422 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import math -import unittest - -import numpy as np -import onnx -import onnxruntime as ort -import torch - -import onnxscript -from onnxscript import FLOAT, script -from onnxscript import opset18 as op -from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose - -msft_op = onnxscript.values.Opset("com.microsoft", 1) - -# This is a basic test that verifies that a proposed expanded computation is equivalent to -# ORT's GQA (for the specific configuration considered). - -# Simple GQA: no rotary embedding, no past key/value, no cos/sin cache, no seqlens/total_seqlen - - -class GQA1(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.batchsize = 2 - self.seqlen = 8 - self.kv_seqlen = self.seqlen - self.headsize = 16 - self.num_heads = 20 - self.kv_num_heads = 10 - self.hidden_size = self.headsize * self.num_heads - self.kv_hidden_size = self.headsize * self.kv_num_heads - self.num_groups = self.num_heads // self.kv_num_heads - - # Abbreviations - D = self.hidden_size - Dkv = self.kv_hidden_size - Dh = self.headsize - Hkv = self.kv_num_heads - - query_type = FLOAT["B", "S", D] - key_type = FLOAT["B", "S", Dkv] - value_type = FLOAT["B", "S", Dkv] - self.input_types = (query_type, key_type, value_type) - - attention_type = FLOAT["B", "S", D] - past_key_type = FLOAT["B", Hkv, "S", Dh] - past_value_type = FLOAT["B", Hkv, "S", Dh] - self.output_types = (attention_type, past_key_type, past_value_type) - - def random_inputs(self): - B = self.batchsize - S = self.seqlen - D = self.hidden_size - Dkv = self.kv_hidden_size - query = np.random.rand(B, S, D).astype(np.float32) - key = np.random.rand(B, S, Dkv).astype(np.float32) - value = np.random.rand(B, S, Dkv).astype(np.float32) - return { - "query": query, - "key": key, - "value": value, - } - - def fused_model_script(self): - H = self.num_heads - Hkv = self.kv_num_heads - - @script() - def gqa(query, key, value): - # Generate seqlens_k and total_seqlen inputs for GQA: - # In this test case, all batch elements have same sequence length. - - total_seqlen = op.Shape(query, start=1, end=2) - total_seqlen_int32 = op.Cast(total_seqlen, to=6) - total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) - batchsize = op.Shape(query, start=0, end=1) - seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) - - attn, past_key, past_value = msft_op.GroupQueryAttention( - query, - key, - value, - None, - None, - seqlens_k, - total_seqlen_int32, - num_heads=H, - kv_num_heads=Hkv, - ) - return attn, past_key, past_value - - return gqa - - def expanded_model_script(self): - scale_factor = math.sqrt(math.sqrt(self.headsize)) - minval = torch.finfo(torch.float32).min - minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) - H = [self.num_heads] - Hkv = [self.kv_num_heads] - Dh = [self.headsize] - G = [self.num_groups] - minus_1 = [-1] # inferred dimension in Reshape op - - @script() - def gqa(query, key, value): - # Shapes used for Reshape ops. Note that we have a few different options on how shapes are - # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate - # existing dimension and one inferred dimension respectively). The following shapes are - # based on what is observed in Phi models generated by the exporter. - B = op.Shape(query, start=0, end=1) - S = op.Shape(query, start=1, end=2) - shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) - shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) - shape_BSD = op.Concat(B, S, minus_1, axis=0) - shape_BHkvGSDh = op.Concat(B, Hkv, G, S, Dh, axis=0) - shape_BHSDh = op.Concat(B, H, S, Dh, axis=0) - shape_SS = op.Concat(S, S, axis=0) - - # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. - # D is different for Q and K/V (not reflected in the names, unfortunately). - # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only - # one sequence length (S) for all Q, K, and V (with no cache). - query_BSHDh = op.Reshape(query, shape_BSHDh) - query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) - - key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) - key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) - key_BHkv1SDh = op.Unsqueeze(key_BHkvSDh, 2) - key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) - key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) - - value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) - value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) - value_BHkv1SDh = op.Unsqueeze(value_BHkvSDh, 2) - value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) - value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) - - # Generate a causal mask where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] - all_min = op.ConstantOfShape(shape_SS, value=minval_tp) - one = op.Constant(value_int=1) - mask = op.Trilu(all_min, one, upper=1) - - # Now, compute attention: - key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) - divisor = op.Constant(value_float=scale_factor) - scaled_query = op.Div(query_BHSDh, divisor) - scaled_key = op.Div(key_transposed, divisor) - attn_score = op.MatMul(scaled_query, scaled_key) - masked_attn_score = op.Add(attn_score, mask) - attn_weight = op.Softmax(masked_attn_score, axis=-1) - attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) - - # Reshape back to BSD format - attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) - attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) - return attention_BSD, key_BHkvSDh, value_BHkvSDh - - return gqa - - def test_equivalence(self): - inputs = self.random_inputs() - - fused_model = self.fused_model_script().to_model_proto( - input_types=self.input_types, - output_types=self.output_types, - ) - session = ort.InferenceSession( - fused_model.SerializeToString(), providers=("CPUExecutionProvider",) - ) - outputs1 = session.run(None, inputs) - - expanded_model = self.expanded_model_script().to_model_proto( - input_types=self.input_types, - output_types=self.output_types, - ) - session = ort.InferenceSession( - expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) - ) - outputs2 = session.run(None, inputs) - - self.assertEqual(len(outputs1), len(outputs2)) - assert_allclose(outputs1, outputs2) - - -class GQA2(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Config parameters - self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1? - self.seqlen = 8 - self.kv_seqlen = self.seqlen - self.past_seqlen = 16 - self.headsize = 16 - self.num_heads = 20 - self.kv_num_heads = 10 - - # Computed config parameters - self.hidden_size = self.headsize * self.num_heads - self.kv_hidden_size = self.headsize * self.kv_num_heads - self.num_groups = self.num_heads // self.kv_num_heads - - # Abbreviations - D = self.hidden_size - Dkv = self.kv_hidden_size - Dh = self.headsize - Hkv = self.kv_num_heads - - # Input/output types have some parameters as dynamic (even though the - # test case instance has specific values above). - self.input_types = ( - FLOAT["B", "S", D], # query - FLOAT["B", "S", Dkv], # key - FLOAT["B", "S", Dkv], # value - FLOAT["B", Hkv, "Skvp", Dh], # past_key - FLOAT["B", Hkv, "Skvp", Dh], # past_value - FLOAT["max_seqlen", Dh // 2], # cos - FLOAT["max_seqlen", Dh // 2], # sin - ) - self.output_types = ( - FLOAT["B", "S", D], # attention - FLOAT["B", Hkv, "ST", Dh], # present_key - FLOAT["B", Hkv, "ST", Dh], # present_value - ) - - def random_inputs(self): - B = self.batchsize - S = self.seqlen - Skvp = self.past_seqlen - D = self.hidden_size - Dkv = self.kv_hidden_size - Dh = self.headsize - Hkv = self.kv_num_heads - total_seqlen = S + Skvp - max_seqlen = total_seqlen - query = np.random.rand(B, S, D).astype(np.float32) - key = np.random.rand(B, S, Dkv).astype(np.float32) - value = np.random.rand(B, S, Dkv).astype(np.float32) - past_key = np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32) - past_value = np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32) - cos = np.random.rand(max_seqlen, Dh // 2).astype(np.float32) - sin = np.random.rand(max_seqlen, Dh // 2).astype(np.float32) - - return { - "query": query, - "key": key, - "value": value, - "past_key": past_key, - "past_value": past_value, - "cos": cos, - "sin": sin, - } - - def fused_model_script(self): - H = self.num_heads - Hkv = self.kv_num_heads - - @script() - def gqa(query, key, value, past_key, past_value, cos, sin): - # Generate seqlens_k and total_seqlen inputs for GQA: - # In this test case, all batch elements have same sequence length. - S = op.Shape(query, start=1, end=2) - past_seq_length = op.Shape(past_key, start=2, end=3) - total_seq_length = op.Add(past_seq_length, S) - total_seqlen_int32 = op.Cast(total_seq_length, to=6) - total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) - batchsize = op.Shape(query, start=0, end=1) - seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) - - attn, past_key, past_value = msft_op.GroupQueryAttention( - query, - key, - value, - past_key, - past_value, - seqlens_k, - total_seqlen_int32, - cos, - sin, - num_heads=H, - kv_num_heads=Hkv, - do_rotary=1, - ) - return attn, past_key, past_value - - return gqa - - def expanded_model_script(self): - scale_factor = math.sqrt(math.sqrt(self.headsize)) - minval = torch.finfo(torch.float32).min - minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) - H = [self.num_heads] - Hkv = [self.kv_num_heads] - Dh = [self.headsize] - G = [self.num_groups] - minus_1 = [-1] # inferred dimension in Reshape op - plus_1 = [1] - - @script() - def gqa(query, key, value, past_key, past_value, cos, sin): - # Shapes used for Reshape ops. Note that we have a few different options on how shapes are - # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate - # existing dimension and one inferred dimension respectively). The following shapes are - # based on what is observed in Phi models generated by the exporter. - B = op.Shape(query, start=0, end=1) - S = op.Shape(query, start=1, end=2) - past_seq_length = op.Shape(past_key, start=2, end=3) - total_seq_length = op.Add(past_seq_length, S) - # past_seq_length = op.Squeeze(past_seq_length_1D, [0]) - # S_0D = op.Squeeze(S,[0]) - - shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) - shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) - shape_BSD = op.Concat(B, S, minus_1, axis=0) - shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0) - - shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0) - shape_ST = op.Concat(S, total_seq_length, axis=0) - - # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. - # D is different for Q and K/V (not reflected in the names, unfortunately). - # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only - # one sequence length (S) for all Q, K, and V (with no cache). - query_BSHDh = op.Reshape(query, shape_BSHDh) - query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) - - key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) - key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) - - value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) - value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) - - # Concat past and do rotary embedding - position_ids_1d = op.Range(past_seq_length, total_seq_length, 1) - position_ids_2d = op.Unsqueeze(position_ids_1d, [0]) - tile_B_1 = op.Concat(B, plus_1, axis=0) - position_ids = op.Tile(position_ids_2d, tile_B_1) - # position_ids_q = op.Tile(position_ids, B) # op.Unsqueeze(position_ids, [0]) - # position_ids_k = op.Tile(position_ids, B) # op.Unsqueeze(position_ids, [0]) - - query_BHSDh_rope = msft_op.RotaryEmbedding( - query_BHSDh, - position_ids, - cos, - sin, - ) - key_BHkvSDh_rope = msft_op.RotaryEmbedding( - key_BHkvSDh, - position_ids, - cos, - sin, - ) - key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) - - value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) - - # Now, expand from shared heads to all heads - key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) - key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) - key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) - - value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) - value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) - value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) - - # Generate causal mask: - # where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] - all_min = op.ConstantOfShape(shape_ST, value=minval_tp) - past_0D = op.Squeeze(past_seq_length) - one = op.Constant(value_int=1) - past_plus_1 = op.Add(past_0D, one) - mask = op.Trilu(all_min, past_plus_1, upper=1) - - # Now, compute attention: - key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) - divisor = op.Constant(value_float=scale_factor) - scaled_query = op.Div(query_BHSDh_rope, divisor) - scaled_key = op.Div(key_transposed, divisor) - attn_score = op.MatMul(scaled_query, scaled_key) - masked_attn_score = op.Add(attn_score, mask) - attn_weight = op.Softmax(masked_attn_score, axis=-1) - attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) - - # Reshape back to BSD format - attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) - attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) - - return attention_BSD, key_seq_BHkvSkvDh, value_seq_BHkvSkvDh - - return gqa - - def test_equivalence(self): - inputs = self.random_inputs() - - fused_model = self.fused_model_script().to_model_proto( - input_types=self.input_types, - output_types=self.output_types, - ) - session = ort.InferenceSession( - fused_model.SerializeToString(), providers=("CPUExecutionProvider",) - ) - outputs1 = session.run(None, inputs) - - expanded_model = self.expanded_model_script().to_model_proto( - input_types=self.input_types, - output_types=self.output_types, - ) - session = ort.InferenceSession( - expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) - ) - outputs2 = session.run(None, inputs) - - self.assertEqual(len(outputs2), len(outputs1)) - assert_allclose(outputs2, outputs1) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 77cd5f346c..f1d380f8f1 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -1,11 +1,499 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations -"""Testing GQA fusion.""" +import math +import unittest -from onnxscript import script +import numpy as np +import onnx +import onnxruntime as ort +import torch +import onnxscript +import onnxscript.ir as ir +import onnxscript.ir.passes.common.shape_inference as shape_inference +import onnxscript.optimizer +from onnxscript import FLOAT, script +from onnxscript import opset18 as op +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose +from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa +from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa -@script() -def _gqa_prompt_script(query, key, value): - pass +msft_op = onnxscript.values.Opset("com.microsoft", 1) + +# Test cases for GroupQueryAttention (GQA) fusion. + +# Simple GQA: no rotary embedding, no past key/value, no cos/sin cache, no seqlens/total_seqlen + + +class GQA1(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.batchsize = 2 + self.seqlen = 8 + self.kv_seqlen = self.seqlen + self.head_size = 16 + self.num_heads = 20 + self.kv_num_heads = 10 + self.hidden_size = self.head_size * self.num_heads + self.kv_hidden_size = self.head_size * self.kv_num_heads + self.num_groups = self.num_heads // self.kv_num_heads + + # Abbreviations + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.head_size + Hkv = self.kv_num_heads + + query_type = FLOAT["B", "S", D] + key_type = FLOAT["B", "S", Dkv] + value_type = FLOAT["B", "S", Dkv] + self.input_types = (query_type, key_type, value_type) + + attention_type = FLOAT["B", "S", D] + past_key_type = FLOAT["B", Hkv, "S", Dh] + past_value_type = FLOAT["B", Hkv, "S", Dh] + self.output_types = (attention_type, past_key_type, past_value_type) + + def random_inputs(self): + B = self.batchsize + S = self.seqlen + D = self.hidden_size + Dkv = self.kv_hidden_size + query = np.random.rand(B, S, D).astype(np.float32) + key = np.random.rand(B, S, Dkv).astype(np.float32) + value = np.random.rand(B, S, Dkv).astype(np.float32) + return { + "query": query, + "key": key, + "value": value, + } + + def fused_model_script(self): + H = self.num_heads + Hkv = self.kv_num_heads + + @script() + def gqa(query, key, value): + # Generate seqlens_k and total_seqlen inputs for GQA: + # In this test case, all batch elements have same sequence length. + + total_seqlen = op.Shape(query, start=1, end=2) + total_seqlen_int32 = op.Cast(total_seqlen, to=6) + total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) + batchsize = op.Shape(query, start=0, end=1) + seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) + + attn, past_key, past_value = msft_op.GroupQueryAttention( + query, + key, + value, + None, + None, + seqlens_k, + total_seqlen_int32, + num_heads=H, + kv_num_heads=Hkv, + ) + return attn, past_key, past_value + + return gqa + + def expanded_model_script(self): + scale_factor = math.sqrt(math.sqrt(self.head_size)) + minval = torch.finfo(torch.float32).min + minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) + H = [self.num_heads] + Hkv = [self.kv_num_heads] + Dh = [self.head_size] + G = [self.num_groups] + minus_1 = [-1] # inferred dimension in Reshape op + + @script() + def gqa(query, key, value): + # Shapes used for Reshape ops. Note that we have a few different options on how shapes are + # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate + # existing dimension and one inferred dimension respectively). The following shapes are + # based on what is observed in Phi models generated by the exporter. + B = op.Shape(query, start=0, end=1) + S = op.Shape(query, start=1, end=2) + shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSD = op.Concat(B, S, minus_1, axis=0) + shape_BHkvGSDh = op.Concat(B, Hkv, G, S, Dh, axis=0) + shape_BHSDh = op.Concat(B, H, S, Dh, axis=0) + shape_SS = op.Concat(S, S, axis=0) + + # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. + # D is different for Q and K/V (not reflected in the names, unfortunately). + # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only + # one sequence length (S) for all Q, K, and V (with no cache). + query_BSHDh = op.Reshape(query, shape_BSHDh) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + key_BHkv1SDh = op.Unsqueeze(key_BHkvSDh, 2) + key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) + key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) + + value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + value_BHkv1SDh = op.Unsqueeze(value_BHkvSDh, 2) + value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) + value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) + + # Generate a causal mask where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] + all_min = op.ConstantOfShape(shape_SS, value=minval_tp) + one = op.Constant(value_int=1) + mask = op.Trilu(all_min, one, upper=1) + + # Now, compute attention: + key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=scale_factor) + scaled_query = op.Div(query_BHSDh, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) + + # Reshape back to BSD format + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) + return attention_BSD, key_BHkvSDh, value_BHkvSDh + + return gqa + + def test_equivalence(self): + inputs = self.random_inputs() + + fused_model = self.fused_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs1 = session.run(None, inputs) + + expanded_model = self.expanded_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs2 = session.run(None, inputs) + + self.assertEqual(len(outputs1), len(outputs2)) + assert_allclose(outputs1, outputs2) + + +class GQA2(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Config parameters + self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1? + self.seqlen = 8 + self.kv_seqlen = self.seqlen + self.past_seqlen = 16 + self.head_size = 16 + self.num_heads = 20 + self.kv_num_heads = 10 + + # Computed config parameters + self.hidden_size = self.head_size * self.num_heads + self.kv_hidden_size = self.head_size * self.kv_num_heads + assert (self.num_heads % self.kv_num_heads) == 0, ( + "num_heads must be divisible by kv_num_heads" + ) + self.num_groups = self.num_heads // self.kv_num_heads + + # Abbreviations + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.head_size + Hkv = self.kv_num_heads + + # Input/output types have some dimensions as dynamic (even though the + # test case instance has specific values above). + self.input_types = ( + FLOAT["B", "S", D], # query + FLOAT["B", "S", Dkv], # key + FLOAT["B", "S", Dkv], # value + FLOAT["B", Hkv, "Skvp", Dh], # past_key + FLOAT["B", Hkv, "Skvp", Dh], # past_value + FLOAT["max_seqlen", Dh // 2], # cos + FLOAT["max_seqlen", Dh // 2], # sin + ) + self.output_types = ( + FLOAT["B", "S", D], # attention + FLOAT["B", Hkv, "ST", Dh], # present_key + FLOAT["B", Hkv, "ST", Dh], # present_value + ) + + def random_inputs(self): + B = self.batchsize + S = self.seqlen + Skvp = self.past_seqlen + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.head_size + Hkv = self.kv_num_heads + total_seqlen = S + Skvp + max_seqlen = total_seqlen + query = np.random.rand(B, S, D).astype(np.float32) + key = np.random.rand(B, S, Dkv).astype(np.float32) + value = np.random.rand(B, S, Dkv).astype(np.float32) + past_key = np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32) + past_value = np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32) + cos = np.random.rand(max_seqlen, Dh // 2).astype(np.float32) + sin = np.random.rand(max_seqlen, Dh // 2).astype(np.float32) + + return { + "query": query, + "key": key, + "value": value, + "past_key": past_key, + "past_value": past_value, + "cos": cos, + "sin": sin, + } + + def fused_model_script(self): + H = self.num_heads + Hkv = self.kv_num_heads + + @script() + def gqa(query, key, value, past_key, past_value, cos, sin): + # Generate seqlens_k and total_seqlen inputs for GQA: + # In this test case, all batch elements have same sequence length. + S = op.Shape(query, start=1, end=2) + past_seq_length = op.Shape(past_key, start=2, end=3) + total_seq_length = op.Add(past_seq_length, S) + total_seqlen_int32 = op.Cast(total_seq_length, to=6) + total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) + batchsize = op.Shape(query, start=0, end=1) + seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) + + attn, past_key, past_value = msft_op.GroupQueryAttention( + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_seqlen_int32, + cos, + sin, + num_heads=H, + kv_num_heads=Hkv, + do_rotary=1, + ) + return attn, past_key, past_value + + return gqa + + def expanded_model_script(self): + scale_factor = math.sqrt(math.sqrt(self.head_size)) + minval = torch.finfo(torch.float32).min + minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) + H = [self.num_heads] + Hkv = [self.kv_num_heads] + Dh = [self.head_size] + G = [self.num_groups] + minus_1 = [-1] # inferred dimension in Reshape op + plus_1 = [1] + + @script() + def gqa(query, key, value, past_key, past_value, cos, sin): + # Shapes used for Reshape ops. Note that we have a few different options on how shapes are + # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate + # existing dimension and one inferred dimension respectively). The following shapes are + # based on what is observed in Phi models generated by the exporter. + B = op.Shape(query, start=0, end=1) + S = op.Shape(query, start=1, end=2) + past_seq_length = op.Shape(past_key, start=2, end=3) + total_seq_length = op.Add(past_seq_length, S) + # past_seq_length = op.Squeeze(past_seq_length_1D, [0]) + # S_0D = op.Squeeze(S,[0]) + + shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSD = op.Concat(B, S, minus_1, axis=0) + shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0) + + shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0) + shape_ST = op.Concat(S, total_seq_length, axis=0) + + # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. + # D is different for Q and K/V (not reflected in the names, unfortunately). + # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only + # one sequence length (S) for all Q, K, and V (with no cache). + query_BSHDh = op.Reshape(query, shape_BSHDh) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + + value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + + # Concat past and do rotary embedding + position_ids_1d = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids_1d, [0]) + position_ids_k = op.Unsqueeze(position_ids_1d, [0]) + + # Note: The above code pattern for position-ids is from exported Phi model. + # However, for use with ORT's RotaryEmbedding it needs the following for batchsize > 1 + # But we currently target batchsize=1 since GQA requires it when there is a past key/value. + # + # position_ids_2d = op.Unsqueeze(position_ids_1d, [0]) + # tile_B_1 = op.Concat(B, plus_1, axis=0) + # position_ids = op.Tile(position_ids_2d, tile_B_1) + + query_BHSDh_rope = msft_op.RotaryEmbedding( + query_BHSDh, + position_ids_q, + cos, + sin, + ) + key_BHkvSDh_rope = msft_op.RotaryEmbedding( + key_BHkvSDh, + position_ids_k, + cos, + sin, + ) + key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + + value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + + # Now, expand from shared heads to all heads + key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) + key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) + key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) + + value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) + value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) + value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) + + # Generate causal mask: + # where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] + seq_len = op.Shape(query, end=2, start=1) + seq_len_0D = op.Squeeze(seq_len) + + past_seq_len_0D = op.Squeeze(past_seq_length) + + total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) + total_seq_len = op.Reshape(total_seq_len_0D, [-1]) + + # The Phi modeling code generates the following +1 as the target-length, which seems + # unnecessary in this context. But duplicating same logic here. + total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1]) + + current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) + mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) + min_val = op.Constant(value=minval_tp) + mask_all_min = op.Expand(min_val, mask_shape) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1) + current_range_as_column = op.Reshape(current_range, [-1, 1]) + boolean_mask = op.Greater(total_range_as_row, current_range_as_column) + float_0_1_mask = op.Cast(boolean_mask, to=1) + float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) + mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1]) + shape_B111 = op.Concat(B, plus_1, plus_1, plus_1, axis=0) + mask_B1ST_plus = op.Expand(mask_4d, shape_B111) + + # Get rid of the extra +1 added above: total_seq_len is enough, no + # need for total_seq_len+1. + mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1]) + + # Now, compute attention: + key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=scale_factor) + scaled_query = op.Div(query_BHSDh_rope, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask_B1ST) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) + + # Reshape back to BSD format + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) + + return attention_BSD, key_seq_BHkvSkvDh, value_seq_BHkvSkvDh + + return gqa + + def test_equivalence(self): + inputs = self.random_inputs() + + fused_model = self.fused_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs1 = session.run(None, inputs) + + expanded_model = self.expanded_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs2 = session.run(None, inputs) + + self.assertEqual(len(outputs2), len(outputs1)) + assert_allclose(outputs2, outputs1) + + # Shape inference doesn't handle ORT contrib ops: so, provide type/shape + # for outputs of RotaryEmbedding op. + query_BHSDh_rope_value_info = onnx.helper.make_tensor_value_info( + "query_BHSDh_rope", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.seqlen, self.head_size], + ) + key_BHkvSDh_rope_value_info = onnx.helper.make_tensor_value_info( + "key_BHkvSDh_rope", + onnx.TensorProto.FLOAT, + ["B", self.kv_num_heads, self.seqlen, self.head_size], + ) + query_BSHDh_value_info = onnx.helper.make_tensor_value_info( + "query_BSHDh", + onnx.TensorProto.FLOAT, + ["B", self.seqlen, self.num_heads, self.head_size], + ) + key_BSHkvDh_value_info = onnx.helper.make_tensor_value_info( + "key_BSHkvDh", + onnx.TensorProto.FLOAT, + ["B", self.seqlen, self.kv_num_heads, self.head_size], + ) + expanded_model.graph.value_info.extend( + [ + query_BHSDh_rope_value_info, + key_BHkvSDh_rope_value_info, + query_BSHDh_value_info, + key_BSHkvDh_value_info, + ] + ) + + expanded_model_ir = ir.serde.from_proto(expanded_model) + inferred_model = shape_inference.infer_shapes(expanded_model_ir) + onnxscript.optimizer.optimize(inferred_model) + + count = fuse_sdpa(inferred_model, debug=True) + self.assertEqual(count, 1) + + count = fuse_gqa(inferred_model, debug=True) + self.assertEqual(count, 1) + + +if __name__ == "__main__": + unittest.main() From 3bdc0b2ec14147d311fbffc226cada0a28558818 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 2 Apr 2025 15:07:39 -0700 Subject: [PATCH 24/28] Complete GQA tests --- onnxscript/rewriter/ort_fusions/gqa_test.py | 96 ++++++++++++--------- 1 file changed, 57 insertions(+), 39 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index f1d380f8f1..7de28359f2 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -70,7 +70,7 @@ def random_inputs(self): "value": value, } - def fused_model_script(self): + def target_model_script(self): H = self.num_heads Hkv = self.kv_num_heads @@ -100,7 +100,7 @@ def gqa(query, key, value): return gqa - def expanded_model_script(self): + def source_model_script(self): scale_factor = math.sqrt(math.sqrt(self.head_size)) minval = torch.finfo(torch.float32).min minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) @@ -169,26 +169,26 @@ def gqa(query, key, value): def test_equivalence(self): inputs = self.random_inputs() - fused_model = self.fused_model_script().to_model_proto( + source_model = self.source_model_script().to_model_proto( input_types=self.input_types, output_types=self.output_types, ) session = ort.InferenceSession( - fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + source_model.SerializeToString(), providers=("CPUExecutionProvider",) ) - outputs1 = session.run(None, inputs) + source_model_outputs = session.run(None, inputs) - expanded_model = self.expanded_model_script().to_model_proto( + target_model = self.target_model_script().to_model_proto( input_types=self.input_types, output_types=self.output_types, ) session = ort.InferenceSession( - expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) + target_model.SerializeToString(), providers=("CPUExecutionProvider",) ) - outputs2 = session.run(None, inputs) + target_model_outputs = session.run(None, inputs) - self.assertEqual(len(outputs1), len(outputs2)) - assert_allclose(outputs1, outputs2) + self.assertEqual(len(target_model_outputs), len(source_model_outputs)) + assert_allclose(target_model_outputs, source_model_outputs) class GQA2(unittest.TestCase): @@ -244,25 +244,18 @@ def random_inputs(self): Hkv = self.kv_num_heads total_seqlen = S + Skvp max_seqlen = total_seqlen - query = np.random.rand(B, S, D).astype(np.float32) - key = np.random.rand(B, S, Dkv).astype(np.float32) - value = np.random.rand(B, S, Dkv).astype(np.float32) - past_key = np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32) - past_value = np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32) - cos = np.random.rand(max_seqlen, Dh // 2).astype(np.float32) - sin = np.random.rand(max_seqlen, Dh // 2).astype(np.float32) return { - "query": query, - "key": key, - "value": value, - "past_key": past_key, - "past_value": past_value, - "cos": cos, - "sin": sin, + "query": np.random.rand(B, S, D).astype(np.float32), + "key": np.random.rand(B, S, Dkv).astype(np.float32), + "value": np.random.rand(B, S, Dkv).astype(np.float32), + "past_key": np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32), + "past_value": np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32), + "cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + "sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), } - def fused_model_script(self): + def target_model_script(self): H = self.num_heads Hkv = self.kv_num_heads @@ -296,7 +289,7 @@ def gqa(query, key, value, past_key, past_value, cos, sin): return gqa - def expanded_model_script(self): + def source_model_script(self): scale_factor = math.sqrt(math.sqrt(self.head_size)) minval = torch.finfo(torch.float32).min minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) @@ -430,31 +423,47 @@ def gqa(query, key, value, past_key, past_value, cos, sin): return gqa def test_equivalence(self): + """Test that the source and target models produce the same outputs.""" inputs = self.random_inputs() - fused_model = self.fused_model_script().to_model_proto( + source_model = self.source_model_script().to_model_proto( input_types=self.input_types, output_types=self.output_types, ) session = ort.InferenceSession( - fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + source_model.SerializeToString(), providers=("CPUExecutionProvider",) ) - outputs1 = session.run(None, inputs) + source_model_outputs = session.run(None, inputs) - expanded_model = self.expanded_model_script().to_model_proto( + target_model = self.target_model_script().to_model_proto( input_types=self.input_types, output_types=self.output_types, ) session = ort.InferenceSession( - expanded_model.SerializeToString(), providers=("CPUExecutionProvider",) + target_model.SerializeToString(), providers=("CPUExecutionProvider",) ) - outputs2 = session.run(None, inputs) + target_model_outputs = session.run(None, inputs) - self.assertEqual(len(outputs2), len(outputs1)) - assert_allclose(outputs2, outputs1) + self.assertEqual(len(source_model_outputs), len(target_model_outputs)) + assert_allclose(source_model_outputs, target_model_outputs) - # Shape inference doesn't handle ORT contrib ops: so, provide type/shape - # for outputs of RotaryEmbedding op. + def test_fusion(self): + """Test that GQA fusion is successful on source model and produces an equivalent model.""" + inputs = self.random_inputs() + + source_model = self.source_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + source_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + source_model_outputs = session.run(None, inputs) + + # Some shapes need to be present in input model for fusion to be successful. + # (i) Shape inference doesn't handle handle ORT contrib ops. + # (ii) TODO: investigate if Reshape(..., ["B", "S", -1, Dh]) handled precisely + # by shape inference. query_BHSDh_rope_value_info = onnx.helper.make_tensor_value_info( "query_BHSDh_rope", onnx.TensorProto.FLOAT, @@ -475,7 +484,7 @@ def test_equivalence(self): onnx.TensorProto.FLOAT, ["B", self.seqlen, self.kv_num_heads, self.head_size], ) - expanded_model.graph.value_info.extend( + source_model.graph.value_info.extend( [ query_BHSDh_rope_value_info, key_BHkvSDh_rope_value_info, @@ -484,8 +493,8 @@ def test_equivalence(self): ] ) - expanded_model_ir = ir.serde.from_proto(expanded_model) - inferred_model = shape_inference.infer_shapes(expanded_model_ir) + source_model_ir = ir.serde.from_proto(source_model) + inferred_model = shape_inference.infer_shapes(source_model_ir) onnxscript.optimizer.optimize(inferred_model) count = fuse_sdpa(inferred_model, debug=True) @@ -494,6 +503,15 @@ def test_equivalence(self): count = fuse_gqa(inferred_model, debug=True) self.assertEqual(count, 1) + fused_model = ir.serde.to_proto(inferred_model) + session = ort.InferenceSession( + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs3 = session.run(None, inputs) + + self.assertEqual(len(outputs3), len(source_model_outputs)) + assert_allclose(outputs3, source_model_outputs) + if __name__ == "__main__": unittest.main() From 9545e5fcdbdaf6b8543eb53e62462bc262dd03b6 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 2 Apr 2025 15:52:17 -0700 Subject: [PATCH 25/28] Cleanup --- onnxscript/rewriter/ort_fusions/gqa_test.py | 204 ++------------------ 1 file changed, 16 insertions(+), 188 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 7de28359f2..6e8b1c8775 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -22,176 +22,10 @@ msft_op = onnxscript.values.Opset("com.microsoft", 1) -# Test cases for GroupQueryAttention (GQA) fusion. +# Test case for GroupQueryAttention (GQA) fusion. -# Simple GQA: no rotary embedding, no past key/value, no cos/sin cache, no seqlens/total_seqlen - -class GQA1(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.batchsize = 2 - self.seqlen = 8 - self.kv_seqlen = self.seqlen - self.head_size = 16 - self.num_heads = 20 - self.kv_num_heads = 10 - self.hidden_size = self.head_size * self.num_heads - self.kv_hidden_size = self.head_size * self.kv_num_heads - self.num_groups = self.num_heads // self.kv_num_heads - - # Abbreviations - D = self.hidden_size - Dkv = self.kv_hidden_size - Dh = self.head_size - Hkv = self.kv_num_heads - - query_type = FLOAT["B", "S", D] - key_type = FLOAT["B", "S", Dkv] - value_type = FLOAT["B", "S", Dkv] - self.input_types = (query_type, key_type, value_type) - - attention_type = FLOAT["B", "S", D] - past_key_type = FLOAT["B", Hkv, "S", Dh] - past_value_type = FLOAT["B", Hkv, "S", Dh] - self.output_types = (attention_type, past_key_type, past_value_type) - - def random_inputs(self): - B = self.batchsize - S = self.seqlen - D = self.hidden_size - Dkv = self.kv_hidden_size - query = np.random.rand(B, S, D).astype(np.float32) - key = np.random.rand(B, S, Dkv).astype(np.float32) - value = np.random.rand(B, S, Dkv).astype(np.float32) - return { - "query": query, - "key": key, - "value": value, - } - - def target_model_script(self): - H = self.num_heads - Hkv = self.kv_num_heads - - @script() - def gqa(query, key, value): - # Generate seqlens_k and total_seqlen inputs for GQA: - # In this test case, all batch elements have same sequence length. - - total_seqlen = op.Shape(query, start=1, end=2) - total_seqlen_int32 = op.Cast(total_seqlen, to=6) - total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) - batchsize = op.Shape(query, start=0, end=1) - seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) - - attn, past_key, past_value = msft_op.GroupQueryAttention( - query, - key, - value, - None, - None, - seqlens_k, - total_seqlen_int32, - num_heads=H, - kv_num_heads=Hkv, - ) - return attn, past_key, past_value - - return gqa - - def source_model_script(self): - scale_factor = math.sqrt(math.sqrt(self.head_size)) - minval = torch.finfo(torch.float32).min - minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) - H = [self.num_heads] - Hkv = [self.kv_num_heads] - Dh = [self.head_size] - G = [self.num_groups] - minus_1 = [-1] # inferred dimension in Reshape op - - @script() - def gqa(query, key, value): - # Shapes used for Reshape ops. Note that we have a few different options on how shapes are - # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate - # existing dimension and one inferred dimension respectively). The following shapes are - # based on what is observed in Phi models generated by the exporter. - B = op.Shape(query, start=0, end=1) - S = op.Shape(query, start=1, end=2) - shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) - shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) - shape_BSD = op.Concat(B, S, minus_1, axis=0) - shape_BHkvGSDh = op.Concat(B, Hkv, G, S, Dh, axis=0) - shape_BHSDh = op.Concat(B, H, S, Dh, axis=0) - shape_SS = op.Concat(S, S, axis=0) - - # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. - # D is different for Q and K/V (not reflected in the names, unfortunately). - # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only - # one sequence length (S) for all Q, K, and V (with no cache). - query_BSHDh = op.Reshape(query, shape_BSHDh) - query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) - - key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) - key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) - key_BHkv1SDh = op.Unsqueeze(key_BHkvSDh, 2) - key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) - key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) - - value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) - value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) - value_BHkv1SDh = op.Unsqueeze(value_BHkvSDh, 2) - value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) - value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) - - # Generate a causal mask where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] - all_min = op.ConstantOfShape(shape_SS, value=minval_tp) - one = op.Constant(value_int=1) - mask = op.Trilu(all_min, one, upper=1) - - # Now, compute attention: - key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) - divisor = op.Constant(value_float=scale_factor) - scaled_query = op.Div(query_BHSDh, divisor) - scaled_key = op.Div(key_transposed, divisor) - attn_score = op.MatMul(scaled_query, scaled_key) - masked_attn_score = op.Add(attn_score, mask) - attn_weight = op.Softmax(masked_attn_score, axis=-1) - attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) - - # Reshape back to BSD format - attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) - attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) - return attention_BSD, key_BHkvSDh, value_BHkvSDh - - return gqa - - def test_equivalence(self): - inputs = self.random_inputs() - - source_model = self.source_model_script().to_model_proto( - input_types=self.input_types, - output_types=self.output_types, - ) - session = ort.InferenceSession( - source_model.SerializeToString(), providers=("CPUExecutionProvider",) - ) - source_model_outputs = session.run(None, inputs) - - target_model = self.target_model_script().to_model_proto( - input_types=self.input_types, - output_types=self.output_types, - ) - session = ort.InferenceSession( - target_model.SerializeToString(), providers=("CPUExecutionProvider",) - ) - target_model_outputs = session.run(None, inputs) - - self.assertEqual(len(target_model_outputs), len(source_model_outputs)) - assert_allclose(target_model_outputs, source_model_outputs) - - -class GQA2(unittest.TestCase): +class GQAFusionTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Config parameters @@ -212,10 +46,15 @@ def __init__(self, *args, **kwargs): self.num_groups = self.num_heads // self.kv_num_heads # Abbreviations + B = self.batchsize + S = self.seqlen + P = self.past_seqlen D = self.hidden_size Dkv = self.kv_hidden_size Dh = self.head_size Hkv = self.kv_num_heads + total_seqlen = S + P + max_seqlen = total_seqlen # Input/output types have some dimensions as dynamic (even though the # test case instance has specific values above). @@ -223,34 +62,23 @@ def __init__(self, *args, **kwargs): FLOAT["B", "S", D], # query FLOAT["B", "S", Dkv], # key FLOAT["B", "S", Dkv], # value - FLOAT["B", Hkv, "Skvp", Dh], # past_key - FLOAT["B", Hkv, "Skvp", Dh], # past_value + FLOAT["B", Hkv, "P", Dh], # past_key + FLOAT["B", Hkv, "P", Dh], # past_value FLOAT["max_seqlen", Dh // 2], # cos FLOAT["max_seqlen", Dh // 2], # sin ) self.output_types = ( FLOAT["B", "S", D], # attention - FLOAT["B", Hkv, "ST", Dh], # present_key - FLOAT["B", Hkv, "ST", Dh], # present_value + FLOAT["B", Hkv, "T", Dh], # present_key + FLOAT["B", Hkv, "T", Dh], # present_value ) - def random_inputs(self): - B = self.batchsize - S = self.seqlen - Skvp = self.past_seqlen - D = self.hidden_size - Dkv = self.kv_hidden_size - Dh = self.head_size - Hkv = self.kv_num_heads - total_seqlen = S + Skvp - max_seqlen = total_seqlen - - return { + self.inputs = { "query": np.random.rand(B, S, D).astype(np.float32), "key": np.random.rand(B, S, Dkv).astype(np.float32), "value": np.random.rand(B, S, Dkv).astype(np.float32), - "past_key": np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32), - "past_value": np.random.rand(B, Hkv, Skvp, Dh).astype(np.float32), + "past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32), "cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), "sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), } @@ -424,7 +252,7 @@ def gqa(query, key, value, past_key, past_value, cos, sin): def test_equivalence(self): """Test that the source and target models produce the same outputs.""" - inputs = self.random_inputs() + inputs = self.inputs source_model = self.source_model_script().to_model_proto( input_types=self.input_types, @@ -449,7 +277,7 @@ def test_equivalence(self): def test_fusion(self): """Test that GQA fusion is successful on source model and produces an equivalent model.""" - inputs = self.random_inputs() + inputs = self.inputs source_model = self.source_model_script().to_model_proto( input_types=self.input_types, From 97bb1c27d6e046f3528f91b7b1a061d161bbba39 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 2 Apr 2025 16:10:22 -0700 Subject: [PATCH 26/28] Address copilot fixes --- onnxscript/rewriter/ort_fusions/gqa.py | 3 +-- onnxscript/rewriter/ort_fusions/gqa_test.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index ccd12ddd07..ae0ab9496c 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -80,8 +80,7 @@ def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111): class GroupQueryAttention(pattern.RewriteRuleClassBase): def __init__(self): - super().__init__("GQA") - self.remove_nodes = False + super().__init__("GQA", remove_nodes=False) def pattern( self, diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 6e8b1c8775..4f8f9ab8ba 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -147,7 +147,6 @@ def gqa(query, key, value, past_key, past_value, cos, sin): shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0) shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0) - shape_ST = op.Concat(S, total_seq_length, axis=0) # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. # D is different for Q and K/V (not reflected in the names, unfortunately). From c8bbb025f710fa7ecdf38b8670bb97b2c6691004 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 2 Apr 2025 22:09:40 -0700 Subject: [PATCH 27/28] Add checks --- onnxscript/rewriter/ort_fusions/gqa.py | 44 ++++++++++---------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index ae0ab9496c..38576f3fb2 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -201,36 +201,24 @@ def check( key_BHkvSDh_rope, query_BSHDh, key_BSHkvDh, - # value_BSHkvDh, **_, ): - # bindings: dict[str, Dim] = {} - - # def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - # return not _check_shape(bindings, val, dims) - - # if no_match(query_BSD, ["B", "S", "D"]): - # return False - # if no_match(key_BSDkv, ["B", "Skv", "D"]): - # return False - # if no_match(value_BSDkv, ["B", "Skv", "D"]): - # return False - - # if no_match(past_key, ["B", "H", "Spast", "Dh"]): - # return False - # if no_match(past_value, ["B", "H", "Spast", "Dv"]): - # return False - # if no_match(query_BSHDh, ["B", "S", "H", "Dh"]): - # return False - # if no_match(key_BSHkvDh, ["B", "S", "H", "Dh"]): - # return False - # if no_match(value_BSHkvDh, ["B", "S", "H", "Dh"]): - # return False - - # TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St) - # But this also, unforunately, depends on ORT version. - # TODO: check that mask is causal. Latest ORT is adding support for - # non-causal masks, but not yet for all EPs. + bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _check_shape(bindings, val, dims) + + if no_match(query_BSD, ["B", "S", "D"]): + return False + if no_match(key_BSDkv, ["B", "S", "Dkv"]): + return False + if no_match(value_BSDkv, ["B", "S", "Dkv"]): + return False + + if no_match(past_key, ["B", "Hkv", "P", "Dh"]): + return False + if no_match(past_value, ["B", "Hkv", "P", "Dv"]): + return False # TODO: verify Reshapes: # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: From 43b336844d4818be08773c350fb54c2bcabebb91 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 3 Apr 2025 14:51:00 -0700 Subject: [PATCH 28/28] Minor cleanup --- onnxscript/rewriter/ort_fusions/gqa.py | 56 +++++++++----------------- 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 38576f3fb2..477bfed6a2 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -23,10 +23,7 @@ D: input embedding dimension (hidden size) = H * Dh Dkv: key/value hidden size = Hkv * Dh -Skv: key/value sequence length (after concatenation of past and current key/value) - -In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh). -The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh). +T: total sequence length (after concatenation of past and current key/value) """ Dim = Union[int, ir.SymbolicDim] @@ -99,31 +96,18 @@ def pattern( shape_B111, ): # Reshape query from (B, S, D) to (B, S, H, D/H) - query_BSHDh = op.Reshape( - query_BSD, - _allow_other_inputs=True, - _allow_other_attributes=True, - _outputs=["query_BSHDh"], - ) + query_BSHDh = op.Reshape(query_BSD, _allow_other_inputs=True, _outputs=["query_BSHDh"]) # Transpose from (B, S, H, D/H) to (B, H, S, D/H) query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H) - key_BSHkvDh = op.Reshape( - key_BSDkv, - _allow_other_inputs=True, - _allow_other_attributes=True, - _outputs=["key_BSHkvDh"], - ) + key_BSHkvDh = op.Reshape(key_BSDkv, _allow_other_inputs=True, _outputs=["key_BSHkvDh"]) # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H) value_BSHkvDh = op.Reshape( - value_BSDkv, - _allow_other_inputs=True, - _allow_other_attributes=True, - _outputs=["value_BSHkvDh"], + value_BSDkv, _allow_other_inputs=True, _outputs=["value_BSHkvDh"] ) # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) @@ -150,33 +134,31 @@ def pattern( ) # Concatenate past_key cache and current key, expand across heads - # that share key/value and transpose to enable dot-product attention computation. + # that share key/value. - key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) - key_seq_BHkv1SkvDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) - key_seq_BHkvGSkvDh = op.Expand(key_seq_BHkv1SkvDh, _allow_other_inputs=True) - key_seq_BHSkvDh = op.Reshape( - key_seq_BHkvGSkvDh, _allow_other_inputs=True, _outputs=["key_seq_BHSkvDh"] - ) - key_seq_BHDhSkv = op.Transpose( - key_seq_BHSkvDh, _allow_other_inputs=True, _outputs=["key_seq_BHDhSkv"] + key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2) + key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, _allow_other_inputs=True) + key_seq_BHTDh = op.Reshape( + key_seq_BHkvGTDh, _allow_other_inputs=True, _outputs=["key_seq_BHTDh"] ) # Concatenate past_value cache and current value, expand across heads # that share key/value. - value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) - value_seq_BHkv1SkvDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) - value_seq_BHkvGSkvDh = op.Expand(value_seq_BHkv1SkvDh, _allow_other_inputs=True) - value_seq_BHSkvDh = op.Reshape( - value_seq_BHkvGSkvDh, _allow_other_inputs=True, _outputs=["value_seq_BHSkvDh"] + value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2) + value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, _allow_other_inputs=True) + value_seq_BHTDh = op.Reshape( + value_seq_BHkvGTDh, _allow_other_inputs=True, _outputs=["value_seq_BHTDh"] ) mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111) + key_seq_BHDhT = op.Transpose(key_seq_BHTDh, perm=[0, 1, 3, 2]) attention_BHSDh = op.SDPA( query_BHSDh_rope, - key_seq_BHDhSkv, - value_seq_BHSkvDh, + key_seq_BHDhT, + value_seq_BHTDh, mask, _domain="ai.onnxruntime.fusion", ) @@ -187,7 +169,7 @@ def pattern( attention_BSD = op.Reshape( attention_BSHDh, _allow_other_inputs=True, _outputs=["attention_BSD"] ) - return attention_BSD, key_seq_BHkvSkvDh, value_seq_BHkvSkvDh + return attention_BSD, key_seq_BHkvTDh, value_seq_BHkvTDh def check( self,