Skip to content

Commit d473fc2

Browse files
elvischenvcboss6
authored andcommitted
[Perf] Use NVIDIA hardware-accelerated instruction for float to fp8_e4m3 quantization (vllm-project#24757)
Signed-off-by: elvischenv <[email protected]> Signed-off-by: bruceszchen <[email protected]>
1 parent 1e5d632 commit d473fc2

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

csrc/quantization/fp8/common.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
#include <cmath>
77

8-
#ifdef USE_ROCM
8+
#ifndef USE_ROCM
9+
#include "nvidia/quant_utils.cuh"
10+
#else
911
#include "amd/quant_utils.cuh"
1012
#endif
1113

@@ -48,7 +50,9 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
4850
float r =
4951
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
5052
#ifndef USE_ROCM
51-
return static_cast<fp8_type>(r);
53+
// Use hardware cvt instruction for fp8 on nvidia
54+
// Currently only support fp8_type = c10::Float8_e4m3fn
55+
return fp8::vec_conversion<fp8_type, float>(r);
5256
#else
5357
// Use hardware cvt instruction for fp8 on rocm
5458
return fp8::cvt_c10<fp8_type>(r);

csrc/quantization/fp8/nvidia/quant_utils.cuh

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,26 @@ namespace vllm {
1212
namespace fp8 {
1313
#ifdef ENABLE_FP8
1414

15-
#if 0 // Disable the following code to reduce the binary size.
1615
template <typename Tout, typename Tin>
17-
__inline__ __device__ Tout
18-
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
16+
__inline__ __device__ Tout vec_conversion(
17+
const Tin& x, const __nv_fp8_interpretation_t fp8_type = __NV_E4M3) {
1918
return x;
2019
}
2120

21+
// float -> c10::Float8_e4m3fn
22+
template <>
23+
__inline__ __device__ c10::Float8_e4m3fn
24+
vec_conversion<c10::Float8_e4m3fn, float>(
25+
const float& a, const __nv_fp8_interpretation_t fp8_type) {
26+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
27+
return static_cast<c10::Float8_e4m3fn>(a);
28+
#else
29+
return c10::Float8_e4m3fn(__nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type),
30+
c10::Float8_e4m3fn::from_bits());
31+
#endif
32+
}
33+
34+
#if 0 // Disable the following code to reduce the binary size.
2235
// fp8 -> half
2336
template <>
2437
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(

0 commit comments

Comments
 (0)