Quantization-Aware Training via FP-Quant and QuTLASS #1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR serves as a demo on how to integrate Quantization-Aware Training (QAT) into an end-to-end training and evaluation pipeline.
Extra Dependencies
It adds the following dependencies on top to enable QAT:
CUDA_HOME.Llama-3.1-8B-Instructexample). Run withuv pip install git+https://github.com/BlackSamorez/transformers@fpquant_backwards trl liger_kernel && python -m torch.distributed.launch --nproc_per_node=8 transformers_distill.pyfrom the existinguvenv.Changes Summary
Quantization is configured in the newly-created
fpquant.pyfile, where the following configuration is applied to the model:This configuration specifies that the model weights are trained in
MXFP4with block-Hadamard transforms of size 128, which improve training stability. For better convergence, the master weights are still stored and updated in FP32. On the backward pass,MXFP8GEMMs are used to speed up training.QAT Scheme
(Scheme inspired by "Pretraining Large Language Models with NVFP4")
On the forward pass, both weights and activations are normalized via the Hadamard transform in groups of 128 values and quantized to MXFP4 via the
fusedQuantizeMxkernel call. The quantization is performed along the inner GEMM dimension (as required by MXFP GEMM tensor cores). The quantized weights and activations are then saved for the backward pass. On the backward pass, weights and activations have to be re-quantized along a different dimension to allow for MXFP8 backward GEMMs. The dequantization from MXFP4, transposition and quantization to MXFP8 is fused and requires just one low-bitwidth global memory round trip in themxfp4_transpose_mxfp8kernel. The gradient is used in both backward GEMMs, but with a different product dimension. To minimize memory movement, we quantize the gradient in 32x32 blocks, to allow for the single resulting MXFP8 tensor (with two separate row/column-replicated scales) to be used in both GEMMs withbackward_bf16_square_double_mxfp8. Since the memory layout for the two backward GEMMs ends up different, we have to provide two separate MXFP8 GEMM kernels.The implementation is most readable in the FP-Quant autograd MXFP4:MXFP8 function.
This configuration is applied to all
nn.Linearlayers (except the finallm_head) after model initialization and loading. When loading for evaluations or inference, we quantize the master weights once when loading (withstore_master_weights=False) to not have to re-quantize them on every forward pass, as needed for training.These changes allow for smaller model size, in terms of memory, and faster training and inference, as well as larger training batch size.
QAT Setups
d32 below $500 training
For the training the
d32model variant (~2B parameters on ~40B tokens), the total runtime on an 8xB200 GPU server came down to 20hrs. This is 3hrs faster than the BF16 baseline on the same machine, and 13hrs faster than the BF16 baseline on a Hopper (H100) machine. End-to-end, this brings the total dollar cost to slightly below $500, assuming a price of $3/h per B200, which can be obtained for instance at vast.ai or datacrunch.io.The resulting trained model can be executed in MXFP4, theoretically allowing for up to 8x, 4x and 6x speedups over BF16 on RTX 5090, B200 and B300 respectively (from FLOPs specs). In practice, for the d32 model, we observe a 2.3x increase in maximum throughput on 5090 with our kernels.
One should note that our research on FP4 QAT, consolidated in the Quartet paper (to appear at NeurIPS 2025) indicates that the speedups would improve significantly with model size increase and the quality gap would decrease with both model size and training duration increase.
Llama-3.1-8B-Instruct
Presented in
transformers_distill.py, this recipe showcases how QAT can be used to restore the quality of originally high-precision models after quantization. Via the transformers integration, it applies the same MXFP4:MXFP8 QAT scheme toLlama-3.1-8B, allowing for recovering more than half of lost performance (as measured by various benchmarks) after training for just ~100M tokens, training 30% faster than BF16 pseudo-quantization QAT.You can read more about FP4 for post-training quantization in our latest paper on the matter. A collection of QAT-tuned FP4 models is available on Hugging Face Hub.