61
61
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
62
62
63
63
64
- class CudaGraphConfig (BaseModel ):
64
+ class StrictBaseModel (BaseModel ):
65
+ """
66
+ A base model that forbids arbitrary fields.
67
+ """
68
+
69
+ class Config :
70
+ extra = "forbid" # globally forbid arbitrary fields
71
+
72
+
73
+ class CudaGraphConfig (StrictBaseModel ):
65
74
"""
66
75
Configuration for CUDA graphs.
67
76
"""
@@ -88,8 +97,40 @@ def validate_cuda_graph_max_batch_size(cls, v):
88
97
"cuda_graph_config.max_batch_size must be non-negative" )
89
98
return v
90
99
100
+ @staticmethod
101
+ def _generate_cuda_graph_batch_sizes (max_batch_size : int ,
102
+ enable_padding : bool ) -> List [int ]:
103
+ """Generate a list of batch sizes for CUDA graphs.
104
+
105
+ Args:
106
+ max_batch_size: Maximum batch size to generate up to
107
+ enable_padding: Whether padding is enabled, which affects the batch size distribution
108
+
109
+ Returns:
110
+ List of batch sizes to create CUDA graphs for
111
+ """
112
+ if enable_padding :
113
+ batch_sizes = [1 , 2 , 4 ] + [i * 8 for i in range (1 , 17 )]
114
+ else :
115
+ batch_sizes = list (range (1 , 32 )) + [32 , 64 , 128 ]
116
+
117
+ # Add powers of 2 up to max_batch_size
118
+ batch_sizes += [
119
+ 2 ** i for i in range (8 , math .floor (math .log (max_batch_size , 2 )))
120
+ ]
121
+
122
+ # Filter and sort batch sizes
123
+ batch_sizes = sorted (
124
+ [size for size in batch_sizes if size <= max_batch_size ])
125
+
126
+ # Add max_batch_size if not already included
127
+ if max_batch_size != batch_sizes [- 1 ]:
128
+ batch_sizes .append (max_batch_size )
129
+
130
+ return batch_sizes
131
+
91
132
92
- class MoeConfig (BaseModel ):
133
+ class MoeConfig (StrictBaseModel ):
93
134
"""
94
135
Configuration for MoE.
95
136
"""
@@ -194,7 +235,7 @@ def to_mapping(self) -> Mapping:
194
235
auto_parallel = self .auto_parallel )
195
236
196
237
197
- class CalibConfig (BaseModel ):
238
+ class CalibConfig (StrictBaseModel ):
198
239
"""
199
240
Calibration configuration.
200
241
"""
@@ -246,7 +287,7 @@ class _ModelFormatKind(Enum):
246
287
TLLM_ENGINE = 2
247
288
248
289
249
- class DecodingBaseConfig (BaseModel ):
290
+ class DecodingBaseConfig (StrictBaseModel ):
250
291
max_draft_len : Optional [int ] = None
251
292
speculative_model_dir : Optional [Union [str , Path ]] = None
252
293
@@ -267,6 +308,7 @@ def from_dict(cls, data: dict):
267
308
config_class = config_classes .get (decoding_type )
268
309
if config_class is None :
269
310
raise ValueError (f"Invalid decoding type: { decoding_type } " )
311
+ data .pop ("decoding_type" )
270
312
271
313
return config_class (** data )
272
314
@@ -465,7 +507,7 @@ def mirror_pybind_fields(pybind_class):
465
507
"""
466
508
467
509
def decorator (cls ):
468
- assert issubclass (cls , BaseModel )
510
+ assert issubclass (cls , StrictBaseModel )
469
511
# Get all non-private fields from the C++ class
470
512
cpp_fields = PybindMirror .get_pybind_variable_fields (pybind_class )
471
513
python_fields = set (cls .model_fields .keys ())
@@ -566,7 +608,7 @@ def _to_pybind(self):
566
608
567
609
568
610
@PybindMirror .mirror_pybind_fields (_DynamicBatchConfig )
569
- class DynamicBatchConfig (BaseModel , PybindMirror ):
611
+ class DynamicBatchConfig (StrictBaseModel , PybindMirror ):
570
612
"""Dynamic batch configuration.
571
613
572
614
Controls how batch size and token limits are dynamically adjusted at runtime.
@@ -592,7 +634,7 @@ def _to_pybind(self):
592
634
593
635
594
636
@PybindMirror .mirror_pybind_fields (_SchedulerConfig )
595
- class SchedulerConfig (BaseModel , PybindMirror ):
637
+ class SchedulerConfig (StrictBaseModel , PybindMirror ):
596
638
capacity_scheduler_policy : CapacitySchedulerPolicy = Field (
597
639
default = CapacitySchedulerPolicy .GUARANTEED_NO_EVICT ,
598
640
description = "The capacity scheduler policy to use" )
@@ -614,7 +656,7 @@ def _to_pybind(self):
614
656
615
657
616
658
@PybindMirror .mirror_pybind_fields (_PeftCacheConfig )
617
- class PeftCacheConfig (BaseModel , PybindMirror ):
659
+ class PeftCacheConfig (StrictBaseModel , PybindMirror ):
618
660
"""
619
661
Configuration for the PEFT cache.
620
662
"""
@@ -742,7 +784,7 @@ def supports_backend(self, backend: str) -> bool:
742
784
743
785
744
786
@PybindMirror .mirror_pybind_fields (_KvCacheConfig )
745
- class KvCacheConfig (BaseModel , PybindMirror ):
787
+ class KvCacheConfig (StrictBaseModel , PybindMirror ):
746
788
"""
747
789
Configuration for the KV cache.
748
790
"""
@@ -825,7 +867,7 @@ def _to_pybind(self):
825
867
826
868
827
869
@PybindMirror .mirror_pybind_fields (_ExtendedRuntimePerfKnobConfig )
828
- class ExtendedRuntimePerfKnobConfig (BaseModel , PybindMirror ):
870
+ class ExtendedRuntimePerfKnobConfig (StrictBaseModel , PybindMirror ):
829
871
"""
830
872
Configuration for extended runtime performance knobs.
831
873
"""
@@ -856,7 +898,7 @@ def _to_pybind(self):
856
898
857
899
858
900
@PybindMirror .mirror_pybind_fields (_CacheTransceiverConfig )
859
- class CacheTransceiverConfig (BaseModel , PybindMirror ):
901
+ class CacheTransceiverConfig (StrictBaseModel , PybindMirror ):
860
902
"""
861
903
Configuration for the cache transceiver.
862
904
"""
@@ -916,7 +958,7 @@ def model_name(self) -> Union[str, Path]:
916
958
return self .model if isinstance (self .model , str ) else None
917
959
918
960
919
- class BaseLlmArgs (BaseModel ):
961
+ class BaseLlmArgs (StrictBaseModel ):
920
962
"""
921
963
Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both.
922
964
"""
@@ -1305,7 +1347,8 @@ def init_build_config(self):
1305
1347
"""
1306
1348
Creating a default BuildConfig if none is provided
1307
1349
"""
1308
- if self .build_config is None :
1350
+ build_config = getattr (self , "build_config" , None )
1351
+ if build_config is None :
1309
1352
kwargs = {}
1310
1353
if self .max_batch_size :
1311
1354
kwargs ["max_batch_size" ] = self .max_batch_size
@@ -1318,10 +1361,10 @@ def init_build_config(self):
1318
1361
if self .max_input_len :
1319
1362
kwargs ["max_input_len" ] = self .max_input_len
1320
1363
self .build_config = BuildConfig (** kwargs )
1321
-
1322
- assert isinstance (
1323
- self . build_config , BuildConfig
1324
- ), f"build_config is not initialized: { self . build_config } "
1364
+ else :
1365
+ assert isinstance (
1366
+ build_config ,
1367
+ BuildConfig ), f"build_config is not initialized: { build_config } "
1325
1368
return self
1326
1369
1327
1370
@model_validator (mode = "after" )
@@ -1758,7 +1801,7 @@ class LoadFormat(Enum):
1758
1801
DUMMY = 1
1759
1802
1760
1803
1761
- class TorchCompileConfig (BaseModel ):
1804
+ class TorchCompileConfig (StrictBaseModel ):
1762
1805
"""
1763
1806
Configuration for torch.compile.
1764
1807
"""
@@ -1972,38 +2015,6 @@ def validate_checkpoint_format(self):
1972
2015
1973
2016
return self
1974
2017
1975
- @staticmethod
1976
- def _generate_cuda_graph_batch_sizes (max_batch_size : int ,
1977
- enable_padding : bool ) -> List [int ]:
1978
- """Generate a list of batch sizes for CUDA graphs.
1979
-
1980
- Args:
1981
- max_batch_size: Maximum batch size to generate up to
1982
- enable_padding: Whether padding is enabled, which affects the batch size distribution
1983
-
1984
- Returns:
1985
- List of batch sizes to create CUDA graphs for
1986
- """
1987
- if enable_padding :
1988
- batch_sizes = [1 , 2 , 4 ] + [i * 8 for i in range (1 , 17 )]
1989
- else :
1990
- batch_sizes = list (range (1 , 32 )) + [32 , 64 , 128 ]
1991
-
1992
- # Add powers of 2 up to max_batch_size
1993
- batch_sizes += [
1994
- 2 ** i for i in range (8 , math .floor (math .log (max_batch_size , 2 )))
1995
- ]
1996
-
1997
- # Filter and sort batch sizes
1998
- batch_sizes = sorted (
1999
- [size for size in batch_sizes if size <= max_batch_size ])
2000
-
2001
- # Add max_batch_size if not already included
2002
- if max_batch_size != batch_sizes [- 1 ]:
2003
- batch_sizes .append (max_batch_size )
2004
-
2005
- return batch_sizes
2006
-
2007
2018
@model_validator (mode = "after" )
2008
2019
def validate_load_balancer (self ) -> 'TorchLlmArgs' :
2009
2020
from .._torch import MoeLoadBalancerConfig
@@ -2040,7 +2051,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
2040
2051
if config .batch_sizes :
2041
2052
config .batch_sizes = sorted (config .batch_sizes )
2042
2053
if config .max_batch_size != 0 :
2043
- if config .batch_sizes != self ._generate_cuda_graph_batch_sizes (
2054
+ if config .batch_sizes != CudaGraphConfig ._generate_cuda_graph_batch_sizes (
2044
2055
config .max_batch_size , config .enable_padding ):
2045
2056
raise ValueError (
2046
2057
"Please don't set both cuda_graph_config.batch_sizes "
@@ -2052,7 +2063,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
2052
2063
config .max_batch_size = max (config .batch_sizes )
2053
2064
else :
2054
2065
max_batch_size = config .max_batch_size or 128
2055
- generated_sizes = self ._generate_cuda_graph_batch_sizes (
2066
+ generated_sizes = CudaGraphConfig ._generate_cuda_graph_batch_sizes (
2056
2067
max_batch_size , config .enable_padding )
2057
2068
config .batch_sizes = generated_sizes
2058
2069
config .max_batch_size = max_batch_size
0 commit comments