Skip to content

Commit 967e26a

Browse files
BlackSamorezSunMarcMekkCyber
authored andcommitted
FP-Quant support (huggingface#38696)
* quartet * quartet qat -> quartet * format * bf16 backward * interfaces * forward_method * quartet -> fp_quant * style * List -> list * list typing * fixed format and annotations * test_fp_quant * docstrings and default dtypes * better docstring and removed noop checks * docs * pseudoquantization support to test on non-blackwell * pseudoquant * Pseudoquant docs * Update docs/source/en/quantization/fp_quant.md Co-authored-by: Marc Sun <[email protected]> * Update docs/source/en/quantization/fp_quant.md * Update docs/source/en/quantization/fp_quant.md * Update src/transformers/utils/quantization_config.py Co-authored-by: Mohamed Mekkouri <[email protected]> * Update tests/quantization/fp_quant_integration/test_fp_quant.py Co-authored-by: Mohamed Mekkouri <[email protected]> * Update tests/quantization/fp_quant_integration/test_fp_quant.py Co-authored-by: Marc Sun <[email protected]> * small test fixes * dockerfile update * spec link * removed `_process_model_after_weight_loading` * toctree --------- Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Mohamed Mekkouri <[email protected]>
1 parent 1a40d4a commit 967e26a

File tree

15 files changed

+629
-0
lines changed

15 files changed

+629
-0
lines changed

docker/transformers-quantization-latest-gpu/Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ RUN git clone https://github.com/NetEase-FuXi/EETQ.git && cd EETQ/ && git submod
7878
# RUN python3 -m pip install --no-cache-dir flute-kernel==0.4.1
7979
# RUN python3 -m pip install --no-cache-dir git+https://github.com/Dao-AILab/fast-hadamard-transform.git
8080

81+
# Add fp-quant for quantization testing
82+
RUN python3 -m pip install --no-cache-dir "fp-quant>=0.1.6"
83+
8184
# Add compressed-tensors for quantization testing
8285
RUN python3 -m pip install --no-cache-dir compressed-tensors
8386

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@
179179
title: FBGEMM
180180
- local: quantization/finegrained_fp8
181181
title: Fine-grained FP8
182+
- local: quantization/fp_quant
183+
title: FP-Quant
182184
- local: gguf
183185
title: GGUF
184186
- local: quantization/gptq

docs/source/en/main_classes/quantization.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
9393

9494
[[autodoc]] QuarkConfig
9595

96+
## FPQuantConfig
97+
98+
[[autodoc]] FPQuantConfig
99+
96100
## AutoRoundConfig
97101

98102
[[autodoc]] AutoRoundConfig
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# FP-Quant
18+
19+
[FP-Quant](https://github.com/IST-DASLab/FP-Quant) is a family of quantization algorithms tailored for the Blackwell generation of Nvidia GPUs. The goal is to allow for efficient post-training quantization (PTQ) and quantization-aware trainin (QAT) of LLMs in the [MXFP4 and NVFP4 data-types](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
20+
21+
Currently, only PTQ with MXFP4 is supported. Models can either be quantized on the fly with `quantization_config=FPQuantConfig()`:
22+
23+
```python
24+
from transformers import AutoModelForCausalLM, AutoTokenizer, FPQuantConfig
25+
import torch
26+
27+
model = AutoModelForCausalLM.from_pretrained(
28+
"qwen/Qwen3-8B",
29+
quantization_config=FPQuantConfig(),
30+
device_map="cuda",
31+
torch_dtype=torch.bfloat16,
32+
)
33+
```
34+
35+
or pre-processed with GPTQ for better quality (see [FP Format Quantization Harness](https://github.com/IST-DASLab/FP-Quant)).
36+
37+
A **Blackwell-generation GPU is required** to run the kernels. Runtime support for FP-Quant is implemented through the [QuTLASS](https://github.com/IST-DASLab/qutlass) library and a lightweight PyTorch interface lib [`fp_quant`](https://github.com/IST-DASLab/FP-Quant/tree/master/inference_lib). We recommend installing the former **from source** and the latter with `pip install fp_quant`.
38+
39+
Users **without a Blackwell-generation GPU** , can use the method with `quantization_config=FPQuantConfig(pseudoquant=True)` without having to install [QuTLASS](https://github.com/IST-DASLab/qutlass). This would provide no speedups but would fully emulate the effect of quantization.
40+
41+
> [!TIP]
42+
> Find models pre-quantized with FP-Quant in the official ISTA-DASLab [collection](https://huggingface.co/collections/ISTA-DASLab/fp-quant-6877c186103a21d3a02568ee).
43+
44+
## torch.compile
45+
46+
FP-Quant is fully compatible with [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html).
47+
48+
```python
49+
import torch
50+
from transformers import AutoModelForCausalLM, AutoTokenizer, FPQuantConfig
51+
52+
model = AutoModelForCausalLM.from_pretrained(
53+
"qwen/Qwen3-8B",
54+
quantization_config=FPQuantConfig(),
55+
device_map="cuda",
56+
torch_dtype=torch.bfloat16,
57+
)
58+
59+
model.forward = torch.compile(model.forward, mode="max-autotune", fullgraph=True)
60+
```
61+
62+
## Speedups
63+
64+
FP-Quant currently performs best for very large batch size processing.
65+
66+
See [QuTLASS README](https://github.com/IST-DASLab/qutlass/blob/main/README.md) for speedups.

docs/source/en/quantization/overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Use the Space below to help you pick a quantization method depending on your har
3030
| [bitsandbytes](./bitsandbytes) | 🟢 | 🟡 | 🟢 | 🟡 | 🔴 | 🟡 | 🟢 | 4/8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes |
3131
| [compressed-tensors](./compressed_tensors) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 1/8 | 🟢 | 🟢 | 🟢 | https://github.com/neuralmagic/compressed-tensors |
3232
| [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ |
33+
| [FP-Quant](./fp_quant) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 4 | 🔴 | 🟢 | 🟢 | https://github.com/IST-DASLab/FP-Quant |
3334
| [GGUF / GGML (llama.cpp)](../gguf) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1/8 | 🔴 | [See Notes](../gguf) | [See Notes](../gguf) | https://github.com/ggerganov/llama.cpp |
3435
| [GPTQModel](./gptq) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/ModelCloud/GPTQModel |
3536
| [AutoGPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@
275275
"HqqConfig",
276276
"QuantoConfig",
277277
"QuarkConfig",
278+
"FPQuantConfig",
278279
"SpQRConfig",
279280
"TorchAoConfig",
280281
"VptqConfig",
@@ -961,6 +962,7 @@
961962
EetqConfig,
962963
FbgemmFp8Config,
963964
FineGrainedFP8Config,
965+
FPQuantConfig,
964966
GPTQConfig,
965967
HiggsConfig,
966968
HqqConfig,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"FP-Quant integration file"
15+
16+
from ..utils import (
17+
is_fp_quant_available,
18+
)
19+
20+
21+
if is_fp_quant_available():
22+
from fp_quant import FPQuantConfig as FPQuantLinearConfig
23+
from fp_quant import FPQuantDtype
24+
25+
from transformers.utils.quantization_config import FPQuantConfig
26+
27+
28+
def adapt_fp_quant_config(config: FPQuantConfig):
29+
if config.forward_dtype == "mxfp4":
30+
forward_dtype = FPQuantDtype.MXFP4
31+
else:
32+
raise ValueError(f"Unsupported forward dtype: {config.forward_dtype}")
33+
34+
if config.backward_dtype == "bf16":
35+
backward_dtype = FPQuantDtype.BF16
36+
else:
37+
raise ValueError(f"Unsupported backward dtype: {config.backward_dtype}")
38+
39+
return FPQuantLinearConfig(
40+
forward_dtype=forward_dtype,
41+
forward_method=config.forward_method,
42+
backward_dtype=backward_dtype,
43+
store_master_weights=config.store_master_weights,
44+
hadamard_group_size=config.hadamard_group_size,
45+
pseudoquantization=config.pseudoquantization,
46+
modules_to_not_convert=config.modules_to_not_convert,
47+
)

src/transformers/quantizers/auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
EetqConfig,
2828
FbgemmFp8Config,
2929
FineGrainedFP8Config,
30+
FPQuantConfig,
3031
GPTQConfig,
3132
HiggsConfig,
3233
HqqConfig,
@@ -49,6 +50,7 @@
4950
from .quantizer_eetq import EetqHfQuantizer
5051
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
5152
from .quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer
53+
from .quantizer_fp_quant import FPQuantHfQuantizer
5254
from .quantizer_gptq import GptqHfQuantizer
5355
from .quantizer_higgs import HiggsHfQuantizer
5456
from .quantizer_hqq import HqqHfQuantizer
@@ -67,6 +69,7 @@
6769
"aqlm": AqlmHfQuantizer,
6870
"quanto": QuantoHfQuantizer,
6971
"quark": QuarkHfQuantizer,
72+
"fp_quant": FPQuantHfQuantizer,
7073
"eetq": EetqHfQuantizer,
7174
"higgs": HiggsHfQuantizer,
7275
"hqq": HqqHfQuantizer,
@@ -89,6 +92,7 @@
8992
"aqlm": AqlmConfig,
9093
"quanto": QuantoConfig,
9194
"quark": QuarkConfig,
95+
"fp_quant": FPQuantConfig,
9296
"hqq": HqqConfig,
9397
"compressed-tensors": CompressedTensorsConfig,
9498
"fbgemm_fp8": FbgemmFp8Config,
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING, Any, Optional
15+
16+
from .base import HfQuantizer
17+
from .quantizers_utils import get_module_from_name
18+
19+
20+
if TYPE_CHECKING:
21+
from ..modeling_utils import PreTrainedModel
22+
23+
from ..utils import is_fp_quant_available, is_qutlass_available, is_torch_available, logging
24+
from ..utils.quantization_config import QuantizationConfigMixin
25+
26+
27+
if is_torch_available():
28+
import torch
29+
30+
logger = logging.get_logger(__name__)
31+
32+
33+
class FPQuantHfQuantizer(HfQuantizer):
34+
"""
35+
Quantizer for the FP-Quant method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
36+
"""
37+
38+
requires_calibration = False
39+
requires_parameters_quantization = True
40+
is_qat_trainable = False
41+
required_packages = ["fp_quant"]
42+
43+
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
44+
super().__init__(quantization_config, **kwargs)
45+
self.quantization_config = quantization_config
46+
47+
def validate_environment(self, device_map, **kwargs):
48+
if not torch.cuda.is_available():
49+
raise NotImplementedError(
50+
"FPQuant quantization is only supported on GPU. Please use a different quantizer."
51+
)
52+
53+
if not is_qutlass_available() and not self.quantization_config.pseudoquantization:
54+
raise ImportError(
55+
"Using `fp_quant` with real quantization requires a **Blackwell GPU** and qutlass: `git clone https://github.com/IST-DASLab/qutlass.git && cd qutlass && pip install --no-build-isolation .`. You can use `FPQuantConfig(pseudoquantization=True, ...)` to use Triton-based pseudo-quantization. It doesn't provide any speedups but emulates the quantization behavior of the real quantization."
56+
)
57+
58+
if self.quantization_config.pseudoquantization:
59+
logger.warning(
60+
"Using pseudo-quantization for FP-Quant. This doesn't provide any speedups but emulates the quantization behavior of the real quantization."
61+
)
62+
63+
if not is_fp_quant_available():
64+
raise ImportError("Using `fp_quant` quantization requires fp_quant: `pip install fp_quant`")
65+
66+
if device_map is None:
67+
raise ValueError(
68+
"You are attempting to load a FPQuant model without setting device_map."
69+
" Please set device_map comprised of 'cuda' devices."
70+
)
71+
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
72+
raise ValueError(
73+
"You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
74+
" This is not supported. Please remove the CPU or disk device from the device_map."
75+
)
76+
77+
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
78+
if torch_dtype is None:
79+
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.bfloat16` for qutlass compatibility.")
80+
torch_dtype = torch.bfloat16
81+
elif torch_dtype != torch.bfloat16:
82+
raise ValueError(
83+
f"Invalid `torch_dtype` {torch_dtype}. fp_quant quantization only supports `torch_dtype=torch.bfloat16`."
84+
)
85+
86+
return torch_dtype
87+
88+
def create_quantized_param(
89+
self,
90+
model: "PreTrainedModel",
91+
param_value: "torch.Tensor",
92+
param_name: str,
93+
target_device: "torch.device",
94+
state_dict: dict[str, Any],
95+
unexpected_keys: Optional[list[str]] = None,
96+
):
97+
module, _ = get_module_from_name(model, param_name)
98+
99+
# The module holds either:
100+
# * `weight` when `store_master_weights=True`
101+
# * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False`
102+
# * `dqweight` when `store_master_weights=False` and `pseudoquantization=True`
103+
104+
if param_name.endswith(".qweight"):
105+
# Loading a real quantized checkpoint without master weights
106+
module.qweight = torch.nn.Parameter(
107+
param_value.to(target_device),
108+
requires_grad=False,
109+
)
110+
module.weight = None
111+
module.dqweight = None
112+
return
113+
114+
if param_name.endswith(".dqweight"):
115+
# Loading a pseudo-quantized checkpoint without master weights
116+
module.dqweight = torch.nn.Parameter(param_value.to(target_device))
117+
module.weight = None
118+
module.qweight = None
119+
module.scales = None
120+
return
121+
122+
# Loading master weights or an unquantized checkpoint
123+
module.weight = torch.nn.Parameter(param_value.to(target_device))
124+
# Let pre-forward handle the quantization and set None where necessary
125+
module.pre_forward()
126+
127+
if unexpected_keys is not None and param_name in unexpected_keys:
128+
unexpected_keys.remove(param_name)
129+
130+
def _process_model_before_weight_loading(
131+
self,
132+
model: "PreTrainedModel",
133+
**kwargs,
134+
):
135+
from fp_quant import replace_with_fp_quant_linear
136+
137+
from ..integrations.fp_quant import adapt_fp_quant_config
138+
139+
replace_with_fp_quant_linear(
140+
model,
141+
fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config),
142+
)
143+
model.config.quantization_config = self.quantization_config
144+
145+
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
146+
return model
147+
148+
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
149+
from fp_quant import FPQuantLinear
150+
151+
fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
152+
153+
def should_exclude(key: str) -> bool:
154+
if key.endswith(".weight") or key.endswith(".bias"):
155+
return False
156+
full_key = f"{prefix}.{key}"
157+
return any(name in key or name in full_key for name in fp_quant_names)
158+
159+
return [key for key in missing_keys if not should_exclude(key)]
160+
161+
@property
162+
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
163+
return False
164+
165+
def is_serializable(self, safe_serialization=None):
166+
return True
167+
168+
def check_quantized_param(
169+
self,
170+
model: "PreTrainedModel",
171+
param_value: "torch.Tensor",
172+
param_name: str,
173+
state_dict: dict[str, Any],
174+
**kwargs,
175+
) -> bool:
176+
from fp_quant import FPQuantLinear
177+
178+
module, tensor_name = get_module_from_name(model, param_name)
179+
if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight", "dqweight"]:
180+
# Only quantize weights of FPQuantLinear modules that are not already quantized
181+
return True
182+
else:
183+
return False

0 commit comments

Comments
 (0)