|
1 | 1 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | + |
2 | 3 | import math |
3 | 4 | from dataclasses import dataclass |
4 | 5 | from typing import List, Optional, Union |
|
10 | 11 | from megatron.core.fp8_utils import get_fp8_align_size |
11 | 12 | from megatron.core.process_groups_config import ProcessGroupCollection |
12 | 13 | from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name |
13 | | -from megatron.core.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region |
14 | 14 | from megatron.core.transformer.cuda_graphs import is_graph_capturing |
15 | 15 | from megatron.core.transformer.enums import CudaGraphScope |
16 | 16 | from megatron.core.transformer.transformer_config import TransformerConfig |
@@ -120,34 +120,18 @@ def switch_load_balancing_loss_func( |
120 | 120 | return aux_loss |
121 | 121 |
|
122 | 122 |
|
123 | | -def z_loss_func(logits, z_loss_coeff, padding_mask: Optional[torch.Tensor] = None): |
| 123 | +def z_loss_func(logits, z_loss_coeff): |
124 | 124 | """Encourages the router's logits to remain small to enhance stability. |
125 | 125 | Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. |
126 | 126 |
|
127 | 127 | Args: |
128 | 128 | logits (torch.Tensor): The logits of the router. |
129 | | - z_loss_coeff (float): The coefficient for the z-loss. |
130 | | - padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions. |
131 | | - Shape [num_tokens]. True = padding (exclude), |
132 | | - False = valid (include). Defaults to None. |
133 | 129 |
|
134 | 130 | Returns: |
135 | 131 | torch.Tensor: The logits after applying the z-loss. |
136 | 132 | """ |
137 | | - logsum = torch.logsumexp(logits, dim=-1) |
138 | | - z_loss_values = torch.square(logsum) |
139 | | - |
140 | | - if padding_mask is not None: |
141 | | - # Invert padding_mask: True (padding) -> 0, False (valid) -> 1 |
142 | | - valid_mask = ~padding_mask |
143 | | - # Only compute z_loss for valid (non-padding) tokens |
144 | | - z_loss_values = z_loss_values * valid_mask |
145 | | - # Compute mean over valid tokens only |
146 | | - num_valid_tokens = valid_mask.sum() |
147 | | - z_loss = z_loss_values.sum() / torch.clamp(num_valid_tokens, min=1.0) * z_loss_coeff |
148 | | - else: |
149 | | - z_loss = torch.mean(z_loss_values) * z_loss_coeff |
150 | 133 |
|
| 134 | + z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff |
151 | 135 | return z_loss |
152 | 136 |
|
153 | 137 |
|
@@ -187,28 +171,6 @@ def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_ |
187 | 171 | return capacity |
188 | 172 |
|
189 | 173 |
|
190 | | -def get_tokens_per_expert_and_token_count( |
191 | | - routing_map: torch.Tensor, |
192 | | - reduce_group: torch.distributed.ProcessGroup, |
193 | | - topk: int = None, |
194 | | - with_padding_mask: bool = False, |
195 | | -) -> torch.Tensor: |
196 | | - """ |
197 | | - Compute global_tokens_per_expert, local_num_tokens and total_num_tokens with padding mask. |
198 | | - """ |
199 | | - local_tokens_per_expert = routing_map.sum(dim=0) |
200 | | - global_tokens_per_expert = reduce_from_tensor_model_parallel_region( |
201 | | - local_tokens_per_expert, reduce_group |
202 | | - ) |
203 | | - if with_padding_mask: |
204 | | - local_num_tokens = local_tokens_per_expert.sum() / topk |
205 | | - total_num_tokens = global_tokens_per_expert.sum() / topk |
206 | | - else: |
207 | | - local_num_tokens = routing_map.shape[0] |
208 | | - total_num_tokens = local_num_tokens * reduce_group.size() |
209 | | - return global_tokens_per_expert, local_num_tokens, total_num_tokens |
210 | | - |
211 | | - |
212 | 174 | class MoEAuxLossAutoScaler(torch.autograd.Function): |
213 | 175 | """An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss.""" |
214 | 176 |
|
@@ -667,48 +629,35 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): |
667 | 629 |
|
668 | 630 |
|
669 | 631 | def compute_routing_scores_for_aux_loss( |
670 | | - logits: torch.Tensor, |
671 | | - topk: int, |
672 | | - score_function: str, |
673 | | - fused: bool = False, |
674 | | - padding_mask: Optional[torch.Tensor] = None, |
| 632 | + logits: torch.Tensor, topk: int, score_function: str, fused: bool = False |
675 | 633 | ): |
676 | 634 | """Compute routing scores based on the score function. |
677 | 635 |
|
678 | 636 | Args: |
679 | 637 | logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts]. |
680 | | - padding_mask (torch.Tensor, optional): Boolean mask indicating padding positions. |
681 | | - Shape [num_tokens]. True = padding (exclude), |
682 | | - False = valid (include). Defaults to None. |
| 638 | +
|
683 | 639 | Returns: |
684 | | - Tuple[torch.Tensor, torch.Tensor]: routing_map and scores. |
| 640 | + torch.Tensor: The normalized routing scores. |
685 | 641 | """ |
686 | 642 | if fused: |
687 | 643 | if not HAVE_TE or fused_compute_score_for_moe_aux_loss is None: |
688 | 644 | raise ValueError( |
689 | 645 | "fused_compute_score_for_moe_aux_loss is not available. Please install TE >= 2.6.0." |
690 | 646 | ) |
691 | | - routing_map, scores = fused_compute_score_for_moe_aux_loss( |
| 647 | + return fused_compute_score_for_moe_aux_loss( |
692 | 648 | logits=logits, topk=topk, score_function=score_function |
693 | 649 | ) |
694 | | - else: |
695 | | - if score_function == "softmax": |
696 | | - scores = torch.softmax(logits, dim=-1, dtype=torch.float32) |
697 | | - elif score_function == "sigmoid": |
698 | | - scores = torch.sigmoid(logits) |
699 | | - scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) |
700 | | - else: |
701 | | - raise ValueError(f"Invalid score_function: {score_function}") |
702 | 650 |
|
703 | | - _, top_indices = torch.topk(scores, k=topk, dim=1) |
704 | | - routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() |
| 651 | + if score_function == "softmax": |
| 652 | + scores = torch.softmax(logits, dim=-1, dtype=torch.float32) |
| 653 | + elif score_function == "sigmoid": |
| 654 | + scores = torch.sigmoid(logits) |
| 655 | + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) |
| 656 | + else: |
| 657 | + raise ValueError(f"Invalid score_function: {score_function}") |
705 | 658 |
|
706 | | - # Apply padding mask to scores if provided |
707 | | - if padding_mask is not None: |
708 | | - # Invert padding_mask and make True indicates valid tokens |
709 | | - valid_mask = (~padding_mask).unsqueeze(-1) |
710 | | - routing_map = routing_map * valid_mask |
711 | | - scores = scores * valid_mask |
| 659 | + _, top_indices = torch.topk(scores, k=topk, dim=1) |
| 660 | + routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() |
712 | 661 | return routing_map, scores |
713 | 662 |
|
714 | 663 |
|
|
0 commit comments