27
27
except ImportError :
28
28
float4_sf_dtype = None
29
29
30
+ # TODO: put the ENUMs in the same place and import it
31
+ FORMAT_FP8 = 0
32
+ FORMAT_NVFP4 = 1
33
+
30
34
31
35
def modelopt_fp4_scale_to_cutlass_fp4_scale (modelopt_scale : torch .Tensor ) -> torch .Tensor :
32
36
"""Converts the modelopt FP4 per-block weight scale to the cutlass format (padded and swizzled)."""
@@ -81,6 +85,7 @@ def create(quant_type_or_node: Union[str, Node], is_bmm: bool = False):
81
85
quantization_impl_map = {
82
86
"" : None ,
83
87
"FP8" : FP8BMMQuantizationImpl ,
88
+ "NVFP4" : None , # BMM NVFP4 is not supported yet
84
89
}
85
90
else :
86
91
quantization_impl_map = {
@@ -160,6 +165,18 @@ def shard_load_hook(
160
165
def fuse_linear_weights (weights , ** kwargs ) -> Tuple [torch .Tensor , Dict [str , torch .Tensor ]]:
161
166
pass
162
167
168
+ @staticmethod
169
+ def custom_op ():
170
+ """Unified custom kernel entry-point for quantized linear."""
171
+ return torch .ops .auto_deploy .custom_quant_linear
172
+
173
+ @staticmethod
174
+ def build_custom_kwargs_for_linear (
175
+ scale_getattrs : Dict [str , Node ],
176
+ ) -> Dict [str , object ]:
177
+ """Default: no extra kwargs. Each impl overrides to pass the right inputs/scales/zps/format."""
178
+ return {}
179
+
163
180
164
181
class FP8QuantizationImpl (QuantizationImpl ):
165
182
@staticmethod
@@ -180,6 +197,20 @@ def scale_names() -> List[str]:
180
197
def default_scales (original_weight_shape : Tuple ) -> Dict [str , torch .Tensor ]:
181
198
return {"input_scale" : torch .tensor (1.0 ), "weight_scale" : torch .tensor (1.0 )}
182
199
200
+ @staticmethod
201
+ def build_custom_kwargs_for_linear (
202
+ scale_getattrs : Dict [str , Node ],
203
+ ) -> Dict [str , object ]:
204
+ # FP8 custom op contract:
205
+ # input_scale=[tensor], weight_scale=[tensor], input_zp=[], weight_zp=[], format_type=FORMAT_FP8
206
+ return dict (
207
+ input_scale = [scale_getattrs ["input_scale" ]],
208
+ weight_scale = [scale_getattrs ["weight_scale" ]],
209
+ input_zp = [],
210
+ weight_zp = [],
211
+ format_type = FORMAT_FP8 ,
212
+ )
213
+
183
214
@staticmethod
184
215
def load_hook (state_dict , prefix , * args , weight_name ):
185
216
if weight_name in state_dict :
@@ -264,6 +295,29 @@ def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]:
264
295
"alpha" : torch .tensor (1.0 / 6.0 ),
265
296
}
266
297
298
+ @staticmethod
299
+ def build_custom_kwargs_for_linear (
300
+ scale_getattrs : Dict [str , Node ],
301
+ ) -> Dict [str , object ]:
302
+ """
303
+ Contract:
304
+ custom_quant_linear(
305
+ x, Wq, bias,
306
+ input_scale=[s_in2],
307
+ weight_scale=[weight_scale_cutlass_uint8, alpha_fused],
308
+ input_zp=[],
309
+ weight_zp=[],
310
+ format_type=FORMAT_NVFP4
311
+ )
312
+ """
313
+ return dict (
314
+ input_scale = [scale_getattrs ["input_scale" ]],
315
+ weight_scale = [scale_getattrs ["weight_scale" ], scale_getattrs ["alpha" ]],
316
+ input_zp = [],
317
+ weight_zp = [],
318
+ format_type = FORMAT_NVFP4 ,
319
+ )
320
+
267
321
@staticmethod
268
322
def load_hook (state_dict , prefix , * args , weight_name ):
269
323
if weight_name in state_dict :
0 commit comments