Skip to content

Commit 2c81f8a

Browse files
MrGevalucaslie
authored andcommitted
TRTLLM-6142: set torch recompile_limit based on cuda_graph_batch_sizes and refactored (#119)
* refactored compile_limit Signed-off-by: Eran Geva <[email protected]> * removed changes made to TorchCompileCompiler Signed-off-by: Eran Geva <[email protected]> * set cache_size_limit in TorchCompileCompiler Signed-off-by: Eran Geva <[email protected]> --------- Signed-off-by: Eran Geva <[email protected]>
1 parent b075c27 commit 2c81f8a

File tree

5 files changed

+52
-29
lines changed

5 files changed

+52
-29
lines changed

examples/auto_deploy/build_and_run_ad.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
from tensorrt_llm.llmapi.llm import RequestOutput
2424
from tensorrt_llm.sampling_params import SamplingParams
2525

26-
# Global torch config, set the torch compile cache to fix up to llama 405B
27-
torch._dynamo.config.cache_size_limit = 20
28-
2926

3027
class PromptConfig(BaseModel):
3128
"""Prompt configuration.

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_compile.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,19 @@
33
import torch
44
import torch.nn as nn
55

6+
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
7+
68
from ..compiler import BackendCompiler, BackendRegistry
79

810

911
@BackendRegistry.register("torch-compile")
1012
class TorchCompileCompiler(BackendCompiler):
13+
def __init__(self, *args, **kwargs):
14+
super().__init__(*args, **kwargs)
15+
# Global torch config, set the torch compile cache to fix up to llama 405B
16+
torch._dynamo.config.cache_size_limit = 20
17+
ad_logger.info(f"Setting cache size limit to {torch._dynamo.config.cache_size_limit}")
18+
1119
def compile(self) -> nn.Module:
1220
"""Compile the model using torch.compile."""
1321
return torch.compile(self.gm, dynamic=True)

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,23 @@ def __init__(
1818
model: nn.Module,
1919
in_spec: TreeSpec,
2020
out_spec: TreeSpec,
21-
max_batch_size: int,
22-
cuda_graph_batch_sizes: List[int] = None,
21+
cuda_graph_batch_sizes: List[int],
2322
num_batched_inputs: Optional[int] = 1, # number of batched, dynamic inputs...
2423
):
2524
super().__init__()
2625
self._in_spec = in_spec
2726
self._out_spec = out_spec
2827
self.model = model
29-
self.max_batch_size = max_batch_size
28+
self.max_batch_size = max(cuda_graph_batch_sizes)
29+
ad_logger.info(f"Setting max batch size to {self.max_batch_size}")
3030
self.num_batched_inputs = num_batched_inputs if num_batched_inputs is not None else 1
3131
self.graphs: Dict[Tuple[int, ...], CUDAGraph] = {}
3232
self._input_buffers: List[torch.Tensor] = [
3333
torch.empty(0, 1) for _ in range(self.num_batched_inputs)
3434
]
3535
self._out_buffer_flat: List[torch.Tensor] = None
3636
self._args_hash: Optional[Tuple[int, ...]] = None
37-
self.cuda_graph_batch_sizes = (
38-
sorted(cuda_graph_batch_sizes, reverse=True)
39-
if cuda_graph_batch_sizes is not None
40-
else self._get_graph_batch_sizes(self.max_batch_size)
41-
)
37+
self.cuda_graph_batch_sizes = sorted(cuda_graph_batch_sizes, reverse=True)
4238
self._cuda_graph_mem_pool = None
4339

4440
def _get_hash(self, flat_args: List[Any]) -> Tuple[int, ...]:
@@ -77,20 +73,6 @@ def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
7773
self._cuda_graph_mem_pool = self._cuda_graph_mem_pool or graph.pool()
7874
return graph
7975

80-
@staticmethod
81-
def _get_graph_batch_sizes(
82-
max_bs: int, extra: Optional[List[int]] = None, multiplier: int = 128
83-
) -> List[int]:
84-
"""Heuristic to set batch sizes for graph capture."""
85-
# do 1, max_bs, and extra as special batch sizes
86-
batch_sizes = {1, max_bs, *(extra or [])}
87-
88-
# add all multiples of multiplier up to max_bs
89-
batch_sizes.update(range(multiplier, max_bs + 1, multiplier))
90-
91-
# return as sorted list
92-
return sorted(batch_sizes, reverse=True)
93-
9476
def capture_graph(self, *args, **kwargs):
9577
"""Capture and pre-fetch the graph for variable batch size."""
9678
# flatten args, kwargs
@@ -177,15 +159,21 @@ def forward(self, *args, **kwargs) -> Any:
177159
class TorchCudagraphCompiler(BackendCompiler):
178160
"""Compiler that uses only CUDA graphs."""
179161

162+
def __init__(self, *args, **kwargs):
163+
super().__init__(*args, **kwargs)
164+
self.cuda_graph_batch_sizes = self.compiler_kwargs.get("cuda_graph_batch_sizes")
165+
if not self.cuda_graph_batch_sizes:
166+
self.cuda_graph_batch_sizes = self._get_graph_batch_sizes(self.max_batch_size)
167+
ad_logger.info(f"Setting cuda_graph_batch_sizes to {self.cuda_graph_batch_sizes}")
168+
180169
def _init_captured_graph(
181170
self, gm: nn.Module, in_spec: TreeSpec, out_spec: TreeSpec
182171
) -> CapturedGraph:
183172
return CapturedGraph(
184173
gm,
185174
in_spec=in_spec,
186175
out_spec=out_spec,
187-
max_batch_size=self.max_batch_size,
188-
cuda_graph_batch_sizes=self.compiler_kwargs.get("cuda_graph_batch_sizes"),
176+
cuda_graph_batch_sizes=self.cuda_graph_batch_sizes,
189177
num_batched_inputs=self.compiler_kwargs.get("num_batched_inputs"),
190178
)
191179

@@ -198,3 +186,17 @@ def compile(self) -> CapturedGraph:
198186
captured_model.capture_graph(*self.args, **self.kwargs)
199187

200188
return captured_model
189+
190+
@staticmethod
191+
def _get_graph_batch_sizes(
192+
max_bs: int, extra: Optional[List[int]] = None, multiplier: int = 128
193+
) -> List[int]:
194+
"""Heuristic to set batch sizes for graph capture."""
195+
# do 1, max_bs, and extra as special batch sizes
196+
batch_sizes = {1, max_bs, *(extra or [])}
197+
198+
# add all multiples of multiplier up to max_bs
199+
batch_sizes.update(range(multiplier, max_bs + 1, multiplier))
200+
201+
# return as sorted list
202+
return sorted(batch_sizes, reverse=True)

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_opt.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import torch
44

5+
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
6+
57
from ..compiler import BackendRegistry
68
from .torch_cudagraph import CapturedGraph, TorchCudagraphCompiler
79

@@ -10,6 +12,17 @@
1012
class TorchOptCompiler(TorchCudagraphCompiler):
1113
"""Compiler that uses both torch.compile and CUDA graphs."""
1214

15+
def __init__(self, *args, **kwargs):
16+
super().__init__(*args, **kwargs)
17+
torch._dynamo.config.recompile_limit = max(
18+
len(self.cuda_graph_batch_sizes), torch._dynamo.config.recompile_limit
19+
)
20+
ad_logger.info(f"Setting recompile limit to {torch._dynamo.config.recompile_limit}")
21+
22+
# Global torch config, set the torch compile cache to fix up to llama 405B
23+
torch._dynamo.config.cache_size_limit = 20
24+
ad_logger.info(f"Setting cache size limit to {torch._dynamo.config.cache_size_limit}")
25+
1326
def _init_captured_graph(self, gm, in_spec, out_spec) -> CapturedGraph:
1427
gm = torch.compile(gm, dynamic=True)
1528
return super()._init_captured_graph(gm, in_spec, out_spec)

tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import tempfile
33
from pathlib import Path
44

5+
import pytest
56
import yaml
67
from _model_test_utils import _hf_model_dir_or_hub_id
78
from click.testing import CliRunner
@@ -64,7 +65,8 @@ def run_benchmark(model_name: str, dataset_path: str, temp_dir: str):
6465
assert result.exit_code == 0
6566

6667

67-
def test_trtllm_bench(llm_root): # noqa: F811
68+
@pytest.mark.parametrize("compile_backend", ["torch-compile", "torch-opt", "torch-cudagraph"])
69+
def test_trtllm_bench(llm_root, compile_backend): # noqa: F811
6870
model_name = _hf_model_dir_or_hub_id(
6971
f"{llm_models_root()}/TinyLlama-1.1B-Chat-v1.0", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
7072
)
@@ -74,8 +76,9 @@ def test_trtllm_bench(llm_root): # noqa: F811
7476
yaml.dump(
7577
{
7678
"model_kwargs": {"num_hidden_layers": 2},
77-
"cuda_graph_batch_sizes": [1, 2],
79+
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
7880
"max_batch_size": 128,
81+
"compile_backend": compile_backend,
7982
},
8083
f,
8184
)

0 commit comments

Comments
 (0)