Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c49ffc1
GPG sign off
HaochenYuan Nov 6, 2025
7a34303
Merge branch 'dev' into dev
BestJuly Nov 7, 2025
5781e60
format fix
HaochenYuan Nov 7, 2025
7e4a3f2
format fix
HaochenYuan Nov 7, 2025
f9d9c15
fix UT error
HaochenYuan Nov 7, 2025
2903bf9
fix some UT error
HaochenYuan Nov 13, 2025
f1b4e84
fix linting
HaochenYuan Nov 13, 2025
b97e762
refactor
HaochenYuan Dec 2, 2025
80a06de
refactor
HaochenYuan Dec 2, 2025
7664c4d
add comments for padding mask
HaochenYuan Dec 2, 2025
8acb7e9
refactor: simplify logic
HaochenYuan Dec 3, 2025
6218741
refactor
HaochenYuan Dec 11, 2025
9dd4373
refactor
HaochenYuan Dec 11, 2025
2a9f3c5
refactor
HaochenYuan Dec 11, 2025
cc34bcd
refactor
HaochenYuan Dec 11, 2025
886a613
fix bug
HaochenYuan Dec 11, 2025
97de067
fix bug in seq_aux_loss
HaochenYuan Dec 11, 2025
fdb7311
fix bug in seq_aux_loss
HaochenYuan Dec 11, 2025
c302132
fix linting
HaochenYuan Dec 11, 2025
aa3ac99
add Copyright header
HaochenYuan Dec 11, 2025
39e3db4
add Copyright header
HaochenYuan Dec 11, 2025
cebf69d
add padding_mask in fusedMLP
HaochenYuan Dec 12, 2025
b4884c7
Merge branch 'dev' into dev
HaochenYuan Dec 12, 2025
bc1a37c
slice padding_mask in SP
HaochenYuan Dec 18, 2025
bc78b60
fix linting
HaochenYuan Dec 18, 2025
a9f02b1
modify preprocess in 1f1b
HaochenYuan Dec 18, 2025
9ae9fc9
modify preprocess in 1f1b&mtp
HaochenYuan Dec 18, 2025
c1195c3
fix linting
HaochenYuan Dec 18, 2025
f9a6b3e
Merge branch 'dev' into dev
HaochenYuan Dec 18, 2025
cae4095
add UT in 1f1b & SP-S
HaochenYuan Dec 23, 2025
13676a9
fix linting
HaochenYuan Dec 23, 2025
6cf4ce6
fix error
HaochenYuan Dec 23, 2025
0c8159b
add copyright
HaochenYuan Dec 23, 2025
a0e1591
fix error
HaochenYuan Dec 23, 2025
84344df
fix linting
HaochenYuan Dec 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1851,7 +1851,7 @@ def forward_post_hook(module, *_) -> None:
"TEFusedMLP module does not support submodules with post-backward hooks"
)

def forward(self, hidden_states: torch.Tensor) -> Tuple[Tensor, Optional[Tensor]]:
def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[Tensor, Optional[Tensor]]:
"""Forward."""

# Construct fused impl if needed
Expand Down
2 changes: 2 additions & 0 deletions megatron/core/models/common/model_chunk_schedule_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(
extra_block_kwargs=None,
runtime_gather_output: Optional[bool] = None,
loss_mask: Optional[Tensor] = None,
padding_mask=None,
):
"""Initialize the schedule plan of all Transformer layers' sub-modules.

Expand Down Expand Up @@ -347,6 +348,7 @@ def __init__(
self._model_chunk_state.mtp_hidden_states = None
self._model_chunk_state.loss_mask = loss_mask
self._model_chunk_state.packed_seq_params = packed_seq_params
self._model_chunk_state.padding_mask = padding_mask
self._model_chunk_state.extra_block_kwargs = extra_block_kwargs
self._model_chunk_state.runtime_gather_output = runtime_gather_output
self._model_chunk_state.model = model
Expand Down
21 changes: 14 additions & 7 deletions megatron/core/models/gpt/fine_grained_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,19 @@ def forward_impl(self):
if not self.gpt_model.pre_process:
self.chunk_state.decoder_input = self.gpt_model.decoder.input_tensor
# Run GPTModel._preprocess
decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
self.gpt_model._preprocess(
input_ids=self.chunk_state.input_ids,
position_ids=self.chunk_state.position_ids,
decoder_input=self.chunk_state.decoder_input,
packed_seq_params=self.chunk_state.packed_seq_params,
)
(
decoder_input,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
padding_mask,
) = self.gpt_model._preprocess(
input_ids=self.chunk_state.input_ids,
position_ids=self.chunk_state.position_ids,
decoder_input=self.chunk_state.decoder_input,
packed_seq_params=self.chunk_state.packed_seq_params,
padding_mask=self.chunk_state.padding_mask,
)

# Saved for later use
Expand All @@ -135,6 +141,7 @@ def forward_impl(self):
self.chunk_state.rotary_pos_cos = rotary_pos_cos
self.chunk_state.rotary_pos_sin = rotary_pos_sin
self.chunk_state.sequence_len_offset = sequence_len_offset
self.chunk_state.padding_mask = padding_mask
return decoder_input


Expand Down
37 changes: 33 additions & 4 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def _preprocess(
decoder_input: Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
padding_mask: Optional[Tensor] = None,
):
"""Preprocesses inputs for the transformer decoder.

