Conversation
RuibinCheung
commented
Mar 27, 2026
- Implement MXFP8 hip kernel.
- Remove MXFP4 and MXFP8 triton kernel.
- Fix bug in mxfp4 quant kernel when enable 2d block.
49b05ed to
acfb0dd
Compare
There was a problem hiding this comment.
Pull request overview
This PR migrates MXFP8 quantization from Triton to a new HIP/CUDA kernel exposed via the PyTorch C++ extension, while removing the previous Triton MXFP4/MXFP8 kernels and adjusting FP8 GEMM dtype selection logic.
Changes:
- Add C++/HIP implementations + bindings for
quantize_mxfp8andquantize_mxfp8_dual(with shuffle support) and route Python MXFP8 quantization to these ops. - Remove Triton MXFP4/MXFP8 kernel source files.
- Update tests and FP8 GEMM helpers to use centralized FP8 dtype selection and MX block-size constants.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/pytorch/ops/test_quantization.py | Updates MXFP8 test coverage and adds a shuffle-scale test. |
| primus_turbo/triton/quantization/quantization_mxfp8.py | Removes Triton MXFP8 quant/dequant kernels. |
| primus_turbo/triton/quantization/quantization_mxfp4.py | Removes Triton MXFP4 quant/dequant kernels. |
| primus_turbo/pytorch/ops/grouped_gemm_fp8.py | Centralizes FP8 dtype selection (adds HYBRID support). |
| primus_turbo/pytorch/ops/gemm_fp8.py | Centralizes FP8 dtype selection (adds HYBRID support). |
| primus_turbo/pytorch/kernels/quantization/quantization_impl.py | Switches MXFP8 quantization to C++ ops; still references Triton dequant kernels. |
| primus_turbo/pytorch/core/low_precision.py | Introduces MXFP4/MXFP8 block-size constants used by configs/tests. |
| csrc/pytorch/quantization/quantization_meta.cpp | Adds meta implementations for MXFP8 quant ops. |
| csrc/pytorch/quantization/quantization.cpp | Implements MXFP8 quant ops and dispatch into new kernels. |
| csrc/pytorch/extensions.h | Declares new MXFP8 quant APIs. |
| csrc/pytorch/bindings_pytorch.cpp | Registers new MXFP8 ops with CUDA/Meta impls. |
| csrc/kernels/quantization/quantization_mxfp8.cu | Adds the new MXFP8 HIP/CUDA quantization kernels (single + dual). |
| csrc/kernels/quantization/quantization_mxfp4.cu | Fixes/cleans up shuffle + reduction utilities usage. |
| csrc/include/primus_turbo/shuffle.h | Adds MXFP8 shuffle layout constants. |
| csrc/include/primus_turbo/quantization.h | Adds MXFP8 constants/types + function declarations. |
| csrc/include/primus_turbo/device/utils.cuh | Consolidates low-level bitcast + fp16/bf16 unpack helpers. |
| csrc/include/primus_turbo/device/shuffle.cuh | Generalizes shuffled-index computation for FP4/FP8. |
| csrc/include/primus_turbo/device/reduce.cuh | Adds AMD wavefront DPP-based max reductions used by kernels. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: | ||
| assert x.is_contiguous(), "The x tensor must be contiguous." | ||
| assert x.dim() == 2, "The x must be 2D tensor." | ||
| # NOTE: quantize fp4 kernel use the ISA which only available on cdna4. |
There was a problem hiding this comment.
The NOTE comment refers to the “fp4 kernel” and CDNA4 ISA, but this function is quantize_mxfp8_impl and uses the MXFP8 backend. Please update the comment to accurately describe the MXFP8 kernel/support check to avoid confusion during maintenance.
| # NOTE: quantize fp4 kernel use the ISA which only available on cdna4. | |
| # NOTE: MXFP8 quantization kernels rely on ISA features only available on CDNA4; | |
| # check_mxfp8_support() verifies that this hardware support is present. |
| x_2d, | ||
| dest_dtype, | ||
| granularity=granularity, | ||
| block_size=MXFP4_BLOCK_SIZE, |
There was a problem hiding this comment.
In the MXFP8 shuffle test, quantize_fp8_with_trans(..., block_size=MXFP4_BLOCK_SIZE) is confusing/misleading since this is an FP8 path. Please use MXFP8_BLOCK_SIZE here to match intent and avoid accidental divergence if FP4/FP8 block sizes ever differ.
| block_size=MXFP4_BLOCK_SIZE, | |
| block_size=MXFP8_BLOCK_SIZE, |
| from primus_turbo.triton.quantization.quantization_mxfp4 import dequantize_mxfp4_kernel | ||
| from primus_turbo.triton.quantization.quantization_mxfp8 import ( | ||
| dequantize_mxfp8_kernel, | ||
| quantize_mxfp8_kernel, | ||
| ) | ||
| from primus_turbo.triton.quantization.quantization_mxfp8 import dequantize_mxfp8_kernel |
There was a problem hiding this comment.
primus_turbo.triton.quantization.quantization_mxfp4 / quantization_mxfp8 no longer exist (the directory now only contains quant_blockwise.py), so these imports will raise ModuleNotFoundError at import time and break MXFP4/MXFP8 dequantization paths that still reference the Triton kernels. Either restore the Triton dequantize kernels/modules or migrate dequantize_mxfp4_impl/dequantize_mxfp8_impl to the new C++/HIP backend and remove these imports.