Skip to content

Commit 1068d77

Browse files
authored
Revert "[Dev] Remove calculation of padding token in moe routing loss (#2121)" (#2747)
Signed-off-by: Charlie Truong <[email protected]>
1 parent 0b6714e commit 1068d77

File tree

15 files changed

+90
-646
lines changed

15 files changed

+90
-646
lines changed

megatron/core/extensions/transformer_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1851,7 +1851,7 @@ def forward_post_hook(module, *_) -> None:
18511851
"TEFusedMLP module does not support submodules with post-backward hooks"
18521852
)
18531853

1854-
def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[Tensor, Optional[Tensor]]:
1854+
def forward(self, hidden_states: torch.Tensor) -> Tuple[Tensor, Optional[Tensor]]:
18551855
"""Forward."""
18561856

18571857
# Construct fused impl if needed

megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ def __init__(
305305
extra_block_kwargs=None,
306306
runtime_gather_output: Optional[bool] = None,
307307
loss_mask: Optional[Tensor] = None,
308-
padding_mask=None,
309308
):
310309
"""Initialize the schedule plan of all Transformer layers' sub-modules.
311310
@@ -348,7 +347,6 @@ def __init__(
348347
self._model_chunk_state.mtp_hidden_states = None
349348
self._model_chunk_state.loss_mask = loss_mask
350349
self._model_chunk_state.packed_seq_params = packed_seq_params
351-
self._model_chunk_state.padding_mask = padding_mask
352350
self._model_chunk_state.extra_block_kwargs = extra_block_kwargs
353351
self._model_chunk_state.runtime_gather_output = runtime_gather_output
354352
self._model_chunk_state.model = model

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,13 @@ def forward_impl(self):
120120
if not self.gpt_model.pre_process:
121121
self.chunk_state.decoder_input = self.gpt_model.decoder.input_tensor
122122
# Run GPTModel._preprocess
123-
(
124-
decoder_input,
125-
rotary_pos_emb,
126-
rotary_pos_cos,
127-
rotary_pos_sin,
128-
sequence_len_offset,
129-
padding_mask,
130-
) = self.gpt_model._preprocess(
131-
input_ids=self.chunk_state.input_ids,
132-
position_ids=self.chunk_state.position_ids,
133-
decoder_input=self.chunk_state.decoder_input,
134-
packed_seq_params=self.chunk_state.packed_seq_params,
135-
padding_mask=self.chunk_state.padding_mask,
123+
decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
124+
self.gpt_model._preprocess(
125+
input_ids=self.chunk_state.input_ids,
126+
position_ids=self.chunk_state.position_ids,
127+
decoder_input=self.chunk_state.decoder_input,
128+
packed_seq_params=self.chunk_state.packed_seq_params,
129+
)
136130
)
137131

138132
# Saved for later use
@@ -141,7 +135,6 @@ def forward_impl(self):
141135
self.chunk_state.rotary_pos_cos = rotary_pos_cos
142136
self.chunk_state.rotary_pos_sin = rotary_pos_sin
143137
self.chunk_state.sequence_len_offset = sequence_len_offset
144-
self.chunk_state.padding_mask = padding_mask
145138
return decoder_input
146139

147140

megatron/core/models/gpt/gpt_model.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ def _preprocess(
284284
decoder_input: Tensor = None,
285285
inference_context: BaseInferenceContext = None,
286286
packed_seq_params: PackedSeqParams = None,
287-
padding_mask: Optional[Tensor] = None,
288287
):
289288
"""Preprocesses inputs for the transformer decoder.
290289
@@ -301,20 +300,7 @@ def _preprocess(
301300
if decoder_input is not None:
302301
pass
303302
elif self.pre_process:
304-
if padding_mask is not None:
305-
assert padding_mask.shape == input_ids.shape, (
306-
f"padding_mask shape {padding_mask.shape} does not match "
307-
f"input_ids shape {input_ids.shape}"
308-
)
309303
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
310-
if padding_mask is not None and self.config.sequence_parallel:
311-
padding_mask = (
312-
tensor_parallel.scatter_to_sequence_parallel_region(
313-
padding_mask.transpose(0, 1).contiguous()
314-
)
315-
.transpose(0, 1)
316-
.contiguous()
317-
)
318304
else:
319305
# intermediate stage of pipeline
320306
# decoder will get hidden_states from encoder.input_tensor
@@ -417,7 +403,6 @@ def _preprocess(
417403
rotary_pos_cos,
418404
rotary_pos_sin,
419405
sequence_len_offset,
420-
padding_mask,
421406
)
422407
if rotary_pos_cos_sin is not None:
423408
# only in the case of flashinfer fused rope will we
@@ -461,7 +446,6 @@ def forward(
461446
*,
462447
inference_params: Optional[BaseInferenceContext] = None,
463448
loss_mask: Optional[Tensor] = None,
464-
padding_mask: Optional[Tensor] = None,
465449
) -> Tensor:
466450
"""Forward function of the GPT Model This function passes the input tensors
467451
through the embedding layer, and then the decoder and finally into the post
@@ -472,9 +456,6 @@ def forward(
472456
Args:
473457
runtime_gather_output (bool): Gather output at runtime. Default None means
474458
`parallel_output` arg in the constructor will be used.
475-
padding_mask (Tensor, optional): Padding mask for MoE routing.
476-
Shape [bsz, seq_length]. True = padding (exclude), False = valid (include).
477-
Only used for MoE layers to exclude padding tokens from routing computations.
478459
"""
479460
if self.config.fine_grained_activation_offloading:
480461
self.preprocess_for_fine_grained_offloading()
@@ -487,19 +468,13 @@ def forward(
487468
decoder_input=decoder_input,
488469
inference_context=inference_context,
489470
packed_seq_params=packed_seq_params,
490-
padding_mask=padding_mask,
491471
)
492472

493-
(
494-
decoder_input,
495-
rotary_pos_emb,
496-
rotary_pos_cos,
497-
rotary_pos_sin,
498-
sequence_len_offset,
499-
padding_mask,
500-
) = preproc_output[:6]
473+
(decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset) = (
474+
preproc_output[:5]
475+
)
501476

502-
rotary_pos_cos_sin = preproc_output[6] if len(preproc_output) == 7 else None
477+
rotary_pos_cos_sin = preproc_output[5] if len(preproc_output) == 6 else None
503478

504479
# Run decoder.
505480
hidden_states = self.decoder(
@@ -512,7 +487,6 @@ def forward(
512487
rotary_pos_cos_sin=rotary_pos_cos_sin,
513488
packed_seq_params=packed_seq_params,
514489
sequence_len_offset=sequence_len_offset,
515-
padding_mask=padding_mask,
516490
**(extra_block_kwargs or {}),
517491
)
518492

@@ -750,7 +724,6 @@ def build_schedule_plan(
750724
runtime_gather_output: Optional[bool] = None,
751725
inference_params: Optional[BaseInferenceContext] = None,
752726
loss_mask: Optional[Tensor] = None,
753-
padding_mask: Optional[Tensor] = None,
754727
):
755728
"""Builds a computation schedule plan for the model.
756729
@@ -776,7 +749,6 @@ def build_schedule_plan(
776749
inference_params (InferenceParams, optional):
777750
Parameters for inference. Defaults to None.
778751
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
779-
padding_mask (Optional[Tensor], optional): Padding mask. Defaults to None.
780752
781753
Returns:
782754
TransformerModelChunkSchedulePlan: The model chunk schedule plan.
@@ -798,7 +770,6 @@ def build_schedule_plan(
798770
extra_block_kwargs,
799771
runtime_gather_output,
800772
loss_mask,
801-
padding_mask,
802773
)
803774

804775
def sharded_state_dict(

megatron/core/transformer/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(
137137
tp_group=tp_group,
138138
)
139139

140-
def forward(self, hidden_states, per_token_scale=None, **kwargs):
140+
def forward(self, hidden_states, per_token_scale=None):
141141
"""Perform the forward pass through the MLP block."""
142142
# [s, b, 4 * h/p]
143143
nvtx_range_push(suffix="linear_fc1")

megatron/core/transformer/moe/moe_layer.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def __init__(
178178
self.cudagraph_tensor_store = MoECudaGraphTensorStore()
179179

180180
@maybe_skip_or_early_return_by_cudagraph("route")
181-
def route(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
181+
def route(self, hidden_states: torch.Tensor):
182182
"""Compute token routing for preprocessing.
183183
184184
This method uses the router to determine which experts to send each token to,
185185
producing routing probabilities and a mapping.
186186
"""
187-
probs, routing_map = self.router(hidden_states, padding_mask=padding_mask)
187+
probs, routing_map = self.router(hidden_states)
188188
return probs, routing_map
189189

190190
@maybe_skip_or_early_return_by_cudagraph("preprocess")
@@ -270,7 +270,7 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten
270270
output = output + shared_expert_output
271271
return output
272272

273-
def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
273+
def forward(self, hidden_states: torch.Tensor):
274274
"""Forward pass for the MoE layer.
275275
276276
The forward pass comprises four main steps:
@@ -280,11 +280,7 @@ def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tens
280280
4. Combine: The outputs from the experts are combined and returned.
281281
282282
Args:
283-
hidden_states (torch.Tensor): The input tensor shape [seq_length, bsz, hidden_size].
284-
padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions.
285-
used for correct auxiliary loss computation for packed sequence.
286-
Shape = [bsz, seq_length]. True = padding (exclude), False = valid (include).
287-
Defaults to None (all tokens are valid).
283+
hidden_states (torch.Tensor): The input tensor to the MoE layer.
288284
289285
Returns:
290286
A tuple containing the output tensor and the MLP bias, if any.
@@ -295,15 +291,11 @@ def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tens
295291
"are enabled without also enabling sequence parallelism."
296292
)
297293

298-
# Transpose from [bsz, seq_length] to [seq_length, bsz] to align with hidden_states
299-
if padding_mask is not None:
300-
padding_mask = padding_mask.transpose(0, 1).bool()
301-
302294
# MoE forward: route -> dispatch -> compute -> combine
303-
def custom_forward(hidden_states, padding_mask=None):
295+
def custom_forward(hidden_states):
304296
try:
305297
shared_expert_output = self.shared_experts_compute(hidden_states)
306-
probs, routing_map = self.route(hidden_states, padding_mask=padding_mask)
298+
probs, routing_map = self.route(hidden_states)
307299
hidden_states, probs, residual = self.preprocess(hidden_states, probs, routing_map)
308300
except MoECudaGraphPartialCaptureSignal as e:
309301
# This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator.
@@ -326,14 +318,11 @@ def custom_forward(hidden_states, padding_mask=None):
326318
tensor_parallel.random.get_cuda_rng_tracker,
327319
parallel_state.get_tensor_model_parallel_group(),
328320
hidden_states,
329-
padding_mask,
330321
)
331322
else:
332-
outputs = tensor_parallel.checkpoint(
333-
custom_forward, False, hidden_states, padding_mask
334-
)
323+
outputs = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
335324
else:
336-
outputs = custom_forward(hidden_states, padding_mask)
325+
outputs = custom_forward(hidden_states)
337326

