Skip to content

Conversation

@BlackSamorez
Copy link
Collaborator

@BlackSamorez BlackSamorez commented Oct 27, 2025

This PR serves as a demo on how to integrate Quantization-Aware Training (QAT) into an end-to-end training and evaluation pipeline.

Extra Dependencies

It adds the following dependencies on top to enable QAT:

  • QuTLASS: a CUDA kernel library that provides MXFP4 and MXFP8 GEMM and quantization-related kernels for both forward and backward passes. Requires set CUDA_HOME.
  • FP-Quant: a wrapper over QuTLASS to simplify model loading and kernel dispatch.
  • (Optional) transformers/accelerate: to demonstrate how FP-Quant allows for QAT of pre-trained models from HF Hub (Llama-3.1-8B-Instruct example). Run with uv pip install git+https://github.com/BlackSamorez/transformers@fpquant_backwards trl liger_kernel && python -m torch.distributed.launch --nproc_per_node=8 transformers_distill.py from the existing uv env.

Changes Summary

Quantization is configured in the newly-created fpquant.py file, where the following configuration is applied to the model:

FPQuantConfig(
    forward_dtype=FPQuantDtype.MXFP4,
    forward_method="abs_max",
    hadamard_group_size=128,
    backward_dtype=FPQuantDtype.MXFP8,
    store_master_weights=store_master_weights,
    modules_to_not_convert=["lm_head"],
)

This configuration specifies that the model weights are trained in MXFP4 with block-Hadamard transforms of size 128, which improve training stability. For better convergence, the master weights are still stored and updated in FP32. On the backward pass, MXFP8 GEMMs are used to speed up training.

QAT Scheme

nanochat_qat drawio

(Scheme inspired by "Pretraining Large Language Models with NVFP4")

On the forward pass, both weights and activations are normalized via the Hadamard transform in groups of 128 values and quantized to MXFP4 via the fusedQuantizeMx kernel call. The quantization is performed along the inner GEMM dimension (as required by MXFP GEMM tensor cores). The quantized weights and activations are then saved for the backward pass. On the backward pass, weights and activations have to be re-quantized along a different dimension to allow for MXFP8 backward GEMMs. The dequantization from MXFP4, transposition and quantization to MXFP8 is fused and requires just one low-bitwidth global memory round trip in the mxfp4_transpose_mxfp8 kernel. The gradient is used in both backward GEMMs, but with a different product dimension. To minimize memory movement, we quantize the gradient in 32x32 blocks, to allow for the single resulting MXFP8 tensor (with two separate row/column-replicated scales) to be used in both GEMMs with backward_bf16_square_double_mxfp8. Since the memory layout for the two backward GEMMs ends up different, we have to provide two separate MXFP8 GEMM kernels.

The implementation is most readable in the FP-Quant autograd MXFP4:MXFP8 function.

This configuration is applied to all nn.Linear layers (except the final lm_head) after model initialization and loading. When loading for evaluations or inference, we quantize the master weights once when loading (with store_master_weights=False) to not have to re-quantize them on every forward pass, as needed for training.

These changes allow for smaller model size, in terms of memory, and faster training and inference, as well as larger training batch size.

QAT Setups

d32 below $500 training

For the training the d32 model variant (~2B parameters on ~40B tokens), the total runtime on an 8xB200 GPU server came down to 20hrs. This is 3hrs faster than the BF16 baseline on the same machine, and 13hrs faster than the BF16 baseline on a Hopper (H100) machine. End-to-end, this brings the total dollar cost to slightly below $500, assuming a price of $3/h per B200, which can be obtained for instance at vast.ai or datacrunch.io.

The resulting trained model can be executed in MXFP4, theoretically allowing for up to 8x, 4x and 6x speedups over BF16 on RTX 5090, B200 and B300 respectively (from FLOPs specs). In practice, for the d32 model, we observe a 2.3x increase in maximum throughput on 5090 with our kernels.

One should note that our research on FP4 QAT, consolidated in the Quartet paper (to appear at NeurIPS 2025) indicates that the speedups would improve significantly with model size increase and the quality gap would decrease with both model size and training duration increase.

image

Llama-3.1-8B-Instruct

Presented in transformers_distill.py, this recipe showcases how QAT can be used to restore the quality of originally high-precision models after quantization. Via the transformers integration, it applies the same MXFP4:MXFP8 QAT scheme to Llama-3.1-8B, allowing for recovering more than half of lost performance (as measured by various benchmarks) after training for just ~100M tokens, training 30% faster than BF16 pseudo-quantization QAT.

image_2025-10-28_09-10-39-imageonline co-merged

You can read more about FP4 for post-training quantization in our latest paper on the matter. A collection of QAT-tuned FP4 models is available on Hugging Face Hub.

def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = F.relu(x).square().to(torch.bfloat16)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully fuse cast to BF16 with act fn to not have to cast explicitly for quantized GEMM kernels again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants