Skip to content

Commit 5e06672

Browse files
committed
Revert "TRTLLM-6142: set torch recompile_limit based on cuda_graph_batch_sizes and refactored (#119)"
This reverts commit 2c81f8a.
1 parent 2c81f8a commit 5e06672

File tree

5 files changed

+29
-52
lines changed

5 files changed

+29
-52
lines changed

examples/auto_deploy/build_and_run_ad.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from tensorrt_llm.llmapi.llm import RequestOutput
2424
from tensorrt_llm.sampling_params import SamplingParams
2525

26+
# Global torch config, set the torch compile cache to fix up to llama 405B
27+
torch._dynamo.config.cache_size_limit = 20
28+
2629

2730
class PromptConfig(BaseModel):
2831
"""Prompt configuration.

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_compile.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,11 @@
33
import torch
44
import torch.nn as nn
55

6-
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
7-
86
from ..compiler import BackendCompiler, BackendRegistry
97

108

119
@BackendRegistry.register("torch-compile")
1210
class TorchCompileCompiler(BackendCompiler):
13-
def __init__(self, *args, **kwargs):
14-
super().__init__(*args, **kwargs)
15-
# Global torch config, set the torch compile cache to fix up to llama 405B
16-
torch._dynamo.config.cache_size_limit = 20
17-
ad_logger.info(f"Setting cache size limit to {torch._dynamo.config.cache_size_limit}")
18-
1911
def compile(self) -> nn.Module:
2012
"""Compile the model using torch.compile."""
2113
return torch.compile(self.gm, dynamic=True)

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,27 @@ def __init__(
1818
model: nn.Module,
1919
in_spec: TreeSpec,
2020
out_spec: TreeSpec,
21-
cuda_graph_batch_sizes: List[int],
21+
max_batch_size: int,
22+
cuda_graph_batch_sizes: List[int] = None,
2223
num_batched_inputs: Optional[int] = 1, # number of batched, dynamic inputs...
2324
):
2425
super().__init__()
2526
self._in_spec = in_spec
2627
self._out_spec = out_spec
2728
self.model = model
28-
self.max_batch_size = max(cuda_graph_batch_sizes)
29-
ad_logger.info(f"Setting max batch size to {self.max_batch_size}")
29+
self.max_batch_size = max_batch_size
3030
self.num_batched_inputs = num_batched_inputs if num_batched_inputs is not None else 1
3131
self.graphs: Dict[Tuple[int, ...], CUDAGraph] = {}
3232
self._input_buffers: List[torch.Tensor] = [
3333
torch.empty(0, 1) for _ in range(self.num_batched_inputs)
3434
]
3535
self._out_buffer_flat: List[torch.Tensor] = None
3636
self._args_hash: Optional[Tuple[int, ...]] = None
37-
self.cuda_graph_batch_sizes = sorted(cuda_graph_batch_sizes, reverse=True)
37+
self.cuda_graph_batch_sizes = (
38+
sorted(cuda_graph_batch_sizes, reverse=True)
39+
if cuda_graph_batch_sizes is not None
40+
else self._get_graph_batch_sizes(self.max_batch_size)
41+
)
3842
self._cuda_graph_mem_pool = None
3943

4044
def _get_hash(self, flat_args: List[Any]) -> Tuple[int, ...]:
@@ -73,6 +77,20 @@ def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
7377
self._cuda_graph_mem_pool = self._cuda_graph_mem_pool or graph.pool()
7478
return graph
7579

80+
@staticmethod
81+
def _get_graph_batch_sizes(
82+
max_bs: int, extra: Optional[List[int]] = None, multiplier: int = 128
83+
) -> List[int]:
84+
"""Heuristic to set batch sizes for graph capture."""
85+
# do 1, max_bs, and extra as special batch sizes
86+
batch_sizes = {1, max_bs, *(extra or [])}
87+
88+
# add all multiples of multiplier up to max_bs
89+
batch_sizes.update(range(multiplier, max_bs + 1, multiplier))
90+
91+
# return as sorted list
92+
return sorted(batch_sizes, reverse=True)
93+
7694
def capture_graph(self, *args, **kwargs):
7795
"""Capture and pre-fetch the graph for variable batch size."""
7896
# flatten args, kwargs
@@ -159,21 +177,15 @@ def forward(self, *args, **kwargs) -> Any:
159177
class TorchCudagraphCompiler(BackendCompiler):
160178
"""Compiler that uses only CUDA graphs."""
161179

162-
def __init__(self, *args, **kwargs):
163-
super().__init__(*args, **kwargs)
164-
self.cuda_graph_batch_sizes = self.compiler_kwargs.get("cuda_graph_batch_sizes")
165-
if not self.cuda_graph_batch_sizes:
166-
self.cuda_graph_batch_sizes = self._get_graph_batch_sizes(self.max_batch_size)
167-
ad_logger.info(f"Setting cuda_graph_batch_sizes to {self.cuda_graph_batch_sizes}")
168-
169180
def _init_captured_graph(
170181
self, gm: nn.Module, in_spec: TreeSpec, out_spec: TreeSpec
171182
) -> CapturedGraph:
172183
return CapturedGraph(
173184
gm,
174185
in_spec=in_spec,
175186
out_spec=out_spec,
176-
cuda_graph_batch_sizes=self.cuda_graph_batch_sizes,
187+
max_batch_size=self.max_batch_size,
188+
cuda_graph_batch_sizes=self.compiler_kwargs.get("cuda_graph_batch_sizes"),
177189
num_batched_inputs=self.compiler_kwargs.get("num_batched_inputs"),
178190
)
179191

@@ -186,17 +198,3 @@ def compile(self) -> CapturedGraph:
186198
captured_model.capture_graph(*self.args, **self.kwargs)
187199

188200
return captured_model
189-
190-
@staticmethod
191-
def _get_graph_batch_sizes(
192-
max_bs: int, extra: Optional[List[int]] = None, multiplier: int = 128
193-
) -> List[int]:
194-
"""Heuristic to set batch sizes for graph capture."""
195-
# do 1, max_bs, and extra as special batch sizes
196-
batch_sizes = {1, max_bs, *(extra or [])}
197-
198-
# add all multiples of multiplier up to max_bs
199-
batch_sizes.update(range(multiplier, max_bs + 1, multiplier))
200-
201-
# return as sorted list
202-
return sorted(batch_sizes, reverse=True)

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_opt.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
import torch
44

5-
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
6-
75
from ..compiler import BackendRegistry
86
from .torch_cudagraph import CapturedGraph, TorchCudagraphCompiler
97

@@ -12,17 +10,6 @@
1210
class TorchOptCompiler(TorchCudagraphCompiler):
1311
"""Compiler that uses both torch.compile and CUDA graphs."""
1412

15-
def __init__(self, *args, **kwargs):
16-
super().__init__(*args, **kwargs)
17-
torch._dynamo.config.recompile_limit = max(
18-
len(self.cuda_graph_batch_sizes), torch._dynamo.config.recompile_limit
19-
)
20-
ad_logger.info(f"Setting recompile limit to {torch._dynamo.config.recompile_limit}")
21-
22-
# Global torch config, set the torch compile cache to fix up to llama 405B
23-
torch._dynamo.config.cache_size_limit = 20
24-
ad_logger.info(f"Setting cache size limit to {torch._dynamo.config.cache_size_limit}")
25-
2613
def _init_captured_graph(self, gm, in_spec, out_spec) -> CapturedGraph:
2714
gm = torch.compile(gm, dynamic=True)
2815
return super()._init_captured_graph(gm, in_spec, out_spec)

tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import tempfile
33
from pathlib import Path
44

5-
import pytest
65
import yaml
76
from _model_test_utils import _hf_model_dir_or_hub_id
87
from click.testing import CliRunner
@@ -65,8 +64,7 @@ def run_benchmark(model_name: str, dataset_path: str, temp_dir: str):
6564
assert result.exit_code == 0
6665

6766

68-
@pytest.mark.parametrize("compile_backend", ["torch-compile", "torch-opt", "torch-cudagraph"])
69-
def test_trtllm_bench(llm_root, compile_backend): # noqa: F811
67+
def test_trtllm_bench(llm_root): # noqa: F811
7068
model_name = _hf_model_dir_or_hub_id(
7169
f"{llm_models_root()}/TinyLlama-1.1B-Chat-v1.0", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
7270
)
@@ -76,9 +74,8 @@ def test_trtllm_bench(llm_root, compile_backend): # noqa: F811
7674
yaml.dump(
7775
{
7876
"model_kwargs": {"num_hidden_layers": 2},
79-
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
77+
"cuda_graph_batch_sizes": [1, 2],
8078
"max_batch_size": 128,
81-
"compile_backend": compile_backend,
8279
},
8380
f,
8481
)

0 commit comments

Comments
 (0)