338327
return outputs
339328

megatron/core/transformer/moe/moe_utils.py

Lines changed: 16 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
23
import math
34
from dataclasses import dataclass
45
from typing import List, Optional, Union
@@ -10,7 +11,6 @@
1011
from megatron.core.fp8_utils import get_fp8_align_size
1112
from megatron.core.process_groups_config import ProcessGroupCollection
1213
from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
13-
from megatron.core.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region
1414
from megatron.core.transformer.cuda_graphs import is_graph_capturing
1515
from megatron.core.transformer.enums import CudaGraphScope
1616
from megatron.core.transformer.transformer_config import TransformerConfig
@@ -120,34 +120,18 @@ def switch_load_balancing_loss_func(
120120
return aux_loss
121121

122122

123-
def z_loss_func(logits, z_loss_coeff, padding_mask: Optional[torch.Tensor] = None):
123+
def z_loss_func(logits, z_loss_coeff):
124124
"""Encourages the router's logits to remain small to enhance stability.
125125
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
126126
127127
Args:
128128
logits (torch.Tensor): The logits of the router.
129-
z_loss_coeff (float): The coefficient for the z-loss.
130-
padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions.
131-
Shape [num_tokens]. True = padding (exclude),
132-
False = valid (include). Defaults to None.
133129
134130
Returns:
135131
torch.Tensor: The logits after applying the z-loss.
136132
"""
137-
logsum = torch.logsumexp(logits, dim=-1)
138-
z_loss_values = torch.square(logsum)
139-
140-
if padding_mask is not None:
141-
# Invert padding_mask: True (padding) -> 0, False (valid) -> 1
142-
valid_mask = ~padding_mask
143-
# Only compute z_loss for valid (non-padding) tokens
144-
z_loss_values = z_loss_values * valid_mask
145-
# Compute mean over valid tokens only
146-
num_valid_tokens = valid_mask.sum()
147-
z_loss = z_loss_values.sum() / torch.clamp(num_valid_tokens, min=1.0) * z_loss_coeff
148-
else:
149-
z_loss = torch.mean(z_loss_values) * z_loss_coeff
150133

134+
z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
151135
return z_loss
152136

153137

@@ -187,28 +171,6 @@ def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_
187171
return capacity
188172

189173

190-
def get_tokens_per_expert_and_token_count(
191-
routing_map: torch.Tensor,
192-
reduce_group: torch.distributed.ProcessGroup,
193-
topk: int = None,
194-
with_padding_mask: bool = False,
195-
) -> torch.Tensor:
196-
"""
197-
Compute global_tokens_per_expert, local_num_tokens and total_num_tokens with padding mask.
198-
"""
199-
local_tokens_per_expert = routing_map.sum(dim=0)
200-
global_tokens_per_expert = reduce_from_tensor_model_parallel_region(
201-
local_tokens_per_expert, reduce_group
202-
)
203-
if with_padding_mask:
204-
local_num_tokens = local_tokens_per_expert.sum() / topk
205-
total_num_tokens = global_tokens_per_expert.sum() / topk
206-
else:
207-
local_num_tokens = routing_map.shape[0]
208-
total_num_tokens = local_num_tokens * reduce_group.size()
209-
return global_tokens_per_expert, local_num_tokens, total_num_tokens
210-
211-
212174
class MoEAuxLossAutoScaler(torch.autograd.Function):
213175
"""An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss."""
214176

@@ -667,48 +629,35 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):
667629

