27
27
except ImportError :
28
28
float4_sf_dtype = None
29
29
30
+ # TODO: put the ENUMs in the same place and import it
31
+ FORMAT_FP8 = 0
32
+ FORMAT_NVFP4 = 1
33
+
30
34
31
35
def modelopt_fp4_scale_to_cutlass_fp4_scale (modelopt_scale : torch .Tensor ) -> torch .Tensor :
32
36
"""Converts the modelopt FP4 per-block weight scale to the cutlass format (padded and swizzled)."""
@@ -160,6 +164,18 @@ def shard_load_hook(
160
164
def fuse_linear_weights (weights , ** kwargs ) -> Tuple [torch .Tensor , Dict [str , torch .Tensor ]]:
161
165
pass
162
166
167
+ @staticmethod
168
+ def custom_op ():
169
+ """Unified custom kernel entry-point for quantized linear."""
170
+ return torch .ops .auto_deploy .custom_quant_linear
171
+
172
+ @staticmethod
173
+ def build_custom_kwargs_for_linear (
174
+ scale_getattrs : Dict [str , Node ],
175
+ ) -> Dict [str , object ]:
176
+ """Default: no extra kwargs. Each impl overrides to pass the right inputs/scales/zps/format."""
177
+ return {}
178
+
163
179
164
180
class FP8QuantizationImpl (QuantizationImpl ):
165
181
@staticmethod
@@ -180,6 +196,20 @@ def scale_names() -> List[str]:
180
196
def default_scales (original_weight_shape : Tuple ) -> Dict [str , torch .Tensor ]:
181
197
return {"input_scale" : torch .tensor (1.0 ), "weight_scale" : torch .tensor (1.0 )}
182
198
199
+ @staticmethod
200
+ def build_custom_kwargs_for_linear (
201
+ scale_getattrs : Dict [str , Node ],
202
+ ) -> Dict [str , object ]:
203
+ # FP8 custom op contract:
204
+ # input_scale=[tensor], weight_scale=[tensor], input_zp=[], weight_zp=[], format_type=FORMAT_FP8
205
+ return dict (
206
+ input_scale = [scale_getattrs ["input_scale" ]],
207
+ weight_scale = [scale_getattrs ["weight_scale" ]],
208
+ input_zp = [],
209
+ weight_zp = [],
210
+ format_type = FORMAT_FP8 ,
211
+ )
212
+
183
213
@staticmethod
184
214
def load_hook (state_dict , prefix , * args , weight_name ):
185
215
if weight_name in state_dict :
@@ -264,6 +294,29 @@ def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]:
264
294
"alpha" : torch .tensor (1.0 / 6.0 ),
265
295
}
266
296
297
+ @staticmethod
298
+ def build_custom_kwargs_for_linear (
299
+ scale_getattrs : Dict [str , Node ],
300
+ ) -> Dict [str , object ]:
301
+ """
302
+ Contract:
303
+ custom_quant_linear(
304
+ x, Wq, bias,
305
+ input_scale=[s_in2],
306
+ weight_scale=[weight_scale_cutlass_uint8, alpha_fused],
307
+ input_zp=[],
308
+ weight_zp=[],
309
+ format_type=FORMAT_NVFP4
310
+ )
311
+ """
312
+ return dict (
313
+ input_scale = [scale_getattrs ["input_scale" ]],
314
+ weight_scale = [scale_getattrs ["weight_scale" ], scale_getattrs ["alpha" ]],
315
+ input_zp = [],
316
+ weight_zp = [],
317
+ format_type = FORMAT_NVFP4 ,
318
+ )
319
+
267
320
@staticmethod
268
321
def load_hook (state_dict , prefix , * args , weight_name ):
269
322
if weight_name in state_dict :
0 commit comments