12
12
QuantizationStrategy ,
13
13
QuantizationType )
14
14
from compressed_tensors .transform import TransformConfig
15
- from pydantic import BaseModel
16
15
17
16
import vllm .envs as envs
18
17
from vllm .logger import init_logger
@@ -268,7 +267,8 @@ def _check_scheme_supported(self,
268
267
else :
269
268
return False
270
269
271
- def _is_fp4a4_nvfp4 (self , weight_quant : BaseModel , input_quant : BaseModel ):
270
+ def _is_fp4a4_nvfp4 (self , weight_quant : QuantizationArgs ,
271
+ input_quant : QuantizationArgs ):
272
272
273
273
if weight_quant is None or input_quant is None :
274
274
return False
@@ -288,8 +288,8 @@ def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
288
288
return (is_tensor_group_quant and is_float_type and is_4_bits
289
289
and is_group_size_16 and is_symmetric )
290
290
291
- def _is_fp4a16_nvfp4 (self , weight_quant : BaseModel ,
292
- input_quant : BaseModel ):
291
+ def _is_fp4a16_nvfp4 (self , weight_quant : QuantizationArgs ,
292
+ input_quant : QuantizationArgs ):
293
293
294
294
is_weight_only = weight_quant is not None and input_quant is None
295
295
is_tensor_group_quant = (
@@ -303,8 +303,8 @@ def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
303
303
return (is_weight_only and is_tensor_group_quant and is_float_type
304
304
and is_4_bits and is_group_size_16 and is_symmetric )
305
305
306
- def _is_static_tensor_w8a8 (self , weight_quant : BaseModel ,
307
- input_quant : BaseModel ) -> bool :
306
+ def _is_static_tensor_w8a8 (self , weight_quant : QuantizationArgs ,
307
+ input_quant : QuantizationArgs ) -> bool :
308
308
is_8_bits = weight_quant .num_bits == input_quant .num_bits == 8
309
309
weight_strategy = (
310
310
weight_quant .strategy == QuantizationStrategy .TENSOR .value
@@ -317,8 +317,8 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
317
317
# Only symmetric weight quantization supported.
318
318
return is_8_bits and is_tensor and weight_quant .symmetric and is_static
319
319
320
- def _is_dynamic_token_w8a8 (self , weight_quant : BaseModel ,
321
- input_quant : BaseModel ) -> bool :
320
+ def _is_dynamic_token_w8a8 (self , weight_quant : QuantizationArgs ,
321
+ input_quant : QuantizationArgs ) -> bool :
322
322
is_8_bits = weight_quant .num_bits == input_quant .num_bits == 8
323
323
weight_strategy = (
324
324
weight_quant .strategy == QuantizationStrategy .TENSOR .value
@@ -331,8 +331,8 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
331
331
# Only symmetric weight quantization supported.
332
332
return is_8_bits and is_token and weight_quant .symmetric and is_dynamic
333
333
334
- def _is_dynamic_token_w4a8_int (self , weight_quant : BaseModel ,
335
- input_quant : BaseModel ) -> bool :
334
+ def _is_dynamic_token_w4a8_int (self , weight_quant : QuantizationArgs ,
335
+ input_quant : QuantizationArgs ) -> bool :
336
336
is_weight_4_bits = weight_quant .num_bits == 4
337
337
is_activation_8_bits = input_quant .num_bits == 8
338
338
weight_strategy = (
@@ -347,8 +347,8 @@ def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel,
347
347
return (is_weight_4_bits and is_activation_8_bits and is_token
348
348
and weight_quant .symmetric and is_dynamic )
349
349
350
- def _is_fp8_w8a8 (self , weight_quant : BaseModel ,
351
- input_quant : BaseModel ) -> bool :
350
+ def _is_fp8_w8a8 (self , weight_quant : QuantizationArgs ,
351
+ input_quant : QuantizationArgs ) -> bool :
352
352
# Confirm weights and activations quantized.
353
353
if weight_quant is None or input_quant is None :
354
354
return False
@@ -358,11 +358,12 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
358
358
and input_quant .type == QuantizationType .FLOAT )
359
359
is_symmetric_weight = weight_quant .symmetric
360
360
is_static_weight = not weight_quant .dynamic
361
- is_per_tensor_or_channel_weight = (weight_quant .strategy in [
362
- QuantizationStrategy .TENSOR , QuantizationStrategy .CHANNEL
361
+ is_tensor_or_channel_or_block_weight = (weight_quant .strategy in [
362
+ QuantizationStrategy .TENSOR , QuantizationStrategy .CHANNEL ,
363
+ QuantizationStrategy .BLOCK
363
364
])
364
365
if not (is_floating_point and is_symmetric_weight and is_static_weight
365
- and is_per_tensor_or_channel_weight ):
366
+ and is_tensor_or_channel_or_block_weight ):
366
367
return False
367
368
368
369
# Dynamic quantization is always supported if weights supported.
@@ -375,8 +376,8 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
375
376
input_quant .strategy == QuantizationStrategy .TENSOR )
376
377
return is_symmetric_activation and is_per_tensor_activation
377
378
378
- def _is_fp8_w4a8 (self , weight_quant : BaseModel ,
379
- input_quant : BaseModel ) -> bool :
379
+ def _is_fp8_w4a8 (self , weight_quant : QuantizationArgs ,
380
+ input_quant : QuantizationArgs ) -> bool :
380
381
if not weight_quant or not input_quant :
381
382
return False
382
383
is_weight_4_bits = weight_quant .num_bits == 4
@@ -392,24 +393,24 @@ def _is_fp8_w4a8(self, weight_quant: BaseModel,
392
393
return (is_weight_4_bits and is_activation_8_bits and is_token
393
394
and is_symmetric and is_dynamic )
394
395
395
- def _is_fp8_w4a8_sm90 (self , weight_quant : BaseModel ,
396
- input_quant : BaseModel ) -> bool :
396
+ def _is_fp8_w4a8_sm90 (self , weight_quant : QuantizationArgs ,
397
+ input_quant : QuantizationArgs ) -> bool :
397
398
return (self ._check_scheme_supported (90 , error = False , match_exact = True )
398
399
and self ._is_fp8_w4a8 (weight_quant , input_quant ))
399
400
400
- def _is_fp8_w8a8_sm90 (self , weight_quant : BaseModel ,
401
- input_quant : BaseModel ) -> bool :
401
+ def _is_fp8_w8a8_sm90 (self , weight_quant : QuantizationArgs ,
402
+ input_quant : QuantizationArgs ) -> bool :
402
403
return (self ._check_scheme_supported (90 , error = False , match_exact = True )
403
404
and self ._is_fp8_w8a8 (weight_quant , input_quant ))
404
405
405
- def _is_fp8_w8a8_sm100 (self , weight_quant : BaseModel ,
406
- input_quant : BaseModel ) -> bool :
406
+ def _is_fp8_w8a8_sm100 (self , weight_quant : QuantizationArgs ,
407
+ input_quant : QuantizationArgs ) -> bool :
407
408
return (self ._check_scheme_supported (
408
409
100 , error = False , match_exact = True )
409
410
and self ._is_fp8_w8a8 (weight_quant , input_quant ))
410
411
411
- def _is_fp8_w8a16 (self , weight_quant : BaseModel ,
412
- input_quant : BaseModel ) -> bool :
412
+ def _is_fp8_w8a16 (self , weight_quant : QuantizationArgs ,
413
+ input_quant : QuantizationArgs ) -> bool :
413
414
# Confirm weights quantized.
414
415
if weight_quant is None :
415
416
return False
@@ -421,18 +422,19 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel,
421
422
# Confirm weight scheme is supported.
422
423
is_symmetric_weight = weight_quant .symmetric
423
424
is_static_weight = not weight_quant .dynamic
424
- is_per_tensor_or_channel_weight = (weight_quant .strategy in [
425
- QuantizationStrategy .TENSOR , QuantizationStrategy .CHANNEL
425
+ is_tensor_or_channel_or_block_weight = (weight_quant .strategy in [
426
+ QuantizationStrategy .TENSOR , QuantizationStrategy .CHANNEL ,
427
+ QuantizationStrategy .BLOCK
426
428
])
427
429
if not (is_symmetric_weight and is_static_weight # noqa: SIM103
428
- and is_per_tensor_or_channel_weight ):
430
+ and is_tensor_or_channel_or_block_weight ):
429
431
return False
430
432
431
433
# All conditions satisfied.
432
434
return True
433
435
434
- def _is_wNa16_group_channel (self , weight_quant : BaseModel ,
435
- input_quant : BaseModel ) -> bool :
436
+ def _is_wNa16_group_channel (self , weight_quant : QuantizationArgs ,
437
+ input_quant : QuantizationArgs ) -> bool :
436
438
input_quant_none = input_quant is None
437
439
is_channel_group = (
438
440
weight_quant .strategy == QuantizationStrategy .CHANNEL .value
@@ -443,8 +445,8 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
443
445
444
446
def _get_scheme_from_parts (
445
447
self ,
446
- weight_quant : BaseModel ,
447
- input_quant : BaseModel ,
448
+ weight_quant : QuantizationArgs ,
449
+ input_quant : QuantizationArgs ,
448
450
format : Optional [str ] = None ) -> "CompressedTensorsScheme" :
449
451
450
452
# use the per-layer format if defined, otherwise, use global format
@@ -496,7 +498,7 @@ def _get_scheme_from_parts(
496
498
CompressedTensorsW8A8Fp8 .get_min_capability (), error = False )
497
499
if is_fp8_w8a8_supported :
498
500
return CompressedTensorsW8A8Fp8 (
499
- strategy = weight_quant . strategy ,
501
+ weight_quant = weight_quant ,
500
502
is_static_input_scheme = (input_quant
501
503
and not input_quant .dynamic ))
502
504
else :
0 commit comments