668630

669631
def compute_routing_scores_for_aux_loss(
670-
logits: torch.Tensor,
671-
topk: int,
672-
score_function: str,
673-
fused: bool = False,
674-
padding_mask: Optional[torch.Tensor] = None,
632+
logits: torch.Tensor, topk: int, score_function: str, fused: bool = False
675633
):
676634
"""Compute routing scores based on the score function.
677635
678636
Args:
679637
logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].
680-
padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions.
681-
Shape [num_tokens]. True = padding (exclude),
682-
False = valid (include). Defaults to None.
638+
683639
Returns:
684-
Tuple[torch.Tensor, torch.Tensor]: routing_map and scores.
640+
torch.Tensor: The normalized routing scores.
685641
"""
686642
if fused:
687643
if not HAVE_TE or fused_compute_score_for_moe_aux_loss is None:
688644
raise ValueError(
689645
"fused_compute_score_for_moe_aux_loss is not available. Please install TE >= 2.6.0."
690646
)
691-
routing_map, scores = fused_compute_score_for_moe_aux_loss(
647+
return fused_compute_score_for_moe_aux_loss(
692648
logits=logits, topk=topk, score_function=score_function
693649
)
694-
else:
695-
if score_function == "softmax":
696-
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
697-
elif score_function == "sigmoid":
698-
scores = torch.sigmoid(logits)
699-
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
700-
else:
701-
raise ValueError(f"Invalid score_function: {score_function}")
702650

703-
_, top_indices = torch.topk(scores, k=topk, dim=1)
704-
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
651+
if score_function == "softmax":
652+
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
653+
elif score_function == "sigmoid":
654+
scores = torch.sigmoid(logits)
655+
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
656+
else:
657+
raise ValueError(f"Invalid score_function: {score_function}")
705658

706-
# Apply padding mask to scores if provided
707-
if padding_mask is not None:
708-
# Invert padding_mask and make True indicates valid tokens
709-
valid_mask = (~padding_mask).unsqueeze(-1)
710-
routing_map = routing_map * valid_mask
711-
scores = scores * valid_mask
659+
_, top_indices = torch.topk(scores, k=topk, dim=1)
660+
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
712661
return routing_map, scores
713662

714663

0 commit comments

Comments
 (0)