@@ -31,34 +31,33 @@ def is_deep_gemm_supported() -> bool:
31
31
32
32
33
33
@functools .cache
34
- def is_blackwell_deep_gemm_e8m0_used () -> bool :
34
+ def is_deep_gemm_e8m0_used () -> bool :
35
35
"""Return ``True`` if vLLM is configured to use DeepGEMM "
36
- "E8M0 scale on a Blackwell-class GPU.
36
+ "E8M0 scale on a Hopper or Blackwell-class GPU.
37
37
"""
38
38
if not is_deep_gemm_supported ():
39
- logger .debug_once (
39
+ logger .info_once (
40
40
"DeepGEMM E8M0 disabled: DeepGEMM not supported on this system." )
41
41
return False
42
42
43
- if not envs .VLLM_USE_DEEP_GEMM_E8M0 :
44
- logger .debug_once ("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0." )
45
- return False
46
-
47
43
_lazy_init ()
48
44
49
45
if _fp8_gemm_nt_impl is None :
50
- logger .debug_once (
51
- "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found" )
46
+ logger .info_once ("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found" )
52
47
return False
53
48
54
- enabled = (current_platform .is_cuda ()
55
- and current_platform .has_device_capability (100 ))
56
- if enabled :
57
- logger .debug_once ("DeepGEMM E8M0 enabled on Blackwell GPU." )
58
- else :
59
- logger .debug_once (
60
- "DeepGEMM E8M0 disabled: not running on Blackwell GPU." )
61
- return enabled
49
+ if current_platform .is_device_capability (100 ) and \
50
+ envs .VLLM_USE_DEEP_GEMM_E8M0 :
51
+ logger .info_once ("DeepGEMM E8M0 enabled on Blackwell GPU." )
52
+ return True
53
+
54
+ if current_platform .is_device_capability (90 ) and \
55
+ envs .VLLM_USE_DEEP_GEMM_E8M0_HOPPER :
56
+ logger .info_once ("DeepGEMM E8M0 enabled on Hopper GPU." )
57
+ return True
58
+
59
+ logger .info_once ("DeepGEMM E8M0 disabled on current configuration." )
60
+ return False
62
61
63
62
64
63
def _missing (* _ : Any , ** __ : Any ) -> NoReturn :
@@ -124,30 +123,26 @@ def fp8_gemm_nt(*args, **kwargs):
124
123
_lazy_init ()
125
124
if _fp8_gemm_nt_impl is None :
126
125
return _missing (* args , ** kwargs )
127
- return _fp8_gemm_nt_impl (
128
- * args ,
129
- disable_ue8m0_cast = not is_blackwell_deep_gemm_e8m0_used (),
130
- ** kwargs )
126
+ return _fp8_gemm_nt_impl (* args ,
127
+ disable_ue8m0_cast = not is_deep_gemm_e8m0_used (),
128
+ ** kwargs )
131
129
132
130
133
131
def m_grouped_fp8_gemm_nt_contiguous (* args , ** kwargs ):
134
132
_lazy_init ()
135
133
if _grouped_impl is None :
136
134
return _missing (* args , ** kwargs )
137
- return _grouped_impl (
138
- * args ,
139
- disable_ue8m0_cast = not is_blackwell_deep_gemm_e8m0_used (),
140
- ** kwargs )
135
+ return _grouped_impl (* args ,
136
+ disable_ue8m0_cast = not is_deep_gemm_e8m0_used (),
137
+ ** kwargs )
141
138
142
139
143
140
def fp8_m_grouped_gemm_nt_masked (* args , ** kwargs ):
144
141
_lazy_init ()
145
142
if _grouped_masked_impl is None :
146
143
return _missing (* args , ** kwargs )
147
144
return _grouped_masked_impl (
148
- * args ,
149
- disable_ue8m0_cast = not is_blackwell_deep_gemm_e8m0_used (),
150
- ** kwargs )
145
+ * args , disable_ue8m0_cast = not is_deep_gemm_e8m0_used (), ** kwargs )
151
146
152
147
153
148
def _ceil_to_ue8m0 (x : torch .Tensor ):
@@ -211,7 +206,7 @@ def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
211
206
"m_grouped_fp8_gemm_nt_contiguous" ,
212
207
"fp8_m_grouped_gemm_nt_masked" ,
213
208
"per_block_cast_to_fp8" ,
214
- "is_blackwell_deep_gemm_e8m0_used " ,
209
+ "is_deep_gemm_e8m0_used " ,
215
210
"is_deep_gemm_supported" ,
216
211
"should_use_deepgemm_for_fp8_linear" ,
217
212
]
0 commit comments