Skip to content

Commit 1bbd749

Browse files
authored
Update imports after QAT was moved out of prototype (#1883)
1 parent 894fdb8 commit 1bbd749

File tree

1 file changed

+46
-28
lines changed

1 file changed

+46
-28
lines changed

torchtune/training/quantization.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from typing import Callable, Optional
8+
from warnings import warn
89

910
try:
1011
# torchao 0.7+
@@ -18,22 +19,29 @@
1819
int8_dynamic_activation_int4_weight,
1920
quantize_,
2021
)
21-
from torchao.quantization.prototype.qat import (
22-
disable_4w_fake_quant,
23-
disable_8da4w_fake_quant,
24-
enable_4w_fake_quant,
25-
enable_8da4w_fake_quant,
26-
Int4WeightOnlyQATQuantizer,
27-
Int8DynActInt4WeightQATQuantizer,
28-
)
29-
from torchao.quantization.prototype.qat._module_swap_api import (
30-
disable_4w_fake_quant_module_swap,
31-
disable_8da4w_fake_quant_module_swap,
32-
enable_4w_fake_quant_module_swap,
33-
enable_8da4w_fake_quant_module_swap,
34-
Int4WeightOnlyQATQuantizerModuleSwap,
35-
Int8DynActInt4WeightQATQuantizerModuleSwap,
36-
)
22+
23+
try:
24+
# torchao 0.7+
25+
from torchao.quantization.qat import (
26+
Int4WeightOnlyQATQuantizer,
27+
Int8DynActInt4WeightQATQuantizer,
28+
)
29+
from torchao.quantization.qat.linear import (
30+
disable_4w_fake_quant,
31+
disable_8da4w_fake_quant,
32+
enable_4w_fake_quant,
33+
enable_8da4w_fake_quant,
34+
)
35+
except ImportError:
36+
# torchao 0.6 and before
37+
from torchao.quantization.prototype.qat import (
38+
disable_4w_fake_quant,
39+
disable_8da4w_fake_quant,
40+
enable_4w_fake_quant,
41+
enable_8da4w_fake_quant,
42+
Int4WeightOnlyQATQuantizer,
43+
Int8DynActInt4WeightQATQuantizer,
44+
)
3745

3846

3947
__all__ = [
@@ -52,9 +60,9 @@
5260
_quantizer_mode_to_enable_fake_quant = {}
5361

5462

55-
# ========================================================
56-
# int8 dynamic activations + int4 weight tensor subclass |
57-
# ========================================================
63+
# ========================================
64+
# int8 dynamic activations + int4 weight |
65+
# ========================================
5866

5967

6068
class Int8DynActInt4WeightQuantizer:
@@ -106,15 +114,15 @@ def quantize(self, model):
106114
_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant
107115

108116

109-
# =============
110-
# module swap |
111-
# =============
117+
# ====================== #
118+
# Backward compatibility #
119+
# ====================== #
112120

113-
# Note: QAT tensor subclass implementation in torchao only works
114-
# with FSDP2 today. For other distribution strategies like DDP and
115-
# FSDP1, users will need to fall back to the old module swap flow.
116121

117122
# int4 weight-only
123+
Int4WeightOnlyQATQuantizerModuleSwap = Int4WeightOnlyQATQuantizer
124+
disable_4w_fake_quant_module_swap = disable_4w_fake_quant
125+
enable_4w_fake_quant_module_swap = enable_4w_fake_quant
118126
_quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap"
119127
_quantizer_mode_to_disable_fake_quant[
120128
"4w-qat-module-swap"
@@ -124,6 +132,9 @@ def quantize(self, model):
124132
] = enable_4w_fake_quant_module_swap
125133

126134
# int8 dynamic activations + int4 weight
135+
Int8DynActInt4WeightQATQuantizerModuleSwap = Int8DynActInt4WeightQATQuantizer
136+
disable_8da4w_fake_quant_module_swap = disable_8da4w_fake_quant
137+
enable_8da4w_fake_quant_module_swap = enable_8da4w_fake_quant
127138
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap"
128139
_quantizer_mode_to_disable_fake_quant[
129140
"8da4w-qat-module-swap"
@@ -141,16 +152,23 @@ def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]:
141152
142153
Currently supported:
143154
144-
- :class:`~torchao.quantization.quant_api.Int8DynActInt4WeightQuantizer`: "8da4w" (requires ``torch>=2.3.0``)
145-
- :class:`~torchao.quantization.prototype.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" (requires ``torch>=2.4.0``)
155+
- :class:`~torchtune.training.quantization.Int8DynActInt4WeightQuantizer`: "8da4w"
156+
- :class:`~torchtune.training.quantization.Int4WeightOnlyQuantizer`: "4w"
157+
- :class:`~torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat"
158+
- :class:`~torchao.quantization.qat.Int4WeightOnlyQATQuantizer`: "4w-qat"
146159
147160
Args:
148161
quantizer (Optional[Callable]): A callable object that implements the `quantize` method.
149162
150163
Returns:
151164
Optional[str]: The quantization mode.
152165
"""
153-
return _quantizer_to_mode.get(type(quantizer), None)
166+
mode = _quantizer_to_mode.get(type(quantizer), None)
167+
if mode is not None and "module-swap" in mode:
168+
warn(
169+
"*QuantizerModuleSwap is deprecated. Please use the version without 'ModuleSwap' instead"
170+
)
171+
return mode
154172

155173

156174
def _get_disable_fake_quant(quantizer_mode: str) -> Callable:

0 commit comments

Comments
 (0)