55# LICENSE file in the root directory of this source tree.
66
77from typing import Callable , Optional
8+ from warnings import warn
89
910from torchtune .utils ._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API
1011
1819 int8_dynamic_activation_int4_weight ,
1920 quantize_ ,
2021)
21- from torchao .quantization .prototype .qat import (
22- disable_4w_fake_quant ,
23- disable_8da4w_fake_quant ,
24- enable_4w_fake_quant ,
25- enable_8da4w_fake_quant ,
26- Int4WeightOnlyQATQuantizer ,
27- Int8DynActInt4WeightQATQuantizer ,
28- )
29- from torchao .quantization .prototype .qat ._module_swap_api import (
30- disable_4w_fake_quant_module_swap ,
31- disable_8da4w_fake_quant_module_swap ,
32- enable_4w_fake_quant_module_swap ,
33- enable_8da4w_fake_quant_module_swap ,
34- Int4WeightOnlyQATQuantizerModuleSwap ,
35- Int8DynActInt4WeightQATQuantizerModuleSwap ,
36- )
22+
23+ try :
24+ # torchao 0.7+
25+ from torchao .quantization .qat import (
26+ Int4WeightOnlyQATQuantizer ,
27+ Int8DynActInt4WeightQATQuantizer ,
28+ )
29+ from torchao .quantization .qat .linear import (
30+ disable_4w_fake_quant ,
31+ disable_8da4w_fake_quant ,
32+ enable_4w_fake_quant ,
33+ enable_8da4w_fake_quant ,
34+ )
35+ except ImportError :
36+ # torchao 0.6 and before
37+ from torchao .quantization .prototype .qat import (
38+ disable_4w_fake_quant ,
39+ disable_8da4w_fake_quant ,
40+ enable_4w_fake_quant ,
41+ enable_8da4w_fake_quant ,
42+ Int4WeightOnlyQATQuantizer ,
43+ Int8DynActInt4WeightQATQuantizer ,
44+ )
3745
3846
3947__all__ = [
5260_quantizer_mode_to_enable_fake_quant = {}
5361
5462
55- # ========================================================
56- # int8 dynamic activations + int4 weight tensor subclass |
57- # ========================================================
63+ # ========================================
64+ # int8 dynamic activations + int4 weight |
65+ # ========================================
5866
5967
6068class Int8DynActInt4WeightQuantizer :
@@ -106,15 +114,15 @@ def quantize(self, model):
106114_quantizer_mode_to_enable_fake_quant ["4w-qat" ] = enable_4w_fake_quant
107115
108116
109- # =============
110- # module swap |
111- # =============
117+ # ====================== #
118+ # Backward compatibility #
119+ # ====================== #
112120
113- # Note: QAT tensor subclass implementation in torchao only works
114- # with FSDP2 today. For other distribution strategies like DDP and
115- # FSDP1, users will need to fall back to the old module swap flow.
116121
117122# int4 weight-only
123+ Int4WeightOnlyQATQuantizerModuleSwap = Int4WeightOnlyQATQuantizer
124+ disable_4w_fake_quant_module_swap = disable_4w_fake_quant
125+ enable_4w_fake_quant_module_swap = enable_4w_fake_quant
118126_quantizer_to_mode [Int4WeightOnlyQATQuantizerModuleSwap ] = "4w-qat-module-swap"
119127_quantizer_mode_to_disable_fake_quant [
120128 "4w-qat-module-swap"
@@ -124,6 +132,9 @@ def quantize(self, model):
124132] = enable_4w_fake_quant_module_swap
125133
126134# int8 dynamic activations + int4 weight
135+ Int8DynActInt4WeightQATQuantizerModuleSwap = Int8DynActInt4WeightQATQuantizer
136+ disable_8da4w_fake_quant_module_swap = disable_8da4w_fake_quant
137+ enable_8da4w_fake_quant_module_swap = enable_8da4w_fake_quant
127138_quantizer_to_mode [Int8DynActInt4WeightQATQuantizerModuleSwap ] = "8da4w-qat-module-swap"
128139_quantizer_mode_to_disable_fake_quant [
129140 "8da4w-qat-module-swap"
@@ -142,15 +153,20 @@ def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]:
142153 Currently supported:
143154
144155 - :class:`~torchao.quantization.quant_api.Int8DynActInt4WeightQuantizer`: "8da4w" (requires ``torch>=2.3.0``)
145- - :class:`~torchao.quantization.prototype. qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" (requires ``torch>=2.4.0``)
156+ - :class:`~torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" (requires ``torch>=2.4.0``)
146157
147158 Args:
148159 quantizer (Optional[Callable]): A callable object that implements the `quantize` method.
149160
150161 Returns:
151162 Optional[str]: The quantization mode.
152163 """
153- return _quantizer_to_mode .get (type (quantizer ), None )
164+ mode = _quantizer_to_mode .get (type (quantizer ), None )
165+ if mode is not None and "module-swap" in mode :
166+ warn (
167+ "*QuantizerModuleSwap is deprecated. Please use the version without 'ModuleSwap' instead"
168+ )
169+ return mode
154170
155171
156172def _get_disable_fake_quant (quantizer_mode : str ) -> Callable :
0 commit comments