Skip to content

Commit a899da2

Browse files
authored
Add int4 weight-only QAT flow targeting tinygemm kernel (#1570)
1 parent b846407 commit a899da2

File tree

1 file changed

+58
-5
lines changed

1 file changed

+58
-5
lines changed

torchtune/training/quantization.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,38 @@
66

77
from typing import Callable, Optional
88

9-
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
9+
from torchao.dtypes import TensorCoreTiledLayoutType
10+
from torchao.quantization import (
11+
int4_weight_only,
12+
int8_dynamic_activation_int4_weight,
13+
quantize_,
14+
)
1015
from torchao.quantization.prototype.qat import (
16+
disable_4w_fake_quant,
1117
disable_8da4w_fake_quant,
18+
enable_4w_fake_quant,
1219
enable_8da4w_fake_quant,
20+
Int4WeightOnlyQATQuantizer,
1321
Int8DynActInt4WeightQATQuantizer,
1422
)
1523
from torchao.quantization.prototype.qat._module_swap_api import (
24+
disable_4w_fake_quant_module_swap,
1625
disable_8da4w_fake_quant_module_swap,
26+
enable_4w_fake_quant_module_swap,
1727
enable_8da4w_fake_quant_module_swap,
28+
Int4WeightOnlyQATQuantizerModuleSwap,
1829
Int8DynActInt4WeightQATQuantizerModuleSwap,
1930
)
2031

2132

2233
__all__ = [
2334
"get_quantizer_mode",
35+
"Int4WeightOnlyQuantizer",
36+
"Int4WeightOnlyQATQuantizer",
37+
"Int4WeightOnlyQATQuantizerModuleSwap",
2438
"Int8DynActInt4WeightQuantizer",
2539
"Int8DynActInt4WeightQATQuantizer",
40+
"Int8DynActInt4WeightQATQuantizerModuleSwap",
2641
]
2742

2843

@@ -57,14 +72,52 @@ def quantize(self, model):
5772
_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant
5873

5974

60-
# ====================================================
61-
# int8 dynamic activations + int4 weight module swap |
62-
# ====================================================
75+
# ==================
76+
# int4 weight only |
77+
# ==================
78+
79+
80+
class Int4WeightOnlyQuantizer:
81+
"""
82+
Quantizer for applying int4 per group weight only quantization
83+
to linear layers in the model using the efficient tinygemm kernel.
84+
"""
85+
86+
def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8):
87+
self.groupsize = groupsize
88+
self.inner_k_tiles = inner_k_tiles
89+
90+
def quantize(self, model):
91+
layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles)
92+
quantize_fn = int4_weight_only(self.groupsize, layout_type)
93+
quantize_(model, quantize_fn)
94+
return model
95+
96+
97+
_quantizer_to_mode[Int4WeightOnlyQuantizer] = "4w"
98+
_quantizer_to_mode[Int4WeightOnlyQATQuantizer] = "4w-qat"
99+
_quantizer_mode_to_disable_fake_quant["4w-qat"] = disable_4w_fake_quant
100+
_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant
101+
102+
103+
# =============
104+
# module swap |
105+
# =============
63106

64107
# Note: QAT tensor subclass implementation in torchao only works
65108
# with FSDP2 today. For other distribution strategies like DDP and
66109
# FSDP1, users will need to fall back to the old module swap flow.
67-
__all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap")
110+
111+
# int4 weight-only
112+
_quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap"
113+
_quantizer_mode_to_disable_fake_quant[
114+
"4w-qat-module-swap"
115+
] = disable_4w_fake_quant_module_swap
116+
_quantizer_mode_to_enable_fake_quant[
117+
"4w-qat-module-swap"
118+
] = enable_4w_fake_quant_module_swap
119+
120+
# int8 dynamic activations + int4 weight
68121
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap"
69122
_quantizer_mode_to_disable_fake_quant[
70123
"8da4w-qat-module-swap"

0 commit comments

Comments
 (0)