1+ import torch
2+
3+ torch .backends .cuda .matmul .allow_tf32 = False
4+ torch .backends .cudnn .allow_tf32 = False
5+
6+ import torch .nn as nn
7+ from torch .testing ._internal .common_utils import run_tests
8+ from torch .testing ._internal .common_utils import TestCase
9+ from transformers .models .gemma3 .modeling_gemma3 import Gemma3Attention , Gemma3DecoderLayer
10+ from transformers .models .gemma3 .configuration_gemma3 import Gemma3Config
11+ from transformers import AutoModelForCausalLM
12+ import torch_tensorrt
13+ from contextlib import nullcontext
14+ import argparse
15+ import sys
16+ import os
17+
18+ # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
19+ sys .path .append (os .path .join (os .path .dirname (__file__ ), '..' ))
20+ from register_sdpa import *
21+
22+
23+ ATOL = 1e-5
24+ RTOL = 1e-5
25+
26+
27+ gemma3_model_name = "google/gemma-3-1b-it"
28+ gemma3_model = AutoModelForCausalLM .from_pretrained (
29+ gemma3_model_name ,
30+ use_cache = False ,
31+ attn_implementation = "sdpa" ,
32+ num_hidden_layers = 1 ,
33+ ).eval ().cuda ()
34+ GEMMA3_CONFIG = gemma3_model .config
35+
36+ def print_diff (tensor1 , tensor2 , prefix = "" ):
37+ """
38+ Print the diff between two tensors
39+ """
40+ print (f"[{ prefix } ] Diff between tensor1 and tensor2: { torch .mean (torch .abs (tensor1 - tensor2 ))} " )
41+
42+
43+ def test_gemma3_attention (args ):
44+
45+ DTYPE = torch .float32
46+ if args .precision == "FP16" :
47+ DTYPE = torch .float16
48+ elif args .precision == "BF16" :
49+ DTYPE = torch .bfloat16
50+
51+ # Set precision specific flags
52+ use_fp32_acc = False
53+ use_explicit_typing = False
54+ if args .precision == "FP16" :
55+ enabled_precisions = {torch .float32 }
56+ use_fp32_acc = True
57+ use_explicit_typing = True
58+ elif args .precision == "BF16" :
59+ enabled_precisions = {torch .bfloat16 }
60+ use_fp32_acc = False
61+ else :
62+ enabled_precisions = {torch .float32 }
63+
64+ model = gemma3_model .model .layers [0 ].self_attn .to (DTYPE )
65+
66+ # gemma3
67+ hidden_states = torch .randn ((1 , 5 , 1152 ), dtype = DTYPE ).cuda ()
68+ position_embeddings = (torch .randn ((1 , 5 , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , 5 , 256 ), dtype = DTYPE ).cuda ())
69+
70+ pyt_output = model (hidden_states , position_embeddings , None )
71+ seq_len = torch .export .Dim ("seq_len" , min = 2 , max = 2176 )
72+ dynamic_shapes = ({1 : seq_len }, ({1 : seq_len }, {1 : seq_len }), None )
73+ ep = torch .export .export (model , (hidden_states , position_embeddings , None ), dynamic_shapes = dynamic_shapes )
74+
75+ with (torch_tensorrt .logging .debug () if args .debug else nullcontext ()):
76+ trt_model = torch_tensorrt .dynamo .compile (ep ,
77+ inputs = [hidden_states , position_embeddings , None ],
78+ enabled_precisions = enabled_precisions ,
79+ disable_tf32 = True ,
80+ use_fp32_acc = use_fp32_acc ,
81+ use_explicit_typing = use_explicit_typing ,
82+ debug = args .debug )
83+ trt_output = trt_model (hidden_states , position_embeddings , None )
84+
85+ if isinstance (pyt_output , tuple ):
86+ print_diff (pyt_output [0 ], trt_output [0 ], "Diff b/w pyt and trt" )
87+ assert torch .allclose (pyt_output [0 ], trt_output [0 ], atol = ATOL , rtol = RTOL )
88+ else :
89+ print_diff (pyt_output , trt_output , "Diff b/w pyt and trt" )
90+ assert torch .allclose (pyt_output , trt_output , atol = ATOL , rtol = RTOL )
91+
92+ def test_gemma3_attention_with_static_cache (args ):
93+
94+ import static_cache_v2
95+ DTYPE = torch .float32
96+ model = gemma3_model .model .layers [0 ].self_attn .to (DTYPE )
97+
98+ # Inputs
99+ ISL = 2048
100+ NUM_TOKENS = 128
101+ OSL = ISL + NUM_TOKENS
102+ hidden_states = torch .randn ((1 , ISL , 1152 ), dtype = DTYPE ).cuda ()
103+ position_embeddings = (torch .randn ((1 , ISL , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , ISL , 256 ), dtype = DTYPE ).cuda ())
104+ key_cache = torch .zeros (1 , 4 , OSL , 64 ).cuda ().to (DTYPE )
105+ value_cache = torch .zeros (1 , 4 , OSL , 64 ).cuda ().to (DTYPE )
106+ start_idx = 0
107+ end_idx = ISL
108+ is_causal = True
109+
110+ pyt_output = model (hidden_states , position_embeddings , None )
111+ seq_len = torch .export .Dim ("seq_len" , min = 2 , max = 2176 )
112+ dynamic_shapes = ({1 : seq_len }, ({1 : seq_len }, {1 : seq_len }), None )
113+ ep = torch .export .export (model , (hidden_states , position_embeddings , None ), dynamic_shapes = dynamic_shapes )
114+ with (torch_tensorrt .logging .debug () if args .debug else nullcontext ()):
115+ trt_model = torch_tensorrt .dynamo .compile (ep ,
116+ inputs = [hidden_states , position_embeddings , None , key_cache , value_cache , start_idx , end_idx , is_causal ],
117+ enabled_precisions = {torch .float32 },
118+ disable_tf32 = True ,
119+ debug = args .debug ,
120+ # offload_module_to_cpu=True,
121+ use_python_runtime = True )
122+
123+ # Test Prefill
124+ trt_output , _ , key_cache , value_cache = trt_model (hidden_states , position_embeddings , None , key_cache , value_cache , start_idx , end_idx , is_causal )
125+ print_diff (pyt_output [0 ], trt_output [0 ], "pyt_output[0] vs trt_output[0] [Prefill]" )
126+
127+ # Test Generate
128+ for start_idx in range (2048 , 2176 ):
129+ end_idx = start_idx + 1
130+ hidden_states_curr = torch .randn ((1 , 1 , 1152 ), dtype = DTYPE ).cuda ()
131+ position_embeddings_curr = (torch .randn ((1 , 1 , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , 1 , 256 ), dtype = DTYPE ).cuda ())
132+ # Concatenate the current hidden_states with the previous ones
133+ hidden_states_full = torch .cat ((hidden_states , hidden_states_curr ), dim = 1 )
134+ position_embeddings_full = (torch .cat ((position_embeddings [0 ], position_embeddings_curr [0 ]), dim = 1 ), torch .cat ((position_embeddings [1 ], position_embeddings_curr [1 ]), dim = 1 ))
135+
136+ is_causal = False
137+ out_no_cache , _ = model (hidden_states_full , position_embeddings_full , None )
138+ out_trt , _ , key_cache , value_cache = trt_model (hidden_states_curr , position_embeddings_curr , None , key_cache , value_cache , start_idx , end_idx , is_causal )
139+ out_pyt = out_no_cache [:, - 1 :, :]
140+ print_diff (out_pyt , out_trt , f"pyt_curr_output vs out_trt for idx { start_idx } " )
141+
142+ hidden_states = hidden_states_full
143+ position_embeddings = position_embeddings_full
144+
145+ def test_gemma3_decoder (args ):
146+
147+ DTYPE = torch .float32
148+ if args .precision == "FP16" :
149+ DTYPE = torch .float16
150+ elif args .precision == "BF16" :
151+ DTYPE = torch .bfloat16
152+ model = gemma3_model .model .layers [0 ].to (DTYPE )
153+ # model.self_attn.is_sliding = False
154+
155+ # gemma3
156+ hidden_states = torch .randn ((1 , 6 , 1152 ), dtype = DTYPE ).cuda ()
157+ position_embeddings_global = (torch .randn ((1 , 6 , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , 6 , 256 ), dtype = DTYPE ).cuda ())
158+ position_embeddings_local = (torch .randn ((1 , 6 , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , 6 , 256 ), dtype = DTYPE ).cuda ())
159+
160+ pyt_output = model (hidden_states , position_embeddings_global , position_embeddings_local )
161+ seq_len = torch .export .Dim ("seq_len" , min = 2 , max = 2176 )
162+ dynamic_shapes = ({1 : seq_len }, ({1 : seq_len }, {1 : seq_len }), ({1 : seq_len }, {1 : seq_len }))
163+ ep = torch .export .export (model , (hidden_states , position_embeddings_global , position_embeddings_local ), dynamic_shapes = dynamic_shapes )
164+
165+ with (torch_tensorrt .logging .debug () if args .debug else nullcontext ()):
166+ trt_model = torch_tensorrt .dynamo .compile (ep ,
167+ inputs = [hidden_states , position_embeddings_global , position_embeddings_local ],
168+ enabled_precisions = {torch .float32 },
169+ debug = args .debug )
170+ trt_output = trt_model (hidden_states , position_embeddings_global , position_embeddings_local )
171+
172+ print (f"Diff b/w pyt and trt: { torch .mean (torch .abs (pyt_output [0 ] - trt_output [0 ]))} " )
173+ # breakpoint()
174+ assert torch .allclose (pyt_output [0 ], trt_output [0 ], atol = ATOL , rtol = RTOL )
175+
176+ def test_gemma3_decoder_with_static_cache (args ):
177+
178+ class Gemma3DecoderLayerBlock (nn .Module ):
179+ def __init__ (self , model ):
180+ super ().__init__ ()
181+ self .config = GEMMA3_CONFIG
182+ self .decoder = Gemma3DecoderLayer (
183+ config = self .config ,
184+ layer_idx = 0 )
185+ self .model = model
186+ def forward (self , hidden_states , position_embeddings ):
187+ return self .model (hidden_states , position_embeddings = position_embeddings )
188+
189+ DTYPE = torch .float32
190+ model = Gemma3DecoderLayerBlock (gemma3_model .model .layers [0 ].to (DTYPE ))
191+
192+ import static_cache_v2
193+ # Inputs
194+ ISL = 2048
195+ NUM_TOKENS = 128
196+ OSL = ISL + NUM_TOKENS
197+ hidden_states = torch .randn ((1 , ISL , 1152 ), dtype = DTYPE ).cuda ()
198+ position_embeddings_global = (torch .randn ((1 , ISL , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , ISL , 256 ), dtype = DTYPE ).cuda ())
199+ position_embeddings_local = (torch .randn ((1 , NUM_TOKENS , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , NUM_TOKENS , 256 ), dtype = DTYPE ).cuda ())
200+ key_cache = torch .zeros (1 , 4 , OSL , 64 ).cuda ().to (DTYPE )
201+ value_cache = torch .zeros (1 , 4 , OSL , 64 ).cuda ().to (DTYPE )
202+ start_idx = 0
203+ end_idx = ISL
204+ is_causal = True
205+
206+ pyt_output = model (hidden_states , position_embeddings_global , position_embeddings_local )
207+ seq_len = torch .export .Dim ("seq_len" , min = 2 , max = 2176 )
208+ dynamic_shapes = ({1 : seq_len }, ({1 : seq_len }, {1 : seq_len }))
209+ ep = torch .export .export (model , args = (hidden_states , position_embeddings ), dynamic_shapes = dynamic_shapes )
210+
211+ with (torch_tensorrt .logging .debug () if args .debug else nullcontext ()):
212+ trt_model = torch_tensorrt .dynamo .compile (ep ,
213+ arg_inputs = [hidden_states , position_embeddings , key_cache , value_cache , start_idx , end_idx , is_causal ],
214+ enabled_precisions = {torch .float32 },
215+ disable_tf32 = True ,
216+ debug = args .debug ,
217+ # offload_module_to_cpu=True,
218+ use_python_runtime = True )
219+
220+ # Test Prefill
221+ trt_output , key_cache , value_cache = trt_model (hidden_states , position_embeddings , key_cache , value_cache , start_idx , end_idx , is_causal )
222+ print_diff (pyt_output [0 ], trt_output , "pyt_output vs trt_output [Prefill]" )
223+
224+ # Test Generate
225+ for start_idx in range (2048 , 2176 ):
226+ end_idx = start_idx + 1
227+ hidden_states_curr = torch .randn ((1 , 1 , 1152 ), dtype = DTYPE ).cuda ()
228+ position_embeddings_curr = (torch .randn ((1 , 1 , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , 1 , 256 ), dtype = DTYPE ).cuda ())
229+ # Concatenate the current hidden_states with the previous ones
230+ hidden_states_full = torch .cat ((hidden_states , hidden_states_curr ), dim = 1 )
231+ position_embeddings_full = (torch .cat ((position_embeddings [0 ], position_embeddings_curr [0 ]), dim = 1 ), torch .cat ((position_embeddings [1 ], position_embeddings_curr [1 ]), dim = 1 ))
232+
233+ is_causal = False
234+ out_no_cache = model (hidden_states_full , position_embeddings_full )
235+
236+ out_trt , key_cache , value_cache = trt_model (hidden_states_curr , position_embeddings_curr , key_cache , value_cache , start_idx , end_idx , is_causal )
237+ out_pyt = out_no_cache [0 ][:, - 1 :, :]
238+ print_diff (out_pyt , out_trt , f"pyt_curr_output vs out_trt for idx { start_idx } " )
239+ hidden_states = hidden_states_full
240+ position_embeddings = position_embeddings_full
241+
242+
243+ if __name__ == "__main__" :
244+ arg_parser = argparse .ArgumentParser (
245+ description = "Run test cases for llama attention and decoder"
246+ )
247+ arg_parser .add_argument (
248+ "--debug" ,
249+ action = "store_true" ,
250+ help = "Enable debug (default: False)"
251+ )
252+ arg_parser .add_argument ("--precision" , type = str , default = "FP16" , help = "Precision to use in the model. Options: FP16, BF16, FP32" )
253+ args = arg_parser .parse_args ()
254+ with torch .inference_mode ():
255+ # test_gemma3_attention(args)
256+ # test_gemma3_attention_with_static_cache(args)
257+ test_gemma3_decoder (args )
258+ # test_gemma3_decoder_with_static_cache(args)
0 commit comments