Skip to content

Commit 8bbd45e

Browse files
committed
init
Signed-off-by: Superjomn <[email protected]>
1 parent d853811 commit 8bbd45e

File tree

2 files changed

+289
-53
lines changed

2 files changed

+289
-53
lines changed

tensorrt_llm/llmapi/llm_args.py

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,16 @@
6161
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
6262

6363

64-
class CudaGraphConfig(BaseModel):
64+
class StrictBaseModel(BaseModel):
65+
"""
66+
A base model that forbids arbitrary fields.
67+
"""
68+
69+
class Config:
70+
extra = "forbid" # globally forbid arbitrary fields
71+
72+
73+
class CudaGraphConfig(StrictBaseModel):
6574
"""
6675
Configuration for CUDA graphs.
6776
"""
@@ -88,8 +97,40 @@ def validate_cuda_graph_max_batch_size(cls, v):
8897
"cuda_graph_config.max_batch_size must be non-negative")
8998
return v
9099

100+
@staticmethod
101+
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
102+
enable_padding: bool) -> List[int]:
103+
"""Generate a list of batch sizes for CUDA graphs.
104+
105+
Args:
106+
max_batch_size: Maximum batch size to generate up to
107+
enable_padding: Whether padding is enabled, which affects the batch size distribution
108+
109+
Returns:
110+
List of batch sizes to create CUDA graphs for
111+
"""
112+
if enable_padding:
113+
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
114+
else:
115+
batch_sizes = list(range(1, 32)) + [32, 64, 128]
116+
117+
# Add powers of 2 up to max_batch_size
118+
batch_sizes += [
119+
2**i for i in range(8, math.floor(math.log(max_batch_size, 2)))
120+
]
121+
122+
# Filter and sort batch sizes
123+
batch_sizes = sorted(
124+
[size for size in batch_sizes if size <= max_batch_size])
125+
126+
# Add max_batch_size if not already included
127+
if max_batch_size != batch_sizes[-1]:
128+
batch_sizes.append(max_batch_size)
129+
130+
return batch_sizes
131+
91132

92-
class MoeConfig(BaseModel):
133+
class MoeConfig(StrictBaseModel):
93134
"""
94135
Configuration for MoE.
95136
"""
@@ -194,7 +235,7 @@ def to_mapping(self) -> Mapping:
194235
auto_parallel=self.auto_parallel)
195236

196237

197-
class CalibConfig(BaseModel):
238+
class CalibConfig(StrictBaseModel):
198239
"""
199240
Calibration configuration.
200241
"""
@@ -246,7 +287,7 @@ class _ModelFormatKind(Enum):
246287
TLLM_ENGINE = 2
247288

248289

249-
class DecodingBaseConfig(BaseModel):
290+
class DecodingBaseConfig(StrictBaseModel):
250291
max_draft_len: Optional[int] = None
251292
speculative_model_dir: Optional[Union[str, Path]] = None
252293

@@ -267,6 +308,7 @@ def from_dict(cls, data: dict):
267308
config_class = config_classes.get(decoding_type)
268309
if config_class is None:
269310
raise ValueError(f"Invalid decoding type: {decoding_type}")
311+
data.pop("decoding_type")
270312

271313
return config_class(**data)
272314

@@ -465,7 +507,7 @@ def mirror_pybind_fields(pybind_class):
465507
"""
466508

467509
def decorator(cls):
468-
assert issubclass(cls, BaseModel)
510+
assert issubclass(cls, StrictBaseModel)
469511
# Get all non-private fields from the C++ class
470512
cpp_fields = PybindMirror.get_pybind_variable_fields(pybind_class)
471513
python_fields = set(cls.model_fields.keys())
@@ -566,7 +608,7 @@ def _to_pybind(self):
566608

567609

568610
@PybindMirror.mirror_pybind_fields(_DynamicBatchConfig)
569-
class DynamicBatchConfig(BaseModel, PybindMirror):
611+
class DynamicBatchConfig(StrictBaseModel, PybindMirror):
570612
"""Dynamic batch configuration.
571613
572614
Controls how batch size and token limits are dynamically adjusted at runtime.
@@ -592,7 +634,7 @@ def _to_pybind(self):
592634

593635

594636
@PybindMirror.mirror_pybind_fields(_SchedulerConfig)
595-
class SchedulerConfig(BaseModel, PybindMirror):
637+
class SchedulerConfig(StrictBaseModel, PybindMirror):
596638
capacity_scheduler_policy: CapacitySchedulerPolicy = Field(
597639
default=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
598640
description="The capacity scheduler policy to use")
@@ -614,7 +656,7 @@ def _to_pybind(self):
614656

615657

616658
@PybindMirror.mirror_pybind_fields(_PeftCacheConfig)
617-
class PeftCacheConfig(BaseModel, PybindMirror):
659+
class PeftCacheConfig(StrictBaseModel, PybindMirror):
618660
"""
619661
Configuration for the PEFT cache.
620662
"""
@@ -742,7 +784,7 @@ def supports_backend(self, backend: str) -> bool:
742784

743785

