Skip to content

Commit 0b6714e

Browse files
[Dev] Remove calculation of padding token in moe routing loss (NVIDIA#2121)
Co-authored-by: Li Tao <[email protected]>
1 parent 46b5505 commit 0b6714e

File tree

15 files changed

+646
-90
lines changed

15 files changed

+646
-90
lines changed

megatron/core/extensions/transformer_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1851,7 +1851,7 @@ def forward_post_hook(module, *_) -> None:
18511851
"TEFusedMLP module does not support submodules with post-backward hooks"
18521852
)
18531853

1854-
def forward(self, hidden_states: torch.Tensor) -> Tuple[Tensor, Optional[Tensor]]:
1854+
def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[Tensor, Optional[Tensor]]:
18551855
"""Forward."""
18561856

18571857
# Construct fused impl if needed

megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def __init__(
305305
extra_block_kwargs=None,
306306
runtime_gather_output: Optional[bool] = None,
307307
loss_mask: Optional[Tensor] = None,
308+
padding_mask=None,
308309
):
309310
"""Initialize the schedule plan of all Transformer layers' sub-modules.
310311
@@ -347,6 +348,7 @@ def __init__(
347348
self._model_chunk_state.mtp_hidden_states = None
348349
self._model_chunk_state.loss_mask = loss_mask
349350
self._model_chunk_state.packed_seq_params = packed_seq_params
351+
self._model_chunk_state.padding_mask = padding_mask
350352
self._model_chunk_state.extra_block_kwargs = extra_block_kwargs
351353
self._model_chunk_state.runtime_gather_output = runtime_gather_output
352354
self._model_chunk_state.model = model

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,19 @@ def forward_impl(self):
120120
if not self.gpt_model.pre_process:
121121
self.chunk_state.decoder_input = self.gpt_model.decoder.input_tensor
122122
# Run GPTModel._preprocess
123-
decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
124-
self.gpt_model._preprocess(
125-
input_ids=self.chunk_state.input_ids,
126-
position_ids=self.chunk_state.position_ids,
127-
decoder_input=self.chunk_state.decoder_input,
128-
packed_seq_params=self.chunk_state.packed_seq_params,
129-
)
123+
(
124+
decoder_input,
125+
rotary_pos_emb,
126+
rotary_pos_cos,
127+
rotary_pos_sin,
128+
sequence_len_offset,
129+
padding_mask,
130+
) = self.gpt_model._preprocess(
131+
input_ids=self.chunk_state.input_ids,
132+
position_ids=self.chunk_state.position_ids,
133+
decoder_input=self.chunk_state.decoder_input,
134+
packed_seq_params=self.chunk_state.packed_seq_params,
135+
padding_mask=self.chunk_state.padding_mask,
130136
)
131137

132138
# Saved for later use
@@ -135,6 +141,7 @@ def forward_impl(self):
135141
self.chunk_state.rotary_pos_cos = rotary_pos_cos
136142
self.chunk_state.rotary_pos_sin = rotary_pos_sin
137143
self.chunk_state.sequence_len_offset = sequence_len_offset
144+
self.chunk_state.padding_mask = padding_mask
138145
return decoder_input
139146

140147

megatron/core/models/gpt/gpt_model.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def _preprocess(
284284
decoder_input: Tensor = None,
285285
inference_context: BaseInferenceContext = None,
286286
packed_seq_params: PackedSeqParams = None,
287+
padding_mask: Optional[Tensor] = None,
287288
):
288289
"""Preprocesses inputs for the transformer decoder.
289290
@@ -300,7 +301,20 @@ def _preprocess(
300301
if decoder_input is not None:
301302
pass
302303
elif self.pre_process:
304+
if padding_mask is not None:
305+
assert padding_mask.shape == input_ids.shape, (
306+
f"padding_mask shape {padding_mask.shape} does not match "
307+
f"input_ids shape {input_ids.shape}"
308+
)
303309
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
310+
if padding_mask is not None and self.config.sequence_parallel:
311+
padding_mask = (
312+
tensor_parallel.scatter_to_sequence_parallel_region(
313+
padding_mask.transpose(0, 1).contiguous()
314+
)
315+
.transpose(0, 1)
316+
.contiguous()
317+
)
304318
else:
305319
# intermediate stage of pipeline
306320
# decoder will get hidden_states from encoder.input_tensor
@@ -403,6 +417,7 @@ def _preprocess(
403417
rotary_pos_cos,
404418
rotary_pos_sin,
405419
sequence_len_offset,
420+
padding_mask,
406421
)
407422
if rotary_pos_cos_sin is not None:
408423
# only in the case of flashinfer fused rope will we
@@ -446,6 +461,7 @@ def forward(
446461
*,
447462
inference_params: Optional[BaseInferenceContext] = None,
448463
loss_mask: Optional[Tensor] = None,
464+
padding_mask: Optional[Tensor] = None,
449465
) -> Tensor:
450466
"""Forward function of the GPT Model This function passes the input tensors
451467
through the embedding layer, and then the decoder and finally into the post
@@ -456,6 +472,9 @@ def forward(
456472
Args:
457473
runtime_gather_output (bool): Gather output at runtime. Default None means
458474
`parallel_output` arg in the constructor will be used.
475+
padding_mask (Tensor, optional): Padding mask for MoE routing.
476+
Shape [bsz, seq_length]. True = padding (exclude), False = valid (include).
477+
Only used for MoE layers to exclude padding tokens from routing computations.
459478
"""
460479
if self.config.fine_grained_activation_offloading:
461480
self.preprocess_for_fine_grained_offloading()
@@ -468,13 +487,19 @@ def forward(
468487
decoder_input=decoder_input,
469488
inference_context=inference_context,
470489
packed_seq_params=packed_seq_params,
490+
padding_mask=padding_mask,
471491
)
472492

473-
(decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset) = (
474-
preproc_output[:5]
475-
)
493+
(
494+
decoder_input,
495+
rotary_pos_emb,
496+
rotary_pos_cos,
497+
rotary_pos_sin,
498+
sequence_len_offset,
499+
padding_mask,
500+
) = preproc_output[:6]
476501

477-
rotary_pos_cos_sin = preproc_output[5] if len(preproc_output) == 6 else None
502+
rotary_pos_cos_sin = preproc_output[6] if len(preproc_output) == 7 else None
478503

479504
# Run decoder.
480505
hidden_states = self.decoder(
@@ -487,6 +512,7 @@ def forward(
487512
rotary_pos_cos_sin=rotary_pos_cos_sin,
488513
packed_seq_params=packed_seq_params,
489514
sequence_len_offset=sequence_len_offset,
515+
padding_mask=padding_mask,
490516
**(extra_block_kwargs or {}),
491517
)
492518

@@ -724,6 +750,7 @@ def build_schedule_plan(
724750
runtime_gather_output: Optional[bool] = None,
725751
inference_params: Optional[BaseInferenceContext] = None,
726752
loss_mask: Optional[Tensor] = None,
753+
padding_mask: Optional[Tensor] = None,
727754
):
728755
"""Builds a computation schedule plan for the model.
729756
@@ -749,6 +776,7 @@ def build_schedule_plan(
749776
inference_params (InferenceParams, optional):
750777
Parameters for inference. Defaults to None.
751778
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
779+
padding_mask (Optional[Tensor], optional): Padding mask. Defaults to None.
752780
753781
Returns:
754782
TransformerModelChunkSchedulePlan: The model chunk schedule plan.
@@ -770,6 +798,7 @@ def build_schedule_plan(
770798
extra_block_kwargs,
771799
runtime_gather_output,
772800
loss_mask,
801+
padding_mask,
773802
)
774803

775804
def sharded_state_dict(

megatron/core/transformer/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(
137137
tp_group=tp_group,
138138
)
139139

140-
def forward(self, hidden_states, per_token_scale=None):
140+
def forward(self, hidden_states, per_token_scale=None, **kwargs):
141141
"""Perform the forward pass through the MLP block."""
142142
# [s, b, 4 * h/p]
143143
nvtx_range_push(suffix="linear_fc1")

megatron/core/transformer/moe/moe_layer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def __init__(
178178
self.cudagraph_tensor_store = MoECudaGraphTensorStore()
179179

180180
@maybe_skip_or_early_return_by_cudagraph("route")
181-
def route(self, hidden_states: torch.Tensor):
181+
def route(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
182182
"""Compute token routing for preprocessing.
183183
184184
This method uses the router to determine which experts to send each token to,
185185
producing routing probabilities and a mapping.
186186
"""
187-
probs, routing_map = self.router(hidden_states)
187+
probs, routing_map = self.router(hidden_states, padding_mask=padding_mask)
188188
return probs, routing_map
189189

190190
@maybe_skip_or_early_return_by_cudagraph("preprocess")
@@ -270,7 +270,7 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten
270270
output = output + shared_expert_output
271271
return output
272272

273-
def forward(self, hidden_states: torch.Tensor):
273+
def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
274274
"""Forward pass for the MoE layer.
275275
276276
The forward pass comprises four main steps:
@@ -280,7 +280,11 @@ def forward(self, hidden_states: torch.Tensor):
280280
4. Combine: The outputs from the experts are combined and returned.
281281
282282
Args:
283-
hidden_states (torch.Tensor): The input tensor to the MoE layer.
283+
hidden_states (torch.Tensor): The input tensor shape [seq_length, bsz, hidden_size].
284+
padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions.
285+
used for correct auxiliary loss computation for packed sequence.
286+
Shape = [bsz, seq_length]. True = padding (exclude), False = valid (include).
287+
Defaults to None (all tokens are valid).
284288
285289
Returns:
286290
A tuple containing the output tensor and the MLP bias, if any.
@@ -291,11 +295,15 @@ def forward(self, hidden_states: torch.Tensor):
291295
"are enabled without also enabling sequence parallelism."
292296
)
293297

298+
# Transpose from [bsz, seq_length] to [seq_length, bsz] to align with hidden_states
299+
if padding_mask is not None:
300+
padding_mask = padding_mask.transpose(0, 1).bool()
301+
294302
# MoE forward: route -> dispatch -> compute -> combine
295-
def custom_forward(hidden_states):
303+
def custom_forward(hidden_states, padding_mask=None):
296304
try:
297305
shared_expert_output = self.shared_experts_compute(hidden_states)
298-
probs, routing_map = self.route(hidden_states)
306+
probs, routing_map = self.route(hidden_states, padding_mask=padding_mask)
299307
hidden_states, probs, residual = self.preprocess(hidden_states, probs, routing_map)
300308
except MoECudaGraphPartialCaptureSignal as e:
301309
# This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator.
@@ -318,11 +326,14 @@ def custom_forward(hidden_states):
318326
tensor_parallel.random.get_cuda_rng_tracker,
319327
parallel_state.get_tensor_model_parallel_group(),
320328
hidden_states,
329+
padding_mask,
321330
)
322331
else:
323-
outputs = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
332+
outputs = tensor_parallel.checkpoint(
333+
custom_forward, False, hidden_states, padding_mask
334+
)
324335
else:
325-
outputs = custom_forward(hidden_states)
336+
outputs = custom_forward(hidden_states, padding_mask)
326337

