Skip to content

Commit 13830e0

Browse files
authored
[MX] By default skip MX quantization on the output layer (meta-pytorch#1208)
1 parent 49d5cab commit 13830e0

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

torchtitan/config_manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,11 +459,12 @@ class MX:
459459
recipe_name: Literal["mxfp8"] = "mxfp8"
460460
"""If specified, creates float8 config from recipe name"""
461461

462-
filter_fqns: list[str] = field(default_factory=list)
462+
filter_fqns: list[str] = field(default_factory=lambda: ["output"])
463463
"""
464464
Comma-separated list of fully qualified names of modules to skip applying mxfloat8 training to.
465-
nn.Linear modules with any dim size not divisible by 16 are always skipped due to hardware requirements.
466-
Example: --MXFloat8.filter_fqns "attention.wq,attention.wk,attention.wv,output"
465+
nn.Linear modules with any dim size not divisible by 16 are also always skipped due to hardware requirements.
466+
By default we always skip the output layer.
467+
Example: --mx.filter_fqns "attention.wq,attention.wk,attention.wv,output"
467468
"""
468469

469470

0 commit comments

Comments
 (0)