Skip to content

Commit 3a82827

Browse files
committed
Enhance MNNVL unittest.
Signed-off-by: Shiyu Li <[email protected]>
1 parent 9b9c631 commit 3a82827

File tree

1 file changed

+96
-47
lines changed

1 file changed

+96
-47
lines changed

tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py

Lines changed: 96 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -47,29 +47,29 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6):
4747
def run_single_rank(
4848
tensor_parallel_size,
4949
single_rank_forward_func,
50-
input,
51-
residual,
50+
input_list,
51+
residual_list,
5252
norm_weight,
5353
eps,
5454
hidden_size,
5555
dtype,
5656
fused_add_norm,
57-
reference_output,
57+
reference_output_list,
5858
):
5959
rank = tensorrt_llm.mpi_rank()
6060
torch.cuda.set_device(rank)
6161
try:
6262
single_rank_forward_func(
63-
input,
64-
residual,
63+
input_list,
64+
residual_list,
6565
norm_weight,
6666
eps,
6767
hidden_size,
6868
dtype,
6969
tensor_parallel_size,
7070
rank,
7171
fused_add_norm,
72-
reference_output,
72+
reference_output_list,
7373
)
7474
except Exception:
7575
traceback.print_exc()
@@ -79,25 +79,30 @@ def run_single_rank(
7979

8080
@torch.inference_mode()
8181
def row_linear_residual_norm_fusion_forward(
82-
x: torch.Tensor,
83-
residual: torch.Tensor,
82+
x_list: list[torch.Tensor],
83+
residual_list: list[torch.Tensor],
8484
norm_weight: torch.Tensor,
8585
eps: float,
8686
hidden_size: int,
8787
dtype: torch.dtype,
8888
tensor_parallel_size: int,
8989
tensor_parallel_rank: int,
9090
fusion: bool,
91-
reference_output: tuple[torch.Tensor, ...],
91+
reference_output_list: list[tuple[torch.Tensor, ...]],
9292
):
9393

94-
x = x.cuda()
95-
residual = residual.cuda()
94+
# Move all tensors to GPU
95+
x_list = [x.cuda() for x in x_list]
96+
residual_list = [residual.cuda() for residual in residual_list]
9697
norm_weight = norm_weight.cuda()
97-
reference_output = tuple(t.cuda() for t in reference_output)
98+
reference_output_list = [
99+
tuple(t.cuda() for t in ref_output)
100+
for ref_output in reference_output_list
101+
]
98102

99103
MPI.COMM_WORLD.barrier()
100104

105+
# Create a single AllReduce instance to be reused for all sequence lengths
101106
allreduce = AllReduce(
102107
mapping=Mapping(
103108
world_size=tensor_parallel_size,
@@ -119,72 +124,116 @@ def func(input, residual, norm_weight, eps, enable_fusion):
119124
residual=residual,
120125
norm_weight=norm_weight,
121126
eps=eps,
122-
))
127+
),
128+
)
123129
return (output, residual)
124130
else:
125131
output = allreduce(input)
126132
return (output, )
127133

128-
output = func(x.clone(), residual.clone(), norm_weight, eps, fusion)
134+
# Process each sequence length using the same AllReduce instance
135+
for i, (x, residual, reference_output) in enumerate(
136+
zip(x_list, residual_list, reference_output_list)):
137+
output = func(x.clone(), residual.clone(), norm_weight, eps, fusion)
129138

130-
torch.testing.assert_close(
131-
output[0],
132-
reference_output[0],
133-
rtol=0.05,
134-
atol=0.15,
135-
)
136-
137-
if fusion:
138139
torch.testing.assert_close(
139-
output[1],
140-
reference_output[1],
140+
output[0],
141+
reference_output[0],
141142
rtol=0.05,
142143
atol=0.15,
143144
)
144145

146+
if fusion:
147+
torch.testing.assert_close(
148+
output[1],
149+
reference_output[1],
150+
rtol=0.05,
151+
atol=0.15,
152+
)
153+
145154

