@@ -47,29 +47,29 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6):
47
47
def run_single_rank (
48
48
tensor_parallel_size ,
49
49
single_rank_forward_func ,
50
- input ,
51
- residual ,
50
+ input_list ,
51
+ residual_list ,
52
52
norm_weight ,
53
53
eps ,
54
54
hidden_size ,
55
55
dtype ,
56
56
fused_add_norm ,
57
- reference_output ,
57
+ reference_output_list ,
58
58
):
59
59
rank = tensorrt_llm .mpi_rank ()
60
60
torch .cuda .set_device (rank )
61
61
try :
62
62
single_rank_forward_func (
63
- input ,
64
- residual ,
63
+ input_list ,
64
+ residual_list ,
65
65
norm_weight ,
66
66
eps ,
67
67
hidden_size ,
68
68
dtype ,
69
69
tensor_parallel_size ,
70
70
rank ,
71
71
fused_add_norm ,
72
- reference_output ,
72
+ reference_output_list ,
73
73
)
74
74
except Exception :
75
75
traceback .print_exc ()
@@ -79,25 +79,30 @@ def run_single_rank(
79
79
80
80
@torch .inference_mode ()
81
81
def row_linear_residual_norm_fusion_forward (
82
- x : torch .Tensor ,
83
- residual : torch .Tensor ,
82
+ x_list : list [ torch .Tensor ] ,
83
+ residual_list : list [ torch .Tensor ] ,
84
84
norm_weight : torch .Tensor ,
85
85
eps : float ,
86
86
hidden_size : int ,
87
87
dtype : torch .dtype ,
88
88
tensor_parallel_size : int ,
89
89
tensor_parallel_rank : int ,
90
90
fusion : bool ,
91
- reference_output : tuple [torch .Tensor , ...],
91
+ reference_output_list : list [ tuple [torch .Tensor , ...] ],
92
92
):
93
93
94
- x = x .cuda ()
95
- residual = residual .cuda ()
94
+ # Move all tensors to GPU
95
+ x_list = [x .cuda () for x in x_list ]
96
+ residual_list = [residual .cuda () for residual in residual_list ]
96
97
norm_weight = norm_weight .cuda ()
97
- reference_output = tuple (t .cuda () for t in reference_output )
98
+ reference_output_list = [
99
+ tuple (t .cuda () for t in ref_output )
100
+ for ref_output in reference_output_list
101
+ ]
98
102
99
103
MPI .COMM_WORLD .barrier ()
100
104
105
+ # Create a single AllReduce instance to be reused for all sequence lengths
101
106
allreduce = AllReduce (
102
107
mapping = Mapping (
103
108
world_size = tensor_parallel_size ,
@@ -119,72 +124,116 @@ def func(input, residual, norm_weight, eps, enable_fusion):
119
124
residual = residual ,
120
125
norm_weight = norm_weight ,
121
126
eps = eps ,
122
- ))
127
+ ),
128
+ )
123
129
return (output , residual )
124
130
else :
125
131
output = allreduce (input )
126
132
return (output , )
127
133
128
- output = func (x .clone (), residual .clone (), norm_weight , eps , fusion )
134
+ # Process each sequence length using the same AllReduce instance
135
+ for i , (x , residual , reference_output ) in enumerate (
136
+ zip (x_list , residual_list , reference_output_list )):
137
+ output = func (x .clone (), residual .clone (), norm_weight , eps , fusion )
129
138
130
- torch .testing .assert_close (
131
- output [0 ],
132
- reference_output [0 ],
133
- rtol = 0.05 ,
134
- atol = 0.15 ,
135
- )
136
-
137
- if fusion :
138
139
torch .testing .assert_close (
139
- output [1 ],
140
- reference_output [1 ],
140
+ output [0 ],
141
+ reference_output [0 ],
141
142
rtol = 0.05 ,
142
143
atol = 0.15 ,
143
144
)
144
145
146
+ if fusion :
147
+ torch .testing .assert_close (
148
+ output [1 ],
149
+ reference_output [1 ],
150
+ rtol = 0.05 ,
151
+ atol = 0.15 ,
152
+ )
153
+
145
154
146
155
@skip_pre_blackwell
147
156
@pytest .mark .skipif (torch .cuda .device_count () < 2 ,
148
157
reason = "needs 2 GPUs to run this test" )
149
- @pytest .mark .parametrize ("seq_len" , [1 , 4 , 32 , 128 ],
150
- ids = lambda x : f"seqlen:{ x } " )
158
+ @pytest .mark .parametrize (
159
+ "seq_len" ,
160
+ [
161
+ [
162
+ 1 ,
163
+ ],
164
+ [
165
+ 4 ,
166
+ ],
167
+ [
168
+ 15 ,
169
+ ],
170
+ [
171
+ 32 ,
172
+ ],
173
+ [
174
+ 128 ,
175
+ ],
176
+ [31 , 11 , 27 , 4 ],
177
+ ],
178
+ ids = lambda x : f"seqlen:{ x } " ,
179
+ )
151
180
@pytest .mark .parametrize ("hidden_size" , [7168 ], ids = lambda x : f"hidden:{ x } " )
181
+ @pytest .mark .parametrize ("dtype" ,
182
+ [torch .float16 , torch .bfloat16 , torch .float32 ],
183
+ ids = lambda x : f"dtype:{ torch .finfo (x ).dtype } " )
152
184
@pytest .mark .parametrize (
153
185
"fusion" ,
154
186
[True , False ],
155
187
ids = ["fusion" , "no_fusion" ],
156
188
)
157
- def test_row_linear_residual_norm_fusion (seq_len , hidden_size , fusion ):
189
+ def test_row_linear_residual_norm_fusion (seq_len , hidden_size , dtype , fusion ):
158
190
159
191
torch .manual_seed (42 )
160
- dtype = torch .bfloat16
161
192
tensor_parallel_size = 2
162
193
163
- x = torch .randn ((tensor_parallel_size , seq_len , hidden_size ), dtype = dtype )
164
- residual = torch .randn ((seq_len , hidden_size ), dtype = dtype )
194
+ # Create norm_weight once (same for all sequence lengths)
165
195
norm_weight = torch .randn ((hidden_size , ), dtype = dtype )
166
196
eps = 1e-5
167
- reference_output = (torch .sum (x , dim = 0 ), )
168
- if fusion :
169
- residual_out = reference_output [0 ] + residual
170
- reference_output = (rms_norm (residual_out .to (torch .float32 ),
171
- norm_weight , eps ).to (dtype ), residual_out )
197
+
198
+ # Create lists of tensors for each sequence length
199
+ x_list = []
200
+ residual_list = []
201
+ reference_output_list = []
202
+
203
+ for seq_len_val in seq_len :
204
+ x = torch .randn ((tensor_parallel_size , seq_len_val , hidden_size ),
205
+ dtype = dtype )
206
+ residual = torch .randn ((seq_len_val , hidden_size ), dtype = dtype )
207
+ reference_output = (torch .sum (x , dim = 0 ), )
208
+ if fusion :
209
+ residual_out = reference_output [0 ] + residual
210
+ reference_output = (rms_norm (residual_out .to (torch .float32 ),
211
+ norm_weight ,
212
+ eps ).to (dtype ), residual_out )
213
+
214
+ x_list .append (x )
215
+ residual_list .append (residual )
216
+ reference_output_list .append (reference_output )
172
217
173
218
with MPIPoolExecutor (max_workers = tensor_parallel_size ) as executor :
174
219
results = executor .map (
175
220
run_single_rank ,
176
- * zip (* [(
177
- tensor_parallel_size ,
178
- row_linear_residual_norm_fusion_forward ,
179
- x [i , :, :],
180
- residual ,
181
- norm_weight ,
182
- eps ,
183
- hidden_size ,
184
- dtype ,
185
- fusion ,
186
- reference_output ,
187
- ) for i in range (tensor_parallel_size )]),
221
+ * zip (* [
222
+ (
223
+ tensor_parallel_size ,
224
+ row_linear_residual_norm_fusion_forward ,
225
+ [
226
+ x [i , :, :] for x in x_list
227
+ ], # Extract the i-th rank's data from each sequence length
228
+ residual_list ,
229
+ norm_weight ,
230
+ eps ,
231
+ hidden_size ,
232
+ dtype ,
233
+ fusion ,
234
+ reference_output_list ,
235
+ ) for i in range (tensor_parallel_size )
236
+ ]),
188
237
)
189
238
for r in results :
190
239
assert r is True
0 commit comments