327338
return outputs
328339

megatron/core/transformer/moe/moe_utils.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2-
32
import math
43
from dataclasses import dataclass
54
from typing import List, Optional, Union
@@ -11,6 +10,7 @@
1110
from megatron.core.fp8_utils import get_fp8_align_size
1211
from megatron.core.process_groups_config import ProcessGroupCollection
1312
from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
13+
from megatron.core.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region
1414
from megatron.core.transformer.cuda_graphs import is_graph_capturing
1515
from megatron.core.transformer.enums import CudaGraphScope
1616
from megatron.core.transformer.transformer_config import TransformerConfig
@@ -120,18 +120,34 @@ def switch_load_balancing_loss_func(
120120
return aux_loss
121121

122122

123-
def z_loss_func(logits, z_loss_coeff):
123+
def z_loss_func(logits, z_loss_coeff, padding_mask: Optional[torch.Tensor] = None):
124124
"""Encourages the router's logits to remain small to enhance stability.
125125
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
126126
127127
Args:
128128
logits (torch.Tensor): The logits of the router.
129+
z_loss_coeff (float): The coefficient for the z-loss.
130+
padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions.
131+
Shape [num_tokens]. True = padding (exclude),
132+
False = valid (include). Defaults to None.
129133
130134
Returns:
131135
torch.Tensor: The logits after applying the z-loss.
132136
"""
137+
logsum = torch.logsumexp(logits, dim=-1)
138+
z_loss_values = torch.square(logsum)
139+
140+
if padding_mask is not None:
141+
# Invert padding_mask: True (padding) -> 0, False (valid) -> 1
142+
valid_mask = ~padding_mask
143+
# Only compute z_loss for valid (non-padding) tokens
144+
z_loss_values = z_loss_values * valid_mask
145+
# Compute mean over valid tokens only
146+
num_valid_tokens = valid_mask.sum()
147+
z_loss = z_loss_values.sum() / torch.clamp(num_valid_tokens, min=1.0) * z_loss_coeff
148+
else:
149+
z_loss = torch.mean(z_loss_values) * z_loss_coeff
133150

134-
z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
135151
return z_loss
136152

137153

@@ -171,6 +187,28 @@ def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_
171187
return capacity
172188

173189

190+
def get_tokens_per_expert_and_token_count(
191+
routing_map: torch.Tensor,
192+
reduce_group: torch.distributed.ProcessGroup,
193+
topk: int = None,
194+
with_padding_mask: bool = False,
195+
) -> torch.Tensor:
196+
"""
197+
Compute global_tokens_per_expert, local_num_tokens and total_num_tokens with padding mask.
198+
"""
199+
local_tokens_per_expert = routing_map.sum(dim=0)
200+
global_tokens_per_expert = reduce_from_tensor_model_parallel_region(
201+
local_tokens_per_expert, reduce_group
202+
)
203+
if with_padding_mask:
204+
local_num_tokens = local_tokens_per_expert.sum() / topk
205+
total_num_tokens = global_tokens_per_expert.sum() / topk
206+
else:
207+
local_num_tokens = routing_map.shape[0]
208+
total_num_tokens = local_num_tokens * reduce_group.size()
209+
return global_tokens_per_expert, local_num_tokens, total_num_tokens
210+
211+
174212
class MoEAuxLossAutoScaler(torch.autograd.Function):
175213
"""An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss."""
176214

@@ -629,35 +667,48 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):
629667

