|
6 | 6 |
|
7 | 7 | from typing import Callable, Optional |
8 | 8 |
|
9 | | -from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ |
| 9 | +from torchao.dtypes import TensorCoreTiledLayoutType |
| 10 | +from torchao.quantization import ( |
| 11 | + int4_weight_only, |
| 12 | + int8_dynamic_activation_int4_weight, |
| 13 | + quantize_, |
| 14 | +) |
10 | 15 | from torchao.quantization.prototype.qat import ( |
| 16 | + disable_4w_fake_quant, |
11 | 17 | disable_8da4w_fake_quant, |
| 18 | + enable_4w_fake_quant, |
12 | 19 | enable_8da4w_fake_quant, |
| 20 | + Int4WeightOnlyQATQuantizer, |
13 | 21 | Int8DynActInt4WeightQATQuantizer, |
14 | 22 | ) |
15 | 23 | from torchao.quantization.prototype.qat._module_swap_api import ( |
| 24 | + disable_4w_fake_quant_module_swap, |
16 | 25 | disable_8da4w_fake_quant_module_swap, |
| 26 | + enable_4w_fake_quant_module_swap, |
17 | 27 | enable_8da4w_fake_quant_module_swap, |
| 28 | + Int4WeightOnlyQATQuantizerModuleSwap, |
18 | 29 | Int8DynActInt4WeightQATQuantizerModuleSwap, |
19 | 30 | ) |
20 | 31 |
|
21 | 32 |
|
22 | 33 | __all__ = [ |
23 | 34 | "get_quantizer_mode", |
| 35 | + "Int4WeightOnlyQuantizer", |
| 36 | + "Int4WeightOnlyQATQuantizer", |
| 37 | + "Int4WeightOnlyQATQuantizerModuleSwap", |
24 | 38 | "Int8DynActInt4WeightQuantizer", |
25 | 39 | "Int8DynActInt4WeightQATQuantizer", |
| 40 | + "Int8DynActInt4WeightQATQuantizerModuleSwap", |
26 | 41 | ] |
27 | 42 |
|
28 | 43 |
|
@@ -57,14 +72,52 @@ def quantize(self, model): |
57 | 72 | _quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant |
58 | 73 |
|
59 | 74 |
|
60 | | -# ==================================================== |
61 | | -# int8 dynamic activations + int4 weight module swap | |
62 | | -# ==================================================== |
| 75 | +# ================== |
| 76 | +# int4 weight only | |
| 77 | +# ================== |
| 78 | + |
| 79 | + |
| 80 | +class Int4WeightOnlyQuantizer: |
| 81 | + """ |
| 82 | + Quantizer for applying int4 per group weight only quantization |
| 83 | + to linear layers in the model using the efficient tinygemm kernel. |
| 84 | + """ |
| 85 | + |
| 86 | + def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8): |
| 87 | + self.groupsize = groupsize |
| 88 | + self.inner_k_tiles = inner_k_tiles |
| 89 | + |
| 90 | + def quantize(self, model): |
| 91 | + layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles) |
| 92 | + quantize_fn = int4_weight_only(self.groupsize, layout_type) |
| 93 | + quantize_(model, quantize_fn) |
| 94 | + return model |
| 95 | + |
| 96 | + |
| 97 | +_quantizer_to_mode[Int4WeightOnlyQuantizer] = "4w" |
| 98 | +_quantizer_to_mode[Int4WeightOnlyQATQuantizer] = "4w-qat" |
| 99 | +_quantizer_mode_to_disable_fake_quant["4w-qat"] = disable_4w_fake_quant |
| 100 | +_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant |
| 101 | + |
| 102 | + |
| 103 | +# ============= |
| 104 | +# module swap | |
| 105 | +# ============= |
63 | 106 |
|
64 | 107 | # Note: QAT tensor subclass implementation in torchao only works |
65 | 108 | # with FSDP2 today. For other distribution strategies like DDP and |
66 | 109 | # FSDP1, users will need to fall back to the old module swap flow. |
67 | | -__all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap") |
| 110 | + |
| 111 | +# int4 weight-only |
| 112 | +_quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap" |
| 113 | +_quantizer_mode_to_disable_fake_quant[ |
| 114 | + "4w-qat-module-swap" |
| 115 | +] = disable_4w_fake_quant_module_swap |
| 116 | +_quantizer_mode_to_enable_fake_quant[ |
| 117 | + "4w-qat-module-swap" |
| 118 | +] = enable_4w_fake_quant_module_swap |
| 119 | + |
| 120 | +# int8 dynamic activations + int4 weight |
68 | 121 | _quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap" |
69 | 122 | _quantizer_mode_to_disable_fake_quant[ |
70 | 123 | "8da4w-qat-module-swap" |
|
0 commit comments