|
6 | 6 |
|
7 | 7 | from tensorrt_llm._torch.utils import (fp4_utils,
|
8 | 8 | get_last_power_of_2_num_tokens_buckets,
|
9 |
| - last_positive_power_of_2) |
| 9 | + last_positive_power_of_2, |
| 10 | + next_positive_power_of_2) |
10 | 11 |
|
11 | 12 | from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
|
12 | 13 | OptimizationProfile, TunableRunner, TuningConfig)
|
13 | 14 |
|
14 | 15 |
|
15 | 16 | def calculate_tile_tokens_dim(num_tokens: int, num_experts: int,
|
16 | 17 | top_k: int) -> int:
|
| 18 | + # Guess tokens per expert assuming perfect expert distribution first. |
17 | 19 | num_tokens_per_expert = num_tokens * top_k // num_experts
|
18 | 20 |
|
19 |
| - # Equivalent to the following: |
20 |
| - # tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) |
21 |
| - # tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) |
22 |
| - # |
23 |
| - # Torch dynamo cannot correctly track next_positive_power_of_2. Each shape |
24 |
| - # passed to next_positive_power_of_2 will trigger a new recompile. |
25 |
| - # Following code still triggers recompile. But it at most produces 4 additional recompiles. |
26 |
| - if num_tokens_per_expert <= 8: |
27 |
| - tile_tokens_dim = 8 |
28 |
| - elif num_tokens_per_expert <= 16: |
29 |
| - tile_tokens_dim = 16 |
30 |
| - elif num_tokens_per_expert <= 32: |
31 |
| - tile_tokens_dim = 32 |
32 |
| - else: |
33 |
| - tile_tokens_dim = 64 |
| 21 | + # And pad the number to the next power of 2. |
| 22 | + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) |
| 23 | + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. |
| 24 | + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) |
34 | 25 |
|
35 | 26 | return tile_tokens_dim
|
36 | 27 |
|
|
0 commit comments