Expand All @@ -300,7 +301,20 @@ def _preprocess(
if decoder_input is not None:
pass
elif self.pre_process:
if padding_mask is not None:
assert padding_mask.shape == input_ids.shape, (
f"padding_mask shape {padding_mask.shape} does not match "
f"input_ids shape {input_ids.shape}"
)
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
if padding_mask is not None and self.config.sequence_parallel:
padding_mask = (
tensor_parallel.scatter_to_sequence_parallel_region(
padding_mask.transpose(0, 1).contiguous()
)
.transpose(0, 1)
.contiguous()
)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
Expand Down Expand Up @@ -403,6 +417,7 @@ def _preprocess(
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
padding_mask,
)
if rotary_pos_cos_sin is not None:
# only in the case of flashinfer fused rope will we
Expand Down Expand Up @@ -446,6 +461,7 @@ def forward(
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
Expand All @@ -456,6 +472,9 @@ def forward(
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
padding_mask (Tensor, optional): Padding mask for MoE routing.
Shape [bsz, seq_length]. True = padding (exclude), False = valid (include).
Only used for MoE layers to exclude padding tokens from routing computations.
"""
if self.config.fine_grained_activation_offloading:
self.preprocess_for_fine_grained_offloading()
Expand All @@ -468,13 +487,19 @@ def forward(
decoder_input=decoder_input,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
padding_mask=padding_mask,
)

(decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset) = (
preproc_output[:5]
)
(
decoder_input,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
padding_mask,
) = preproc_output[:6]

rotary_pos_cos_sin = preproc_output[5] if len(preproc_output) == 6 else None
rotary_pos_cos_sin = preproc_output[6] if len(preproc_output) == 7 else None

# Run decoder.
hidden_states = self.decoder(
Expand All @@ -487,6 +512,7 @@ def forward(
rotary_pos_cos_sin=rotary_pos_cos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
padding_mask=padding_mask,
**(extra_block_kwargs or {}),
)

Expand Down Expand Up @@ -724,6 +750,7 @@ def build_schedule_plan(
runtime_gather_output: Optional[bool] = None,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
padding_mask: Optional[Tensor] = None,
):
"""Builds a computation schedule plan for the model.

Expand All @@ -749,6 +776,7 @@ def build_schedule_plan(
inference_params (InferenceParams, optional):
Parameters for inference. Defaults to None.
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
padding_mask (Optional[Tensor], optional): Padding mask. Defaults to None.

Returns:
TransformerModelChunkSchedulePlan: The model chunk schedule plan.
Expand All @@ -770,6 +798,7 @@ def build_schedule_plan(
extra_block_kwargs,
runtime_gather_output,
loss_mask,
padding_mask,
)

def sharded_state_dict(
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
tp_group=tp_group,
)

def forward(self, hidden_states, per_token_scale=None):
def forward(self, hidden_states, per_token_scale=None, **kwargs):
"""Perform the forward pass through the MLP block."""
# [s, b, 4 * h/p]
nvtx_range_push(suffix="linear_fc1")
Expand Down
27 changes: 19 additions & 8 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ def __init__(
self.cudagraph_tensor_store = MoECudaGraphTensorStore()

@maybe_skip_or_early_return_by_cudagraph("route")
def route(self, hidden_states: torch.Tensor):
def route(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
"""Compute token routing for preprocessing.

This method uses the router to determine which experts to send each token to,
producing routing probabilities and a mapping.
"""
probs, routing_map = self.router(hidden_states)
probs, routing_map = self.router(hidden_states, padding_mask=padding_mask)
return probs, routing_map

@maybe_skip_or_early_return_by_cudagraph("preprocess")
Expand Down Expand Up @@ -270,7 +270,7 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten
output = output + shared_expert_output
return output

def forward(self, hidden_states: torch.Tensor):
def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
"""Forward pass for the MoE layer.

The forward pass comprises four main steps:
Expand All @@ -280,7 +280,11 @@ def forward(self, hidden_states: torch.Tensor):
4. Combine: The outputs from the experts are combined and returned.

Args:
hidden_states (torch.Tensor): The input tensor to the MoE layer.
hidden_states (torch.Tensor): The input tensor shape [seq_length, bsz, hidden_size].
padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions.
used for correct auxiliary loss computation for packed sequence.
Shape = [bsz, seq_length]. True = padding (exclude), False = valid (include).
Defaults to None (all tokens are valid).

Returns:
A tuple containing the output tensor and the MLP bias, if any.
Expand All @@ -291,11 +295,15 @@ def forward(self, hidden_states: torch.Tensor):
"are enabled without also enabling sequence parallelism."
)

# Transpose from [bsz, seq_length] to [seq_length, bsz] to align with hidden_states
if padding_mask is not None:
padding_mask = padding_mask.transpose(0, 1).bool()

# MoE forward: route -> dispatch -> compute -> combine
def custom_forward(hidden_states):
def custom_forward(hidden_states, padding_mask=None):
try:
shared_expert_output = self.shared_experts_compute(hidden_states)
probs, routing_map = self.route(hidden_states)
probs, routing_map = self.route(hidden_states, padding_mask=padding_mask)
hidden_states, probs, residual = self.preprocess(hidden_states, probs, routing_map)
except MoECudaGraphPartialCaptureSignal as e:
# This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator.
Expand All @@ -318,11 +326,14 @@ def custom_forward(hidden_states):
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
padding_mask,
)
else:
outputs = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
outputs = tensor_parallel.checkpoint(
custom_forward, False, hidden_states, padding_mask
)
else:
outputs = custom_forward(hidden_states)
outputs = custom_forward(hidden_states, padding_mask)

return outputs

Expand Down
83 changes: 67 additions & 16 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import math
from dataclasses import dataclass
from typing import List, Optional, Union
Expand All @@ -11,6 +10,7 @@
from megatron.core.fp8_utils import get_fp8_align_size
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
from megatron.core.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region
from megatron.core.transformer.cuda_graphs import is_graph_capturing
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -120,18 +120,34 @@ def switch_load_balancing_loss_func(
return aux_loss


def z_loss_func(logits, z_loss_coeff):
def z_loss_func(logits, z_loss_coeff, padding_mask: Optional[torch.Tensor] = None):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.

Args:
logits (torch.Tensor): The logits of the router.
z_loss_coeff (float): The coefficient for the z-loss.
padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions.
Shape [num_tokens]. True = padding (exclude),
False = valid (include). Defaults to None.

Returns:
torch.Tensor: The logits after applying the z-loss.
"""
logsum = torch.logsumexp(logits, dim=-1)
z_loss_values = torch.square(logsum)

if padding_mask is not None:
# Invert padding_mask: True (padding) -> 0, False (valid) -> 1
valid_mask = ~padding_mask
# Only compute z_loss for valid (non-padding) tokens
z_loss_values = z_loss_values * valid_mask
# Compute mean over valid tokens only
num_valid_tokens = valid_mask.sum()
z_loss = z_loss_values.sum() / torch.clamp(num_valid_tokens, min=1.0) * z_loss_coeff
else:
z_loss = torch.mean(z_loss_values) * z_loss_coeff

z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
return z_loss


Expand Down Expand Up @@ -171,6 +187,28 @@ def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_
return capacity


def get_tokens_per_expert_and_token_count(
routing_map: torch.Tensor,
reduce_group: torch.distributed.ProcessGroup,
topk: int = None,
with_padding_mask: bool = False,
) -> torch.Tensor:
"""
Compute global_tokens_per_expert, local_num_tokens and total_num_tokens with padding mask.
"""
local_tokens_per_expert = routing_map.sum(dim=0)
global_tokens_per_expert = reduce_from_tensor_model_parallel_region(
local_tokens_per_expert, reduce_group
)
if with_padding_mask:
local_num_tokens = local_tokens_per_expert.sum() / topk
total_num_tokens = global_tokens_per_expert.sum() / topk
else:
local_num_tokens = routing_map.shape[0]
total_num_tokens = local_num_tokens * reduce_group.size()
return global_tokens_per_expert, local_num_tokens, total_num_tokens


class MoEAuxLossAutoScaler(torch.autograd.Function):
"""An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss."""

Expand Down Expand Up @@ -629,35 +667,48 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):


def compute_routing_scores_for_aux_loss(
logits: torch.Tensor, topk: int, score_function: str, fused: bool = False
logits: torch.Tensor,
topk: int,
score_function: str,
fused: bool = False,
padding_mask: Optional[torch.Tensor] = None,
):
"""Compute routing scores based on the score function.

Args:
logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].

padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions.
Shape [num_tokens]. True = padding (exclude),
False = valid (include). Defaults to None.
Returns:
torch.Tensor: The normalized routing scores.
Tuple[torch.Tensor, torch.Tensor]: routing_map and scores.
"""
if fused:
if not HAVE_TE or fused_compute_score_for_moe_aux_loss is None:
raise ValueError(
"fused_compute_score_for_moe_aux_loss is not available. Please install TE >= 2.6.0."
)
return fused_compute_score_for_moe_aux_loss(
routing_map, scores = fused_compute_score_for_moe_aux_loss(
logits=logits, topk=topk, score_function=score_function
)

if score_function == "softmax":
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
else:
raise ValueError(f"Invalid score_function: {score_function}")
if score_function == "softmax":
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
else:
raise ValueError(f"Invalid score_function: {score_function}")

_, top_indices = torch.topk(scores, k=topk, dim=1)
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()

_, top_indices = torch.topk(scores, k=topk, dim=1)
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
# Apply padding mask to scores if provided
if padding_mask is not None:
# Invert padding_mask and make True indicates valid tokens
valid_mask = (~padding_mask).unsqueeze(-1)
routing_map = routing_map * valid_mask
scores = scores * valid_mask
return routing_map, scores


Expand Down
Loading
Loading