146155
@skip_pre_blackwell
147156
@pytest.mark.skipif(torch.cuda.device_count() < 2,
148157
reason="needs 2 GPUs to run this test")
149-
@pytest.mark.parametrize("seq_len", [1, 4, 32, 128],
150-
ids=lambda x: f"seqlen:{x}")
158+
@pytest.mark.parametrize(
159+
"seq_len",
160+
[
161+
[
162+
1,
163+
],
164+
[
165+
4,
166+
],
167+
[
168+
15,
169+
],
170+
[
171+
32,
172+
],
173+
[
174+
128,
175+
],
176+
[31, 11, 27, 4],
177+
],
178+
ids=lambda x: f"seqlen:{x}",
179+
)
151180
@pytest.mark.parametrize("hidden_size", [7168], ids=lambda x: f"hidden:{x}")
181+
@pytest.mark.parametrize("dtype",
182+
[torch.float16, torch.bfloat16, torch.float32],
183+
ids=lambda x: f"dtype:{torch.finfo(x).dtype}")
152184
@pytest.mark.parametrize(
153185
"fusion",
154186
[True, False],
155187
ids=["fusion", "no_fusion"],
156188
)
157-
def test_row_linear_residual_norm_fusion(seq_len, hidden_size, fusion):
189+
def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, fusion):
158190

159191
torch.manual_seed(42)
160-
dtype = torch.bfloat16
161192
tensor_parallel_size = 2
162193

163-
x = torch.randn((tensor_parallel_size, seq_len, hidden_size), dtype=dtype)
164-
residual = torch.randn((seq_len, hidden_size), dtype=dtype)
194+
# Create norm_weight once (same for all sequence lengths)
165195
norm_weight = torch.randn((hidden_size, ), dtype=dtype)
166196
eps = 1e-5
167-
reference_output = (torch.sum(x, dim=0), )
168-
if fusion:
169-
residual_out = reference_output[0] + residual
170-
reference_output = (rms_norm(residual_out.to(torch.float32),
171-
norm_weight, eps).to(dtype), residual_out)
197+
198+
# Create lists of tensors for each sequence length
199+
x_list = []
200+
residual_list = []
201+
reference_output_list = []
202+
203+
for seq_len_val in seq_len:
204+
x = torch.randn((tensor_parallel_size, seq_len_val, hidden_size),
205+
dtype=dtype)
206+
residual = torch.randn((seq_len_val, hidden_size), dtype=dtype)
207+
reference_output = (torch.sum(x, dim=0), )
208+
if fusion:
209+
residual_out = reference_output[0] + residual
210+
reference_output = (rms_norm(residual_out.to(torch.float32),
211+
norm_weight,
212+
eps).to(dtype), residual_out)
213+
214+
x_list.append(x)
215+
residual_list.append(residual)
216+
reference_output_list.append(reference_output)
172217

173218
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
174219
results = executor.map(
175220
run_single_rank,
176-
*zip(*[(
177-
tensor_parallel_size,
178-
row_linear_residual_norm_fusion_forward,
179-
x[i, :, :],
180-
residual,
181-
norm_weight,
182-
eps,
183-
hidden_size,
184-
dtype,
185-
fusion,
186-
reference_output,
187-
) for i in range(tensor_parallel_size)]),
221+
*zip(*[
222+
(
223+
tensor_parallel_size,
224+
row_linear_residual_norm_fusion_forward,
225+
[
226+
x[i, :, :] for x in x_list
227+
], # Extract the i-th rank's data from each sequence length
228+
residual_list,
229+
norm_weight,
230+
eps,
231+
hidden_size,
232+
dtype,
233+
fusion,
234+
reference_output_list,
235+
) for i in range(tensor_parallel_size)
236+
]),
188237
)
189238
for r in results:
190239
assert r is True

0 commit comments

Comments
 (0)