Skip to content

Conversation

@HaochenYuan
Copy link
Contributor

@HaochenYuan HaochenYuan commented Nov 4, 2025

What does this PR do ?

Related issue: 1982
This PR removes the calculation of padding token in aux loss.
PR to the main branch: #2142


Usage

For GPTModel.

Pass padding_mask to the forward() method. The mask should have the same shape as the input sequence:

# padding_mask shape: [batch_size, sequence_length], True = padding token, False = valid token
output = gpt_model.forward(..., padding_mask=padding_mask)

Note: When using GPTModel, the padding_mask is automatically propagated through TransformerBlockTransformerLayerMoELayer, so you only need to pass it once at the GPTModel level. If you are using MoELayer directly (without GPTModel), pass padding_mask to the forward() method:

1. TL;DR (3-5 sentences)

  • What: This PR adds support for padding_mask throughout the MoE routing and auxiliary loss computation pipeline, allowing padding tokens to be explicitly excluded from load balancing and z-loss calculations.
  • Why: Previously, padding tokens were included in MoE auxiliary loss computations, which artificially skewed the load balancing metrics and could lead to suboptimal expert load distribution during training (Issue #1984).
  • Impact: Users training MoE models with variable-length sequences now get more accurate auxiliary losses, improving training stability and expert utilization. The change affects GPTModel.forward(), TransformerBlock, TransformerLayer, MLP, TEFusedMLP, MoELayer, and TopKRouter.

2. Big Picture

2.1 Before vs After Architecture

graph TB
    subgraph "Before: Padding Tokens Included"
        A1[GPTModel.forward] --> B1[TransformerBlock]
        B1 --> C1[TransformerLayer]
        C1 --> D1[MoELayer.forward]
        D1 --> E1[Router.routing]
        E1 --> F1["Aux Loss Computation<br/>(ALL tokens counted)"]
        style F1 fill:#ffcccc
    end
    
    subgraph "After: Padding Tokens Excluded"
        A2[GPTModel.forward<br/>+ padding_mask] --> B2[TransformerBlock<br/>+ padding_mask]
        B2 --> C2[TransformerLayer<br/>+ padding_mask]
        C2 --> D2[MoELayer.forward<br/>+ padding_mask]
        D2 --> E2[Router.routing<br/>+ padding_mask]
        E2 --> F2["Aux Loss Computation<br/>(VALID tokens only)"]
        E2 --> G2["get_tokens_per_expert_and_token_count"]
        style G2 fill:#90EE90
        style F2 fill:#90EE90
    end
Loading

2.2 Change Scope Summary

Category Files Description
Modified megatron/core/models/gpt/gpt_model.py Added padding_mask parameter to forward()
Modified megatron/core/transformer/transformer_block.py Propagates padding_mask through checkpointing and layer iteration
Modified megatron/core/transformer/transformer_layer.py Passes padding_mask to MoE layer in _forward_mlp()
Modified megatron/core/transformer/mlp.py Accepts padding_mask in MLP.forward()
Modified megatron/core/extensions/transformer_engine.py Accepts padding_mask in TEFusedMLP.forward()
Modified megatron/core/transformer/moe/moe_layer.py Accepts padding_mask in route(), forward() and passes to router
Modified megatron/core/transformer/moe/router.py Core logic: applies padding mask to aux loss, z-loss, expert bias
Modified megatron/core/transformer/moe/moe_utils.py New compute_tokens_per_expert_with_padding() helper; updated loss functions
New Test tests/unit_tests/transformer/moe/test_aux_loss.py TestPaddingMaskAuxLoss class with comprehensive tests
Modified Test tests/unit_tests/transformer/moe/test_routers.py Added test_router_with_padding_mask()

3. Design Rationale

3.1 Problem Background

Original Limitation:

In sequences with variable lengths, shorter sequences are padded to match the longest sequence in the batch. Previously, these padding tokens were treated identically to real tokens during MoE auxiliary loss computation:

# BEFORE: All tokens counted (including padding)
tokens_per_expert = routing_map.sum(dim=0)  # counts padding tokens
num_tokens = routing_map.shape[0]           # total including padding
total_num_tokens = num_tokens * self.tp_cp_group.size()

This caused several issues:

  1. Skewed load balancing: Padding tokens artificially inflated token counts for certain experts
  2. Incorrect z-loss: The router's logits for padding positions contributed to stability loss
  3. Misleading metrics: Logged aux loss values didn't reflect actual training dynamics

3.2 Solution Approach

Design Decision: Explicit Padding Mask Propagation

The chosen approach propagates a boolean padding_mask tensor through the entire forward pass:

  • Shape: [batch_size, seq_length] at input, transposed to [seq_length, batch_size] for MoE
  • Semantics: False for valid tokens, True for padding tokens

Alternatives Considered:

  1. Automatic padding detection (e.g., from attention mask): Rejected because attention masks have different semantics and shapes (e.g., causal masks)
  2. Token-level dropping in router: Rejected as it would change routing behavior, not just loss computation
  3. Post-hoc loss correction: Rejected as it requires tracking additional metadata

Trade-offs:

  • (+) Explicit control over which tokens are masked
  • (+) No changes to routing decisions (only loss computation)
  • (+) Compatible with all parallelism strategies (TP, EP, CP)
  • (-) Requires user to provide padding_mask (optional parameter)

3.3 Key Design Points

New Abstraction: get_tokens_per_expert_and_token_count()

def get_tokens_per_expert_and_token_count(
    routing_map: torch.Tensor,
    reduce_group: torch.distributed.ProcessGroup,
    topk: int = None,
    with_padding_mask: bool = False,
) -> torch.Tensor:
    """
    Compute global_tokens_per_expert, local_num_tokens and total_num_tokens with padding mask.
    """

This unified function handles:

  • Standard aux_loss: Token counts across full batch
  • seq_aux_loss: Per-sequence token counts with batch dimension folded into experts
  • global_aux_loss: Token counts across global batch

Interface Changes:

  • Router.forward(input, padding_mask=None) - accepts optional padding mask
  • MoELayer.forward(hidden_states, padding_mask=None) - passes mask to router
  • switch_load_balancing_loss_func(..., padding_mask=None) - masks probability aggregation
  • z_loss_func(logits, z_loss_coeff, padding_mask=None) - excludes padding from mean

4. Execution Path Deep Dive

4.1 Entry Point

The new code path is triggered when a user passes padding_mask to GPTModel.forward():

# User code
output = gpt_model.forward(
    input_ids=tokens,
    position_ids=positions,
    attention_mask=attn_mask,
    padding_mask=padding_mask,  # <-- NEW: Shape [bsz, seq_length]
)

4.2 Call Chain Visualization

sequenceDiagram
    participant User
    participant GPTModel
    participant TransformerBlock
    participant TransformerLayer
    participant MoELayer
    participant Router
    participant AuxLoss as Aux Loss Functions
    
    User->>GPTModel: forward(padding_mask=[bsz, seq])
    GPTModel->>TransformerBlock: forward(padding_mask)
    
    loop For each layer
        TransformerBlock->>TransformerLayer: forward(padding_mask)
        TransformerLayer->>TransformerLayer: _forward_attention()
        TransformerLayer->>TransformerLayer: _forward_mlp(padding_mask)
        
        alt is_moe_layer
            TransformerLayer->>MoELayer: forward(hidden_states, padding_mask)
            MoELayer->>MoELayer: transpose padding_mask to [seq, bsz]
            MoELayer->>Router: forward(hidden_states, padding_mask)
            Router->>Router: flatten padding_mask to [num_tokens]
            Router->>Router: apply_z_loss(logits, padding_mask)
            Router->>AuxLoss: z_loss_func(..., padding_mask)
            AuxLoss-->>Router: z_loss (masked mean)
            Router->>Router: routing(logits, padding_mask)
            Router->>AuxLoss: get_tokens_per_expert_and_token_count(...)
            AuxLoss-->>Router: (global_tokens_per_expert, local_num_tokens, total_num_tokens)
            Router->>AuxLoss: switch_load_balancing_loss_func(..., global_tokens_per_expert)
            AuxLoss-->>Router: aux_loss (valid tokens only)
            Router->>Router: _apply_expert_bias(routing_map, padding_mask)
            Router-->>MoELayer: (probs, routing_map)
        end
    end
    
    TransformerBlock-->>User: output
Loading

4.3 Data Flow

graph TD
    A["padding_mask<br/>[bsz, seq_length]<br/>bool tensor"] -->|GPTModel| B["padding_mask<br/>[bsz, seq_length]"]
    B -->|TransformerBlock| C["padding_mask<br/>[bsz, seq_length]"]
    C -->|TransformerLayer| D["padding_mask<br/>[bsz, seq_length]"]
    D -->|MoELayer| E["padding_mask_flat<br/>[num_tokens]<br/>where num_tokens = seq * bsz"]
    
    E -->|"Router.apply_z_loss()"| F["z_loss = masked_mean(logsumexp²)<br/>Only valid tokens contribute"]
    E -->|"get_tokens_per_expert_and_token_count()"| G["global_tokens_per_expert[num_experts]<br/>Only valid token routings counted<br/>total_num_tokens, local_num_tokens<br/>(derives from all-reduced counts)"]
    E -->|"switch_load_balancing_loss_func()"| H["probs masked before sum<br/>padding probs zeroed"]
    
    style A fill:#e1f5fe
    style E fill:#fff3e0
    style F fill:#e8f5e9
    style H fill:#e8f5e9
Loading

@HaochenYuan HaochenYuan requested review from a team as code owners November 4, 2025 03:00
@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 4, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HaochenYuan HaochenYuan added dev branch Dev branch related issues and development module: moe labels Nov 4, 2025
@Victarry
Copy link
Contributor

Victarry commented Nov 4, 2025

Could you please add an UT to the megatron-lm/tests/unit_tests/transformer/moe/test_routers.py and megatron-lm/tests/unit_tests/transformer/moe/test_aux_loss.py

@HaochenYuan
Copy link
Contributor Author

Could you please add an UT to the megatron-lm/tests/unit_tests/transformer/moe/test_routers.py and megatron-lm/tests/unit_tests/transformer/moe/test_aux_loss.py

Done

@Victarry
Copy link
Contributor

Victarry commented Nov 4, 2025

/ok to test 96b0d00

@Victarry Victarry added this to the Core 0.16 milestone Nov 4, 2025
@yanring
Copy link
Contributor

yanring commented Nov 5, 2025

Please submit a mirror PR to main as well

@HaochenYuan
Copy link
Contributor Author

Please submit a mirror PR to main as well

Done

@HaochenYuan HaochenYuan force-pushed the dev branch 2 times, most recently from 1a58ae5 to a389f33 Compare November 6, 2025 10:21
@BestJuly
Copy link
Contributor

BestJuly commented Nov 7, 2025

/ok to test 7a34303

@Victarry
Copy link
Contributor

/ok to test 84344df

@Victarry Victarry added this pull request to the merge queue Dec 24, 2025
@Victarry Victarry removed this pull request from the merge queue due to a manual request Dec 24, 2025
@Victarry Victarry added this pull request to the merge queue Dec 24, 2025
Merged via the queue into NVIDIA:dev with commit 0b6714e Dec 24, 2025
43 of 45 checks passed
chtruong814 added a commit to chtruong814/Megatron-LM that referenced this pull request Dec 24, 2025
chtruong814 added a commit that referenced this pull request Dec 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dev branch Dev branch related issues and development module: moe Run functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants