Skip to content

Commit fbd6523

Browse files
authored
Refactor dense FP8 tensor/channel/block utils and add CT FP8 block (vllm-project#21404)
1 parent 470484a commit fbd6523

File tree

5 files changed

+442
-318
lines changed

5 files changed

+442
-318
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -805,12 +805,10 @@ def weight_loader_v2(self,
805805
assert loaded_shard_id < len(self.output_sizes)
806806

807807
if isinstance(param, BlockQuantScaleParameter):
808-
from vllm.model_executor.layers.quantization.fp8 import (
809-
Fp8LinearMethod, Fp8MoEMethod)
810808
assert self.quant_method is not None
811-
assert isinstance(self.quant_method,
812-
(Fp8LinearMethod, Fp8MoEMethod))
813-
weight_block_size = self.quant_method.quant_config.weight_block_size
809+
# Assume the weight block size has been set by quant method
810+
assert hasattr(self, "weight_block_size")
811+
weight_block_size = self.weight_block_size
814812
assert weight_block_size is not None
815813
block_n, _ = weight_block_size[0], weight_block_size[1]
816814
shard_offset = (
@@ -989,8 +987,10 @@ def weight_loader_v2(self,
989987
# Note(simon): This is needed for Qwen3's fp8 quantization.
990988
if isinstance(param, BlockQuantScaleParameter):
991989
assert self.quant_method is not None
992-
assert hasattr(self.quant_method, "quant_config")
993-
weight_block_size = self.quant_method.quant_config.weight_block_size
990+
# Assume the weight block size has been set by quant method
991+
assert hasattr(self, "weight_block_size")
992+
weight_block_size = self.weight_block_size
993+
assert weight_block_size is not None
994994
block_n, _ = weight_block_size[0], weight_block_size[1]
995995
shard_offset = (shard_offset + block_n - 1) // block_n
996996
shard_size = (shard_size + block_n - 1) // block_n

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
QuantizationStrategy,
1313
QuantizationType)
1414
from compressed_tensors.transform import TransformConfig
15-
from pydantic import BaseModel
1615

1716
import vllm.envs as envs
1817
from vllm.logger import init_logger
@@ -268,7 +267,8 @@ def _check_scheme_supported(self,
268267
else:
269268
return False
270269

271-
def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
270+
def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs,
271+
input_quant: QuantizationArgs):
272272

273273
if weight_quant is None or input_quant is None:
274274
return False
@@ -288,8 +288,8 @@ def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
288288
return (is_tensor_group_quant and is_float_type and is_4_bits
289289
and is_group_size_16 and is_symmetric)
290290

291-
def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
292-
input_quant: BaseModel):
291+
def _is_fp4a16_nvfp4(self, weight_quant: QuantizationArgs,
292+
input_quant: QuantizationArgs):
293293

294294
is_weight_only = weight_quant is not None and input_quant is None
295295
is_tensor_group_quant = (
@@ -303,8 +303,8 @@ def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
303303
return (is_weight_only and is_tensor_group_quant and is_float_type
304304
and is_4_bits and is_group_size_16 and is_symmetric)
305305

306-
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
307-
input_quant: BaseModel) -> bool:
306+
def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs,
307+
input_quant: QuantizationArgs) -> bool:
308308
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
309309
weight_strategy = (
310310
weight_quant.strategy == QuantizationStrategy.TENSOR.value
@@ -317,8 +317,8 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
317317
# Only symmetric weight quantization supported.
318318
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
319319

320-
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
321-
input_quant: BaseModel) -> bool:
320+
def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs,
321+
input_quant: QuantizationArgs) -> bool:
322322
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
323323
weight_strategy = (
324324
weight_quant.strategy == QuantizationStrategy.TENSOR.value
@@ -331,8 +331,8 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
331331
# Only symmetric weight quantization supported.
332332
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
333333

334-
def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel,
335-
input_quant: BaseModel) -> bool:
334+
def _is_dynamic_token_w4a8_int(self, weight_quant: QuantizationArgs,
335+
input_quant: QuantizationArgs) -> bool:
336336
is_weight_4_bits = weight_quant.num_bits == 4
337337
is_activation_8_bits = input_quant.num_bits == 8
338338
weight_strategy = (
@@ -347,8 +347,8 @@ def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel,
347347
return (is_weight_4_bits and is_activation_8_bits and is_token
348348
and weight_quant.symmetric and is_dynamic)
349349

350-
def _is_fp8_w8a8(self, weight_quant: BaseModel,
351-
input_quant: BaseModel) -> bool:
350+
def _is_fp8_w8a8(self, weight_quant: QuantizationArgs,
351+
input_quant: QuantizationArgs) -> bool:
352352
# Confirm weights and activations quantized.
353353
if weight_quant is None or input_quant is None:
354354
return False
@@ -358,11 +358,12 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
358358
and input_quant.type == QuantizationType.FLOAT)
359359
is_symmetric_weight = weight_quant.symmetric
360360
is_static_weight = not weight_quant.dynamic
361-
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
362-
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
361+
is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
362+
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
363+
QuantizationStrategy.BLOCK
363364
])
364365
if not (is_floating_point and is_symmetric_weight and is_static_weight
365-
and is_per_tensor_or_channel_weight):
366+
and is_tensor_or_channel_or_block_weight):
366367
return False
367368

368369
# Dynamic quantization is always supported if weights supported.
@@ -375,8 +376,8 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
375376
input_quant.strategy == QuantizationStrategy.TENSOR)
376377
return is_symmetric_activation and is_per_tensor_activation
377378

378-
def _is_fp8_w4a8(self, weight_quant: BaseModel,
379-
input_quant: BaseModel) -> bool:
379+
def _is_fp8_w4a8(self, weight_quant: QuantizationArgs,
380+
input_quant: QuantizationArgs) -> bool:
380381
if not weight_quant or not input_quant:
381382
return False
382383
is_weight_4_bits = weight_quant.num_bits == 4
@@ -392,24 +393,24 @@ def _is_fp8_w4a8(self, weight_quant: BaseModel,
392393
return (is_weight_4_bits and is_activation_8_bits and is_token
393394
and is_symmetric and is_dynamic)
394395

395-
def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel,
396-
input_quant: BaseModel) -> bool:
396+
def _is_fp8_w4a8_sm90(self, weight_quant: QuantizationArgs,
397+
input_quant: QuantizationArgs) -> bool:
397398
return (self._check_scheme_supported(90, error=False, match_exact=True)
398399
and self._is_fp8_w4a8(weight_quant, input_quant))
399400

400-
def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel,
401-
input_quant: BaseModel) -> bool:
401+
def _is_fp8_w8a8_sm90(self, weight_quant: QuantizationArgs,
402+
input_quant: QuantizationArgs) -> bool:
402403
return (self._check_scheme_supported(90, error=False, match_exact=True)
403404
and self._is_fp8_w8a8(weight_quant, input_quant))
404405

405-
def _is_fp8_w8a8_sm100(self, weight_quant: BaseModel,
406-
input_quant: BaseModel) -> bool:
406+
def _is_fp8_w8a8_sm100(self, weight_quant: QuantizationArgs,
407+
input_quant: QuantizationArgs) -> bool:
407408
return (self._check_scheme_supported(
408409
100, error=False, match_exact=True)
409410
and self._is_fp8_w8a8(weight_quant, input_quant))
410411

411-
def _is_fp8_w8a16(self, weight_quant: BaseModel,
412-
input_quant: BaseModel) -> bool:
412+
def _is_fp8_w8a16(self, weight_quant: QuantizationArgs,
413+
input_quant: QuantizationArgs) -> bool:
413414
# Confirm weights quantized.
414415
if weight_quant is None:
415416
return False
@@ -421,18 +422,19 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel,
421422
# Confirm weight scheme is supported.
422423
is_symmetric_weight = weight_quant.symmetric
423424
is_static_weight = not weight_quant.dynamic
424-
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
425-
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
425+
is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
426+
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
427+
QuantizationStrategy.BLOCK
426428
])
427429
if not (is_symmetric_weight and is_static_weight # noqa: SIM103
428-
and is_per_tensor_or_channel_weight):
430+
and is_tensor_or_channel_or_block_weight):
429431
return False
430432

431433
# All conditions satisfied.
432434
return True
433435

434-
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
435-
input_quant: BaseModel) -> bool:
436+
def _is_wNa16_group_channel(self, weight_quant: QuantizationArgs,
437+
input_quant: QuantizationArgs) -> bool:
436438
input_quant_none = input_quant is None
437439
is_channel_group = (
438440
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
@@ -443,8 +445,8 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
443445

444446
def _get_scheme_from_parts(
445447
self,
446-
weight_quant: BaseModel,
447-
input_quant: BaseModel,
448+
weight_quant: QuantizationArgs,
449+
input_quant: QuantizationArgs,
448450
format: Optional[str] = None) -> "CompressedTensorsScheme":
449451

450452
# use the per-layer format if defined, otherwise, use global format
@@ -496,7 +498,7 @@ def _get_scheme_from_parts(
496498
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
497499
if is_fp8_w8a8_supported:
498500
return CompressedTensorsW8A8Fp8(
499-
strategy=weight_quant.strategy,
501+
weight_quant=weight_quant,
500502
is_static_input_scheme=(input_quant
501503
and not input_quant.dynamic))
502504
else:

0 commit comments

Comments
 (0)