|
| 1 | +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from typing import TYPE_CHECKING, Any, Optional |
| 15 | + |
| 16 | +from .base import HfQuantizer |
| 17 | +from .quantizers_utils import get_module_from_name |
| 18 | + |
| 19 | + |
| 20 | +if TYPE_CHECKING: |
| 21 | + from ..modeling_utils import PreTrainedModel |
| 22 | + |
| 23 | +from ..utils import is_fp_quant_available, is_qutlass_available, is_torch_available, logging |
| 24 | +from ..utils.quantization_config import QuantizationConfigMixin |
| 25 | + |
| 26 | + |
| 27 | +if is_torch_available(): |
| 28 | + import torch |
| 29 | + |
| 30 | +logger = logging.get_logger(__name__) |
| 31 | + |
| 32 | + |
| 33 | +class FPQuantHfQuantizer(HfQuantizer): |
| 34 | + """ |
| 35 | + Quantizer for the FP-Quant method. Enables the loading of prequantized models and in-flight quantization of full-precision models. |
| 36 | + """ |
| 37 | + |
| 38 | + requires_calibration = False |
| 39 | + requires_parameters_quantization = True |
| 40 | + is_qat_trainable = False |
| 41 | + required_packages = ["fp_quant"] |
| 42 | + |
| 43 | + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): |
| 44 | + super().__init__(quantization_config, **kwargs) |
| 45 | + self.quantization_config = quantization_config |
| 46 | + |
| 47 | + def validate_environment(self, device_map, **kwargs): |
| 48 | + if not torch.cuda.is_available(): |
| 49 | + raise NotImplementedError( |
| 50 | + "FPQuant quantization is only supported on GPU. Please use a different quantizer." |
| 51 | + ) |
| 52 | + |
| 53 | + if not is_qutlass_available() and not self.quantization_config.pseudoquantization: |
| 54 | + raise ImportError( |
| 55 | + "Using `fp_quant` with real quantization requires a **Blackwell GPU** and qutlass: `git clone https://github.com/IST-DASLab/qutlass.git && cd qutlass && pip install --no-build-isolation .`. You can use `FPQuantConfig(pseudoquantization=True, ...)` to use Triton-based pseudo-quantization. It doesn't provide any speedups but emulates the quantization behavior of the real quantization." |
| 56 | + ) |
| 57 | + |
| 58 | + if self.quantization_config.pseudoquantization: |
| 59 | + logger.warning( |
| 60 | + "Using pseudo-quantization for FP-Quant. This doesn't provide any speedups but emulates the quantization behavior of the real quantization." |
| 61 | + ) |
| 62 | + |
| 63 | + if not is_fp_quant_available(): |
| 64 | + raise ImportError("Using `fp_quant` quantization requires fp_quant: `pip install fp_quant`") |
| 65 | + |
| 66 | + if device_map is None: |
| 67 | + raise ValueError( |
| 68 | + "You are attempting to load a FPQuant model without setting device_map." |
| 69 | + " Please set device_map comprised of 'cuda' devices." |
| 70 | + ) |
| 71 | + elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()): |
| 72 | + raise ValueError( |
| 73 | + "You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device." |
| 74 | + " This is not supported. Please remove the CPU or disk device from the device_map." |
| 75 | + ) |
| 76 | + |
| 77 | + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": |
| 78 | + if torch_dtype is None: |
| 79 | + logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.bfloat16` for qutlass compatibility.") |
| 80 | + torch_dtype = torch.bfloat16 |
| 81 | + elif torch_dtype != torch.bfloat16: |
| 82 | + raise ValueError( |
| 83 | + f"Invalid `torch_dtype` {torch_dtype}. fp_quant quantization only supports `torch_dtype=torch.bfloat16`." |
| 84 | + ) |
| 85 | + |
| 86 | + return torch_dtype |
| 87 | + |
| 88 | + def create_quantized_param( |
| 89 | + self, |
| 90 | + model: "PreTrainedModel", |
| 91 | + param_value: "torch.Tensor", |
| 92 | + param_name: str, |
| 93 | + target_device: "torch.device", |
| 94 | + state_dict: dict[str, Any], |
| 95 | + unexpected_keys: Optional[list[str]] = None, |
| 96 | + ): |
| 97 | + module, _ = get_module_from_name(model, param_name) |
| 98 | + |
| 99 | + # The module holds either: |
| 100 | + # * `weight` when `store_master_weights=True` |
| 101 | + # * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False` |
| 102 | + # * `dqweight` when `store_master_weights=False` and `pseudoquantization=True` |
| 103 | + |
| 104 | + if param_name.endswith(".qweight"): |
| 105 | + # Loading a real quantized checkpoint without master weights |
| 106 | + module.qweight = torch.nn.Parameter( |
| 107 | + param_value.to(target_device), |
| 108 | + requires_grad=False, |
| 109 | + ) |
| 110 | + module.weight = None |
| 111 | + module.dqweight = None |
| 112 | + return |
| 113 | + |
| 114 | + if param_name.endswith(".dqweight"): |
| 115 | + # Loading a pseudo-quantized checkpoint without master weights |
| 116 | + module.dqweight = torch.nn.Parameter(param_value.to(target_device)) |
| 117 | + module.weight = None |
| 118 | + module.qweight = None |
| 119 | + module.scales = None |
| 120 | + return |
| 121 | + |
| 122 | + # Loading master weights or an unquantized checkpoint |
| 123 | + module.weight = torch.nn.Parameter(param_value.to(target_device)) |
| 124 | + # Let pre-forward handle the quantization and set None where necessary |
| 125 | + module.pre_forward() |
| 126 | + |
| 127 | + if unexpected_keys is not None and param_name in unexpected_keys: |
| 128 | + unexpected_keys.remove(param_name) |
| 129 | + |
| 130 | + def _process_model_before_weight_loading( |
| 131 | + self, |
| 132 | + model: "PreTrainedModel", |
| 133 | + **kwargs, |
| 134 | + ): |
| 135 | + from fp_quant import replace_with_fp_quant_linear |
| 136 | + |
| 137 | + from ..integrations.fp_quant import adapt_fp_quant_config |
| 138 | + |
| 139 | + replace_with_fp_quant_linear( |
| 140 | + model, |
| 141 | + fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config), |
| 142 | + ) |
| 143 | + model.config.quantization_config = self.quantization_config |
| 144 | + |
| 145 | + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): |
| 146 | + return model |
| 147 | + |
| 148 | + def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: |
| 149 | + from fp_quant import FPQuantLinear |
| 150 | + |
| 151 | + fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)} |
| 152 | + |
| 153 | + def should_exclude(key: str) -> bool: |
| 154 | + if key.endswith(".weight") or key.endswith(".bias"): |
| 155 | + return False |
| 156 | + full_key = f"{prefix}.{key}" |
| 157 | + return any(name in key or name in full_key for name in fp_quant_names) |
| 158 | + |
| 159 | + return [key for key in missing_keys if not should_exclude(key)] |
| 160 | + |
| 161 | + @property |
| 162 | + def is_trainable(self, model: Optional["PreTrainedModel"] = None): |
| 163 | + return False |
| 164 | + |
| 165 | + def is_serializable(self, safe_serialization=None): |
| 166 | + return True |
| 167 | + |
| 168 | + def check_quantized_param( |
| 169 | + self, |
| 170 | + model: "PreTrainedModel", |
| 171 | + param_value: "torch.Tensor", |
| 172 | + param_name: str, |
| 173 | + state_dict: dict[str, Any], |
| 174 | + **kwargs, |
| 175 | + ) -> bool: |
| 176 | + from fp_quant import FPQuantLinear |
| 177 | + |
| 178 | + module, tensor_name = get_module_from_name(model, param_name) |
| 179 | + if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight", "dqweight"]: |
| 180 | + # Only quantize weights of FPQuantLinear modules that are not already quantized |
| 181 | + return True |
| 182 | + else: |
| 183 | + return False |
0 commit comments