|
1 | 1 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | | - |
3 | 2 | import math |
4 | 3 | from dataclasses import dataclass |
5 | 4 | from typing import List, Optional, Union |
|
11 | 10 | from megatron.core.fp8_utils import get_fp8_align_size |
12 | 11 | from megatron.core.process_groups_config import ProcessGroupCollection |
13 | 12 | from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name |
| 13 | +from megatron.core.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region |
14 | 14 | from megatron.core.transformer.cuda_graphs import is_graph_capturing |
15 | 15 | from megatron.core.transformer.enums import CudaGraphScope |
16 | 16 | from megatron.core.transformer.transformer_config import TransformerConfig |
@@ -120,18 +120,34 @@ def switch_load_balancing_loss_func( |
120 | 120 | return aux_loss |
121 | 121 |
|
122 | 122 |
|
123 | | -def z_loss_func(logits, z_loss_coeff): |
| 123 | +def z_loss_func(logits, z_loss_coeff, padding_mask: Optional[torch.Tensor] = None): |
124 | 124 | """Encourages the router's logits to remain small to enhance stability. |
125 | 125 | Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. |
126 | 126 |
|
127 | 127 | Args: |
128 | 128 | logits (torch.Tensor): The logits of the router. |
| 129 | + z_loss_coeff (float): The coefficient for the z-loss. |
| 130 | + padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions. |
| 131 | + Shape [num_tokens]. True = padding (exclude), |
| 132 | + False = valid (include). Defaults to None. |
129 | 133 |
|
130 | 134 | Returns: |
131 | 135 | torch.Tensor: The logits after applying the z-loss. |
132 | 136 | """ |
| 137 | + logsum = torch.logsumexp(logits, dim=-1) |
| 138 | + z_loss_values = torch.square(logsum) |
| 139 | + |
| 140 | + if padding_mask is not None: |
| 141 | + # Invert padding_mask: True (padding) -> 0, False (valid) -> 1 |
| 142 | + valid_mask = ~padding_mask |
| 143 | + # Only compute z_loss for valid (non-padding) tokens |
| 144 | + z_loss_values = z_loss_values * valid_mask |
| 145 | + # Compute mean over valid tokens only |
| 146 | + num_valid_tokens = valid_mask.sum() |
| 147 | + z_loss = z_loss_values.sum() / torch.clamp(num_valid_tokens, min=1.0) * z_loss_coeff |
| 148 | + else: |
| 149 | + z_loss = torch.mean(z_loss_values) * z_loss_coeff |
133 | 150 |
|
134 | | - z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff |
135 | 151 | return z_loss |
136 | 152 |
|
137 | 153 |
|
@@ -171,6 +187,28 @@ def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_ |
171 | 187 | return capacity |
172 | 188 |
|
173 | 189 |
|
| 190 | +def get_tokens_per_expert_and_token_count( |
| 191 | + routing_map: torch.Tensor, |
| 192 | + reduce_group: torch.distributed.ProcessGroup, |
| 193 | + topk: int = None, |
| 194 | + with_padding_mask: bool = False, |
| 195 | +) -> torch.Tensor: |
| 196 | + """ |
| 197 | + Compute global_tokens_per_expert, local_num_tokens and total_num_tokens with padding mask. |
| 198 | + """ |
| 199 | + local_tokens_per_expert = routing_map.sum(dim=0) |
| 200 | + global_tokens_per_expert = reduce_from_tensor_model_parallel_region( |
| 201 | + local_tokens_per_expert, reduce_group |
| 202 | + ) |
| 203 | + if with_padding_mask: |
| 204 | + local_num_tokens = local_tokens_per_expert.sum() / topk |
| 205 | + total_num_tokens = global_tokens_per_expert.sum() / topk |
| 206 | + else: |
| 207 | + local_num_tokens = routing_map.shape[0] |
| 208 | + total_num_tokens = local_num_tokens * reduce_group.size() |
| 209 | + return global_tokens_per_expert, local_num_tokens, total_num_tokens |
| 210 | + |
| 211 | + |
174 | 212 | class MoEAuxLossAutoScaler(torch.autograd.Function): |
175 | 213 | """An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss.""" |
176 | 214 |
|
@@ -629,35 +667,48 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): |
629 | 667 |
|
630 | 668 |
|
631 | 669 | def compute_routing_scores_for_aux_loss( |
632 | | - logits: torch.Tensor, topk: int, score_function: str, fused: bool = False |
| 670 | + logits: torch.Tensor, |
| 671 | + topk: int, |
| 672 | + score_function: str, |
| 673 | + fused: bool = False, |
| 674 | + padding_mask: Optional[torch.Tensor] = None, |
633 | 675 | ): |
634 | 676 | """Compute routing scores based on the score function. |
635 | 677 |
|
636 | 678 | Args: |
637 | 679 | logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts]. |
638 | | -
|
| 680 | + padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions. |
| 681 | + Shape [num_tokens]. True = padding (exclude), |
| 682 | + False = valid (include). Defaults to None. |
639 | 683 | Returns: |
640 | | - torch.Tensor: The normalized routing scores. |
| 684 | + Tuple[torch.Tensor, torch.Tensor]: routing_map and scores. |
641 | 685 | """ |
642 | 686 | if fused: |
643 | 687 | if not HAVE_TE or fused_compute_score_for_moe_aux_loss is None: |
644 | 688 | raise ValueError( |
645 | 689 | "fused_compute_score_for_moe_aux_loss is not available. Please install TE >= 2.6.0." |
646 | 690 | ) |
647 | | - return fused_compute_score_for_moe_aux_loss( |
| 691 | + routing_map, scores = fused_compute_score_for_moe_aux_loss( |
648 | 692 | logits=logits, topk=topk, score_function=score_function |
649 | 693 | ) |
650 | | - |
651 | | - if score_function == "softmax": |
652 | | - scores = torch.softmax(logits, dim=-1, dtype=torch.float32) |
653 | | - elif score_function == "sigmoid": |
654 | | - scores = torch.sigmoid(logits) |
655 | | - scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) |
656 | 694 | else: |
657 | | - raise ValueError(f"Invalid score_function: {score_function}") |
| 695 | + if score_function == "softmax": |
| 696 | + scores = torch.softmax(logits, dim=-1, dtype=torch.float32) |
| 697 | + elif score_function == "sigmoid": |
| 698 | + scores = torch.sigmoid(logits) |
| 699 | + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) |
| 700 | + else: |
| 701 | + raise ValueError(f"Invalid score_function: {score_function}") |
| 702 | + |
| 703 | + _, top_indices = torch.topk(scores, k=topk, dim=1) |
| 704 | + routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() |
658 | 705 |
|
659 | | - _, top_indices = torch.topk(scores, k=topk, dim=1) |
660 | | - routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() |
| 706 | + # Apply padding mask to scores if provided |
| 707 | + if padding_mask is not None: |
| 708 | + # Invert padding_mask and make True indicates valid tokens |
| 709 | + valid_mask = (~padding_mask).unsqueeze(-1) |
| 710 | + routing_map = routing_map * valid_mask |
| 711 | + scores = scores * valid_mask |
661 | 712 | return routing_map, scores |
662 | 713 |
|
663 | 714 |
|
|
0 commit comments