Skip to content

Commit 4e653f6

Browse files
committed
Use torch compile compatible next_positive_power_of_2
Signed-off-by: Jin Li <[email protected]>
1 parent e7c25d2 commit 4e653f6

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,22 @@
66

77
from tensorrt_llm._torch.utils import (fp4_utils,
88
get_last_power_of_2_num_tokens_buckets,
9-
last_positive_power_of_2)
9+
last_positive_power_of_2,
10+
next_positive_power_of_2)
1011

1112
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
1213
OptimizationProfile, TunableRunner, TuningConfig)
1314

1415

1516
def calculate_tile_tokens_dim(num_tokens: int, num_experts: int,
1617
top_k: int) -> int:
18+
# Guess tokens per expert assuming perfect expert distribution first.
1719
num_tokens_per_expert = num_tokens * top_k // num_experts
1820

19-
# Equivalent to the following:
20-
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
21-
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
22-
#
23-
# Torch dynamo cannot correctly track next_positive_power_of_2. Each shape
24-
# passed to next_positive_power_of_2 will trigger a new recompile.
25-
# Following code still triggers recompile. But it at most produces 4 additional recompiles.
26-
if num_tokens_per_expert <= 8:
27-
tile_tokens_dim = 8
28-
elif num_tokens_per_expert <= 16:
29-
tile_tokens_dim = 16
30-
elif num_tokens_per_expert <= 32:
31-
tile_tokens_dim = 32
32-
else:
33-
tile_tokens_dim = 64
21+
# And pad the number to the next power of 2.
22+
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
23+
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
24+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
3425

3526
return tile_tokens_dim
3627

tensorrt_llm/_torch/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,17 @@ def next_positive_power_of_2(x: int) -> int:
196196
if x < 1:
197197
return 1
198198

199-
return 1 << (x - 1).bit_length()
199+
# Following code is equivalent to 1 << (x - 1).bit_length()
200+
# But this impl does not contain bit_length() so can be used by torch compile.
201+
# It can correctly handle 64bit number which should be enough for now.
202+
n = x - 1
203+
n |= n >> 1
204+
n |= n >> 2
205+
n |= n >> 4
206+
n |= n >> 8
207+
n |= n >> 16
208+
n |= n >> 32
209+
return n + 1
200210

201211

202212
def last_positive_power_of_2(x: int) -> int:

0 commit comments

Comments
 (0)