744786
@PybindMirror.mirror_pybind_fields(_KvCacheConfig)
745-
class KvCacheConfig(BaseModel, PybindMirror):
787+
class KvCacheConfig(StrictBaseModel, PybindMirror):
746788
"""
747789
Configuration for the KV cache.
748790
"""
@@ -825,7 +867,7 @@ def _to_pybind(self):
825867

826868

827869
@PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig)
828-
class ExtendedRuntimePerfKnobConfig(BaseModel, PybindMirror):
870+
class ExtendedRuntimePerfKnobConfig(StrictBaseModel, PybindMirror):
829871
"""
830872
Configuration for extended runtime performance knobs.
831873
"""
@@ -856,7 +898,7 @@ def _to_pybind(self):
856898

857899

858900
@PybindMirror.mirror_pybind_fields(_CacheTransceiverConfig)
859-
class CacheTransceiverConfig(BaseModel, PybindMirror):
901+
class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
860902
"""
861903
Configuration for the cache transceiver.
862904
"""
@@ -916,7 +958,7 @@ def model_name(self) -> Union[str, Path]:
916958
return self.model if isinstance(self.model, str) else None
917959

918960

919-
class BaseLlmArgs(BaseModel):
961+
class BaseLlmArgs(StrictBaseModel):
920962
"""
921963
Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both.
922964
"""
@@ -1305,7 +1347,8 @@ def init_build_config(self):
13051347
"""
13061348
Creating a default BuildConfig if none is provided
13071349
"""
1308-
if self.build_config is None:
1350+
build_config = getattr(self, "build_config", None)
1351+
if build_config is None:
13091352
kwargs = {}
13101353
if self.max_batch_size:
13111354
kwargs["max_batch_size"] = self.max_batch_size
@@ -1318,10 +1361,10 @@ def init_build_config(self):
13181361
if self.max_input_len:
13191362
kwargs["max_input_len"] = self.max_input_len
13201363
self.build_config = BuildConfig(**kwargs)
1321-
1322-
assert isinstance(
1323-
self.build_config, BuildConfig
1324-
), f"build_config is not initialized: {self.build_config}"
1364+
else:
1365+
assert isinstance(
1366+
build_config,
1367+
BuildConfig), f"build_config is not initialized: {build_config}"
13251368
return self
13261369

13271370
@model_validator(mode="after")
@@ -1758,7 +1801,7 @@ class LoadFormat(Enum):
17581801
DUMMY = 1
17591802

17601803

1761-
class TorchCompileConfig(BaseModel):
1804+
class TorchCompileConfig(StrictBaseModel):
17621805
"""
17631806
Configuration for torch.compile.
17641807
"""
@@ -1972,38 +2015,6 @@ def validate_checkpoint_format(self):
19722015

19732016
return self
19742017

1975-
@staticmethod
1976-
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
1977-
enable_padding: bool) -> List[int]:
1978-
"""Generate a list of batch sizes for CUDA graphs.
1979-
1980-
Args:
1981-
max_batch_size: Maximum batch size to generate up to
1982-
enable_padding: Whether padding is enabled, which affects the batch size distribution
1983-
1984-
Returns:
1985-
List of batch sizes to create CUDA graphs for
1986-
"""
1987-
if enable_padding:
1988-
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
1989-
else:
1990-
batch_sizes = list(range(1, 32)) + [32, 64, 128]
1991-
1992-
# Add powers of 2 up to max_batch_size
1993-
batch_sizes += [
1994-
2**i for i in range(8, math.floor(math.log(max_batch_size, 2)))
1995-
]
1996-
1997-
# Filter and sort batch sizes
1998-
batch_sizes = sorted(
1999-
[size for size in batch_sizes if size <= max_batch_size])
2000-
2001-
# Add max_batch_size if not already included
2002-
if max_batch_size != batch_sizes[-1]:
2003-
batch_sizes.append(max_batch_size)
2004-
2005-
return batch_sizes
2006-
20072018
@model_validator(mode="after")
20082019
def validate_load_balancer(self) -> 'TorchLlmArgs':
20092020
from .._torch import MoeLoadBalancerConfig
@@ -2040,7 +2051,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
20402051
if config.batch_sizes:
20412052
config.batch_sizes = sorted(config.batch_sizes)
20422053
if config.max_batch_size != 0:
2043-
if config.batch_sizes != self._generate_cuda_graph_batch_sizes(
2054+
if config.batch_sizes != CudaGraphConfig._generate_cuda_graph_batch_sizes(
20442055
config.max_batch_size, config.enable_padding):
20452056
raise ValueError(
20462057
"Please don't set both cuda_graph_config.batch_sizes "
@@ -2052,7 +2063,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
20522063
config.max_batch_size = max(config.batch_sizes)
20532064
else:
20542065
max_batch_size = config.max_batch_size or 128
2055-
generated_sizes = self._generate_cuda_graph_batch_sizes(
2066+
generated_sizes = CudaGraphConfig._generate_cuda_graph_batch_sizes(
20562067
max_batch_size, config.enable_padding)
20572068
config.batch_sizes = generated_sizes
20582069
config.max_batch_size = max_batch_size

0 commit comments

Comments
 (0)