630668

631669
def compute_routing_scores_for_aux_loss(
632-
logits: torch.Tensor, topk: int, score_function: str, fused: bool = False
670+
logits: torch.Tensor,
671+
topk: int,
672+
score_function: str,
673+
fused: bool = False,
674+
padding_mask: Optional[torch.Tensor] = None,
633675
):
634676
"""Compute routing scores based on the score function.
635677
636678
Args:
637679
logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].
638-
680+
padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions.
681+
Shape [num_tokens]. True = padding (exclude),
682+
False = valid (include). Defaults to None.
639683
Returns:
640-
torch.Tensor: The normalized routing scores.
684+
Tuple[torch.Tensor, torch.Tensor]: routing_map and scores.
641685
"""
642686
if fused:
643687
if not HAVE_TE or fused_compute_score_for_moe_aux_loss is None:
644688
raise ValueError(
645689
"fused_compute_score_for_moe_aux_loss is not available. Please install TE >= 2.6.0."
646690
)
647-
return fused_compute_score_for_moe_aux_loss(
691+
routing_map, scores = fused_compute_score_for_moe_aux_loss(
648692
logits=logits, topk=topk, score_function=score_function
649693
)
650-
651-
if score_function == "softmax":
652-
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
653-
elif score_function == "sigmoid":
654-
scores = torch.sigmoid(logits)
655-
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
656694
else:
657-
raise ValueError(f"Invalid score_function: {score_function}")
695+
if score_function == "softmax":
696+
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
697+
elif score_function == "sigmoid":
698+
scores = torch.sigmoid(logits)
699+
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
700+
else:
701+
raise ValueError(f"Invalid score_function: {score_function}")
702+
703+
_, top_indices = torch.topk(scores, k=topk, dim=1)
704+
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
658705

659-
_, top_indices = torch.topk(scores, k=topk, dim=1)
660-
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
706+
# Apply padding mask to scores if provided
707+
if padding_mask is not None:
708+
# Invert padding_mask and make True indicates valid tokens
709+
valid_mask = (~padding_mask).unsqueeze(-1)
710+
routing_map = routing_map * valid_mask
711+
scores = scores * valid_mask
661712
return routing_map, scores
662713

663714

0 commit comments

Comments
 (0)