Skip to content

feat: refine mxfp8 quant kernel#261

Open
RuibinCheung wants to merge 2 commits intomainfrom
feat/zhangrb/refine_mxfp8_quant
Open

feat: refine mxfp8 quant kernel#261
RuibinCheung wants to merge 2 commits intomainfrom
feat/zhangrb/refine_mxfp8_quant

Conversation

@RuibinCheung
Copy link
Copy Markdown
Contributor

  • Implement MXFP8 hip kernel.
  • Remove MXFP4 and MXFP8 triton kernel.
  • Fix bug in mxfp4 quant kernel when enable 2d block.

Copilot AI review requested due to automatic review settings March 27, 2026 04:14
@RuibinCheung RuibinCheung force-pushed the feat/zhangrb/refine_mxfp8_quant branch from 49b05ed to acfb0dd Compare March 27, 2026 04:59
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR migrates MXFP8 quantization from Triton to a new HIP/CUDA kernel exposed via the PyTorch C++ extension, while removing the previous Triton MXFP4/MXFP8 kernels and adjusting FP8 GEMM dtype selection logic.

Changes:

  • Add C++/HIP implementations + bindings for quantize_mxfp8 and quantize_mxfp8_dual (with shuffle support) and route Python MXFP8 quantization to these ops.
  • Remove Triton MXFP4/MXFP8 kernel source files.
  • Update tests and FP8 GEMM helpers to use centralized FP8 dtype selection and MX block-size constants.

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/pytorch/ops/test_quantization.py Updates MXFP8 test coverage and adds a shuffle-scale test.
primus_turbo/triton/quantization/quantization_mxfp8.py Removes Triton MXFP8 quant/dequant kernels.
primus_turbo/triton/quantization/quantization_mxfp4.py Removes Triton MXFP4 quant/dequant kernels.
primus_turbo/pytorch/ops/grouped_gemm_fp8.py Centralizes FP8 dtype selection (adds HYBRID support).
primus_turbo/pytorch/ops/gemm_fp8.py Centralizes FP8 dtype selection (adds HYBRID support).
primus_turbo/pytorch/kernels/quantization/quantization_impl.py Switches MXFP8 quantization to C++ ops; still references Triton dequant kernels.
primus_turbo/pytorch/core/low_precision.py Introduces MXFP4/MXFP8 block-size constants used by configs/tests.
csrc/pytorch/quantization/quantization_meta.cpp Adds meta implementations for MXFP8 quant ops.
csrc/pytorch/quantization/quantization.cpp Implements MXFP8 quant ops and dispatch into new kernels.
csrc/pytorch/extensions.h Declares new MXFP8 quant APIs.
csrc/pytorch/bindings_pytorch.cpp Registers new MXFP8 ops with CUDA/Meta impls.
csrc/kernels/quantization/quantization_mxfp8.cu Adds the new MXFP8 HIP/CUDA quantization kernels (single + dual).
csrc/kernels/quantization/quantization_mxfp4.cu Fixes/cleans up shuffle + reduction utilities usage.
csrc/include/primus_turbo/shuffle.h Adds MXFP8 shuffle layout constants.
csrc/include/primus_turbo/quantization.h Adds MXFP8 constants/types + function declarations.
csrc/include/primus_turbo/device/utils.cuh Consolidates low-level bitcast + fp16/bf16 unpack helpers.
csrc/include/primus_turbo/device/shuffle.cuh Generalizes shuffled-index computation for FP4/FP8.
csrc/include/primus_turbo/device/reduce.cuh Adds AMD wavefront DPP-based max reductions used by kernels.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
assert x.is_contiguous(), "The x tensor must be contiguous."
assert x.dim() == 2, "The x must be 2D tensor."
# NOTE: quantize fp4 kernel use the ISA which only available on cdna4.
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The NOTE comment refers to the “fp4 kernel” and CDNA4 ISA, but this function is quantize_mxfp8_impl and uses the MXFP8 backend. Please update the comment to accurately describe the MXFP8 kernel/support check to avoid confusion during maintenance.

Suggested change
# NOTE: quantize fp4 kernel use the ISA which only available on cdna4.
# NOTE: MXFP8 quantization kernels rely on ISA features only available on CDNA4;
# check_mxfp8_support() verifies that this hardware support is present.

Copilot uses AI. Check for mistakes.
x_2d,
dest_dtype,
granularity=granularity,
block_size=MXFP4_BLOCK_SIZE,
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the MXFP8 shuffle test, quantize_fp8_with_trans(..., block_size=MXFP4_BLOCK_SIZE) is confusing/misleading since this is an FP8 path. Please use MXFP8_BLOCK_SIZE here to match intent and avoid accidental divergence if FP4/FP8 block sizes ever differ.

Suggested change
block_size=MXFP4_BLOCK_SIZE,
block_size=MXFP8_BLOCK_SIZE,

Copilot uses AI. Check for mistakes.
Comment on lines 23 to +24
from primus_turbo.triton.quantization.quantization_mxfp4 import dequantize_mxfp4_kernel
from primus_turbo.triton.quantization.quantization_mxfp8 import (
dequantize_mxfp8_kernel,
quantize_mxfp8_kernel,
)
from primus_turbo.triton.quantization.quantization_mxfp8 import dequantize_mxfp8_kernel
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

primus_turbo.triton.quantization.quantization_mxfp4 / quantization_mxfp8 no longer exist (the directory now only contains quant_blockwise.py), so these imports will raise ModuleNotFoundError at import time and break MXFP4/MXFP8 dequantization paths that still reference the Triton kernels. Either restore the Triton dequantize kernels/modules or migrate dequantize_mxfp4_impl/dequantize_mxfp8_impl to the new C++/HIP backend and remove these imports.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants