-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[Dev] Remove calculation of padding token in moe routing loss #2121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Could you please add an UT to the |
Done |
|
/ok to test 96b0d00 |
|
Please submit a mirror PR to main as well |
Done |
1a58ae5 to
a389f33
Compare
|
/ok to test 7a34303 |
|
/ok to test 84344df |
…NVIDIA#2121)" This reverts commit 0b6714e. Signed-off-by: Charlie Truong <[email protected]>
…#2121)" (#2747) Signed-off-by: Charlie Truong <[email protected]>
What does this PR do ?
Related issue: 1982
This PR removes the calculation of padding token in aux loss.
PR to the main branch: #2142
Usage
For
GPTModel.Pass
padding_maskto theforward()method. The mask should have the same shape as the input sequence:Note: When using
GPTModel, thepadding_maskis automatically propagated throughTransformerBlock→TransformerLayer→MoELayer, so you only need to pass it once at theGPTModellevel. If you are usingMoELayerdirectly (withoutGPTModel), passpadding_maskto theforward()method:1. TL;DR (3-5 sentences)
padding_maskthroughout the MoE routing and auxiliary loss computation pipeline, allowing padding tokens to be explicitly excluded from load balancing and z-loss calculations.GPTModel.forward(),TransformerBlock,TransformerLayer,MLP,TEFusedMLP,MoELayer, andTopKRouter.2. Big Picture
2.1 Before vs After Architecture
graph TB subgraph "Before: Padding Tokens Included" A1[GPTModel.forward] --> B1[TransformerBlock] B1 --> C1[TransformerLayer] C1 --> D1[MoELayer.forward] D1 --> E1[Router.routing] E1 --> F1["Aux Loss Computation<br/>(ALL tokens counted)"] style F1 fill:#ffcccc end subgraph "After: Padding Tokens Excluded" A2[GPTModel.forward<br/>+ padding_mask] --> B2[TransformerBlock<br/>+ padding_mask] B2 --> C2[TransformerLayer<br/>+ padding_mask] C2 --> D2[MoELayer.forward<br/>+ padding_mask] D2 --> E2[Router.routing<br/>+ padding_mask] E2 --> F2["Aux Loss Computation<br/>(VALID tokens only)"] E2 --> G2["get_tokens_per_expert_and_token_count"] style G2 fill:#90EE90 style F2 fill:#90EE90 end2.2 Change Scope Summary
megatron/core/models/gpt/gpt_model.pypadding_maskparameter toforward()megatron/core/transformer/transformer_block.pypadding_maskthrough checkpointing and layer iterationmegatron/core/transformer/transformer_layer.pypadding_maskto MoE layer in_forward_mlp()megatron/core/transformer/mlp.pypadding_maskinMLP.forward()megatron/core/extensions/transformer_engine.pypadding_maskinTEFusedMLP.forward()megatron/core/transformer/moe/moe_layer.pypadding_maskinroute(),forward()and passes to routermegatron/core/transformer/moe/router.pymegatron/core/transformer/moe/moe_utils.pycompute_tokens_per_expert_with_padding()helper; updated loss functionstests/unit_tests/transformer/moe/test_aux_loss.pyTestPaddingMaskAuxLossclass with comprehensive teststests/unit_tests/transformer/moe/test_routers.pytest_router_with_padding_mask()3. Design Rationale
3.1 Problem Background
Original Limitation:
In sequences with variable lengths, shorter sequences are padded to match the longest sequence in the batch. Previously, these padding tokens were treated identically to real tokens during MoE auxiliary loss computation:
This caused several issues:
3.2 Solution Approach
Design Decision: Explicit Padding Mask Propagation
The chosen approach propagates a boolean
padding_masktensor through the entire forward pass:[batch_size, seq_length]at input, transposed to[seq_length, batch_size]for MoEFalsefor valid tokens,Truefor padding tokensAlternatives Considered:
Trade-offs:
padding_mask(optional parameter)3.3 Key Design Points
New Abstraction:
get_tokens_per_expert_and_token_count()This unified function handles:
aux_loss: Token counts across full batchseq_aux_loss: Per-sequence token counts with batch dimension folded into expertsglobal_aux_loss: Token counts across global batchInterface Changes:
Router.forward(input, padding_mask=None)- accepts optional padding maskMoELayer.forward(hidden_states, padding_mask=None)- passes mask to routerswitch_load_balancing_loss_func(..., padding_mask=None)- masks probability aggregationz_loss_func(logits, z_loss_coeff, padding_mask=None)- excludes padding from mean4. Execution Path Deep Dive
4.1 Entry Point
The new code path is triggered when a user passes
padding_masktoGPTModel.forward():4.2 Call Chain Visualization
sequenceDiagram participant User participant GPTModel participant TransformerBlock participant TransformerLayer participant MoELayer participant Router participant AuxLoss as Aux Loss Functions User->>GPTModel: forward(padding_mask=[bsz, seq]) GPTModel->>TransformerBlock: forward(padding_mask) loop For each layer TransformerBlock->>TransformerLayer: forward(padding_mask) TransformerLayer->>TransformerLayer: _forward_attention() TransformerLayer->>TransformerLayer: _forward_mlp(padding_mask) alt is_moe_layer TransformerLayer->>MoELayer: forward(hidden_states, padding_mask) MoELayer->>MoELayer: transpose padding_mask to [seq, bsz] MoELayer->>Router: forward(hidden_states, padding_mask) Router->>Router: flatten padding_mask to [num_tokens] Router->>Router: apply_z_loss(logits, padding_mask) Router->>AuxLoss: z_loss_func(..., padding_mask) AuxLoss-->>Router: z_loss (masked mean) Router->>Router: routing(logits, padding_mask) Router->>AuxLoss: get_tokens_per_expert_and_token_count(...) AuxLoss-->>Router: (global_tokens_per_expert, local_num_tokens, total_num_tokens) Router->>AuxLoss: switch_load_balancing_loss_func(..., global_tokens_per_expert) AuxLoss-->>Router: aux_loss (valid tokens only) Router->>Router: _apply_expert_bias(routing_map, padding_mask) Router-->>MoELayer: (probs, routing_map) end end TransformerBlock-->>User: output4.3 Data Flow
graph TD A["padding_mask<br/>[bsz, seq_length]<br/>bool tensor"] -->|GPTModel| B["padding_mask<br/>[bsz, seq_length]"] B -->|TransformerBlock| C["padding_mask<br/>[bsz, seq_length]"] C -->|TransformerLayer| D["padding_mask<br/>[bsz, seq_length]"] D -->|MoELayer| E["padding_mask_flat<br/>[num_tokens]<br/>where num_tokens = seq * bsz"] E -->|"Router.apply_z_loss()"| F["z_loss = masked_mean(logsumexp²)<br/>Only valid tokens contribute"] E -->|"get_tokens_per_expert_and_token_count()"| G["global_tokens_per_expert[num_experts]<br/>Only valid token routings counted<br/>total_num_tokens, local_num_tokens<br/>(derives from all-reduced counts)"] E -->|"switch_load_balancing_loss_func()"| H["probs masked before sum<br/>padding probs zeroed"] style A fill:#e1f5fe style E fill:#fff3e0 style F fill:#e8f5e9 style H fill:#e8f5e9