Skip to content

Commit 7f80e6f

Browse files
committed
Update logic to disable quantizers
Signed-off-by: ajrasane <[email protected]>
1 parent e4b374a commit 7f80e6f

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,21 @@ def build_quant_cfg(
174174
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
175175
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
176176

177+
# Qwen3 specific quantizer disabling patterns (thinker.model.layers only)
178+
if "qkv_disabled" in qformat:
179+
quant_cfg = copy.deepcopy(quant_cfg) # Don't modify global config
180+
for proj in ["q_proj", "k_proj", "v_proj"]:
181+
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
182+
"enable": False
183+
}
184+
if "qkvo_disabled" in qformat:
185+
if "qkv_disabled" not in qformat: # Avoid double deepcopy
186+
quant_cfg = copy.deepcopy(quant_cfg)
187+
for proj in ["o_proj"]:
188+
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
189+
"enable": False
190+
}
191+
177192
return quant_cfg
178193

179194

examples/llm_ptq/hf_ptq.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import argparse
17+
import contextlib
1718
import os
1819
import random
1920
import time
@@ -298,12 +299,8 @@ def main(args):
298299
use_seq_device_map=args.use_seq_device_map,
299300
attn_implementation=args.attn_implementation,
300301
)
301-
else:
302-
assert args.qformat in QUANT_CFG_CHOICES, (
303-
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
304-
)
305-
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
306302

303+
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
307304
# Qwen3 specific quantizer disabling patterns (thinker.model.layers only)
308305
if "qkv_disabled" in args.qformat:
309306
# Disable q_proj, k_proj, v_proj quantizers
@@ -325,6 +322,11 @@ def main(args):
325322
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
326323
for i in range(total_layers - n_layers_to_disable, total_layers):
327324
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
325+
else:
326+
assert args.qformat in QUANT_CFG_CHOICES, (
327+
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
328+
)
329+
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
328330

329331
if args.kv_cache_qformat != "none":
330332
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
@@ -357,6 +359,8 @@ def main(args):
357359
# since parameters are distributed. Force cuda:0 for input tensors.
358360
if device is None or str(device) in ("meta", "cpu"):
359361
device = "cuda"
362+
print(f"Overriding device to {device}")
363+
360364
processor = None
361365
tokenizer = None
362366

@@ -646,11 +650,6 @@ def main(args):
646650
print("Updating full_model with quantized language_model...")
647651
language_model_lineage[-2].language_model = model
648652

649-
# if args.verbose:
650-
# mtq.print_quant_summary(full_model)
651-
652-
import contextlib
653-
654653
if args.verbose:
655654
with open("./quant_summary.txt", "w") as f, contextlib.redirect_stdout(f):
656655
mtq.print_quant_summary(full_model)

0 commit comments

Comments
 (0)