Skip to content

Commit fc52849

Browse files
committed
Update imports after QAT was moved out of prototype
Summary: pytorch/ao#1091 moved QAT out of prototype in torchao. This is a BC-breaking change so torchtune also needs to update its QAT imports. Additionally, after pytorch/ao#987 we decided that QAT in torchao will use module swaps to insert fake quantizes, so there is no need to have a separate module swap quantizer, so this commit removes the `*ModuleSwapQuantizer` option. Test Plan: pytest -m integration_test tests/recipes/test_qat_distributed.py should work
1 parent ca37c59 commit fc52849

File tree

1 file changed

+43
-27
lines changed

1 file changed

+43
-27
lines changed

torchtune/training/quantization.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from typing import Callable, Optional
8+
from warnings import warn
89

910
from torchtune.utils._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API
1011

@@ -18,22 +19,29 @@
1819
int8_dynamic_activation_int4_weight,
1920
quantize_,
2021
)
21-
from torchao.quantization.prototype.qat import (
22-
disable_4w_fake_quant,
23-
disable_8da4w_fake_quant,
24-
enable_4w_fake_quant,
25-
enable_8da4w_fake_quant,
26-
Int4WeightOnlyQATQuantizer,
27-
Int8DynActInt4WeightQATQuantizer,
28-
)
29-
from torchao.quantization.prototype.qat._module_swap_api import (
30-
disable_4w_fake_quant_module_swap,
31-
disable_8da4w_fake_quant_module_swap,
32-
enable_4w_fake_quant_module_swap,
33-
enable_8da4w_fake_quant_module_swap,
34-
Int4WeightOnlyQATQuantizerModuleSwap,
35-
Int8DynActInt4WeightQATQuantizerModuleSwap,
36-
)
22+
23+
try:
24+
# torchao 0.7+
25+
from torchao.quantization.qat import (
26+
Int4WeightOnlyQATQuantizer,
27+
Int8DynActInt4WeightQATQuantizer,
28+
)
29+
from torchao.quantization.qat.linear import (
30+
disable_4w_fake_quant,
31+
disable_8da4w_fake_quant,
32+
enable_4w_fake_quant,
33+
enable_8da4w_fake_quant,
34+
)
35+
except ImportError:
36+
# torchao 0.6 and before
37+
from torchao.quantization.prototype.qat import (
38+
disable_4w_fake_quant,
39+
disable_8da4w_fake_quant,
40+
enable_4w_fake_quant,
41+
enable_8da4w_fake_quant,
42+
Int4WeightOnlyQATQuantizer,
43+
Int8DynActInt4WeightQATQuantizer,
44+
)
3745

3846

3947
__all__ = [
@@ -52,9 +60,9 @@
5260
_quantizer_mode_to_enable_fake_quant = {}
5361

5462

55-
# ========================================================
56-
# int8 dynamic activations + int4 weight tensor subclass |
57-
# ========================================================
63+
# ========================================
64+
# int8 dynamic activations + int4 weight |
65+
# ========================================
5866

5967

6068
class Int8DynActInt4WeightQuantizer:
@@ -106,15 +114,15 @@ def quantize(self, model):
106114
_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant
107115

108116

109-
# =============
110-
# module swap |
111-
# =============
117+
# ====================== #
118+
# Backward compatibility #
119+
# ====================== #
112120

113-
# Note: QAT tensor subclass implementation in torchao only works
114-
# with FSDP2 today. For other distribution strategies like DDP and
115-
# FSDP1, users will need to fall back to the old module swap flow.
116121

117122
# int4 weight-only
123+
Int4WeightOnlyQATQuantizerModuleSwap = Int4WeightOnlyQATQuantizer
124+
disable_4w_fake_quant_module_swap = disable_4w_fake_quant
125+
enable_4w_fake_quant_module_swap = enable_4w_fake_quant
118126
_quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap"
119127
_quantizer_mode_to_disable_fake_quant[
120128
"4w-qat-module-swap"
@@ -124,6 +132,9 @@ def quantize(self, model):
124132
] = enable_4w_fake_quant_module_swap
125133

126134
# int8 dynamic activations + int4 weight
135+
Int8DynActInt4WeightQATQuantizerModuleSwap = Int8DynActInt4WeightQATQuantizer
136+
disable_8da4w_fake_quant_module_swap = disable_8da4w_fake_quant
137+
enable_8da4w_fake_quant_module_swap = enable_8da4w_fake_quant
127138
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap"
128139
_quantizer_mode_to_disable_fake_quant[
129140
"8da4w-qat-module-swap"
@@ -142,15 +153,20 @@ def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]:
142153
Currently supported:
143154
144155
- :class:`~torchao.quantization.quant_api.Int8DynActInt4WeightQuantizer`: "8da4w" (requires ``torch>=2.3.0``)
145-
- :class:`~torchao.quantization.prototype.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" (requires ``torch>=2.4.0``)
156+
- :class:`~torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" (requires ``torch>=2.4.0``)
146157
147158
Args:
148159
quantizer (Optional[Callable]): A callable object that implements the `quantize` method.
149160
150161
Returns:
151162
Optional[str]: The quantization mode.
152163
"""
153-
return _quantizer_to_mode.get(type(quantizer), None)
164+
mode = _quantizer_to_mode.get(type(quantizer), None)
165+
if mode is not None and "module-swap" in mode:
166+
warn(
167+
"*QuantizerModuleSwap is deprecated. Please use the version without 'ModuleSwap' instead"
168+
)
169+
return mode
154170

155171

156172
def _get_disable_fake_quant(quantizer_mode: str) -> Callable:

0 commit comments

Comments
 (0)