@@ -97,7 +97,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
97
97
template <typename TypeAct>
98
98
std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> create_weight_quant_runner ()
99
99
{
100
- if (isWInt8Quant ())
100
+ if (isInt8Quant ())
101
101
{
102
102
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, uint8_t >>();
103
103
}
@@ -109,10 +109,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
109
109
return std::make_unique<
110
110
kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t , TypeAct, TypeAct>>();
111
111
}
112
- else
113
- {
114
- return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, cutlass::uint4b_t >>();
115
- }
116
112
#endif
117
113
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, cutlass::uint4b_t >>();
118
114
}
@@ -123,16 +119,15 @@ class FusedMoeRunner : public torch::CustomClassHolder
123
119
}
124
120
125
121
FusedMoeRunner (c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype,
126
- bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_woq_group_scaling ,
122
+ bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_int8_woq_per_channel ,
127
123
bool use_mxfp8_act_scaling, bool use_fused_finalize)
128
124
{
129
125
mActivationDtype = activation_dtype;
130
126
mWeightDtype = weight_dtype;
131
127
mOutputDtype = output_dtype;
132
128
mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale;
133
129
mUseW4GroupScaling = use_w4_group_scaling;
134
- mUseWoqPerChannel = use_woq_per_channel;
135
- mUseWoqGroupScaling = use_woq_group_scaling;
130
+ mUseINT8WoqPerChannel = use_int8_woq_per_channel;
136
131
mUseMxfp8ActScaling = use_mxfp8_act_scaling;
137
132
mUseFusedFinalize = use_fused_finalize;
138
133
mInnerDimMultiplier = 1 ;
@@ -167,7 +162,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
167
162
mInnerDimMultiplier = 16 ; // 16 FP4 -> 1 LONG
168
163
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype );
169
164
}
170
-
171
165
if (isNvfp4Quant ())
172
166
{
173
167
mInnerDimMultiplier = 16 ; // 16 FP4 -> 1 LONG
@@ -182,7 +176,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
182
176
default : mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, false >(mOutputDtype );
183
177
}
184
178
}
185
-
186
179
if (isWFP4A16Quant ())
187
180
{
188
181
mInnerDimMultiplier = 2 ;
@@ -197,9 +190,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
197
190
}
198
191
#endif
199
192
}
200
-
201
193
#endif
202
- if (isWeightOnlyQuant ())
194
+ if (isIntWeightOnlyQuant ())
203
195
{
204
196
if (isInt4Quant ())
205
197
{
@@ -316,8 +308,10 @@ class FusedMoeRunner : public torch::CustomClassHolder
316
308
TORCH_CHECK (fc1_expert_weights.sizes ()[0 ] == fc2_expert_weights.sizes ()[0 ],
317
309
" fc1_expert_weights and fc2_expert_weights must have the same number of experts." );
318
310
319
- if (mUseWoqPerChannel )
311
+ if (mUseINT8WoqPerChannel )
320
312
{
313
+ // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
314
+ // [num_experts, inter_size, hidden_size]
321
315
TORCH_CHECK (fc1_expert_weights.sizes ()[2 ] == fc2_expert_weights.sizes ()[1 ] * mInnerDimMultiplier * 2 ,
322
316
" fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size." );
323
317
}
@@ -331,8 +325,10 @@ class FusedMoeRunner : public torch::CustomClassHolder
331
325
int64_t num_rows = input.sizes ()[0 ];
332
326
int64_t hidden_size = fc2_expert_weights.sizes ()[1 ];
333
327
int64_t inter_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
334
- if (mUseWoqPerChannel )
328
+ if (mUseINT8WoqPerChannel )
335
329
{
330
+ // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
331
+ // [num_experts, inter_size, hidden_size]
336
332
hidden_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
337
333
inter_size = fc2_expert_weights.sizes ()[1 ];
338
334
}
@@ -614,8 +610,10 @@ class FusedMoeRunner : public torch::CustomClassHolder
614
610
int64_t const num_rows = input.sizes ()[0 ];
615
611
int64_t hidden_size = fc2_expert_weights.sizes ()[1 ];
616
612
int64_t inter_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
617
- if (mUseWoqPerChannel )
613
+ if (mUseINT8WoqPerChannel )
618
614
{
615
+ // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
616
+ // [num_experts, inter_size, hidden_size]
619
617
hidden_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
620
618
inter_size = fc2_expert_weights.sizes ()[1 ];
621
619
}
@@ -701,8 +699,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
701
699
702
700
bool mUseDeepSeekFP8BlockScaling = false ;
703
701
bool mUseW4GroupScaling = false ;
704
- bool mUseWoqPerChannel = false ;
705
- bool mUseWoqGroupScaling = false ;
702
+ bool mUseINT8WoqPerChannel = false ;
706
703
bool mUseMxfp8ActScaling = false ;
707
704
bool mUseFusedFinalize = true ;
708
705
@@ -917,7 +914,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
917
914
TORCH_CHECK (false , " MXFP8 x MXFP4 quantization is not supported in OSS Cutlass Moe Gemm" );
918
915
#endif
919
916
}
920
-
921
917
else if (isNvfp4Quant ())
922
918
{
923
919
TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for nvfp4 quantization" );
@@ -990,30 +986,31 @@ class FusedMoeRunner : public torch::CustomClassHolder
990
986
return kernels::QuantParams::FP8BlockScaling (
991
987
static_cast <float const *>(fc1_scales.data_ptr ()), static_cast <float const *>(fc2_scales.data_ptr ()));
992
988
}
993
- else if (isWeightOnlyQuant ())
989
+ else if (isWFP4A16Quant ())
994
990
{
995
991
TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for weight only quantization" );
996
- if (mUseWoqPerChannel )
992
+ TORCH_CHECK (quant_scales.value ().size () == 2 , " Expecting 2 quant scales for W4A16 quantization" );
993
+
994
+ auto & fc1_weight_scales = quant_scales.value ()[0 ];
995
+ auto & fc2_weight_scales = quant_scales.value ()[1 ];
996
+ int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size;
997
+ return kernels::QuantParams::GroupWise (group_size, static_cast <void const *>(fc1_weight_scales.data_ptr ()),
998
+ static_cast <void const *>(fc2_weight_scales.data_ptr ()), nullptr , nullptr , nullptr , nullptr , nullptr ,
999
+ nullptr );
1000
+ }
1001
+ else if (isIntWeightOnlyQuant ())
1002
+ {
1003
+ TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for weight only quantization" );
1004
+ if (mUseINT8WoqPerChannel )
997
1005
{
998
- TORCH_CHECK (quant_scales.value ().size () == 2 , " Expecting 2 quant scales for weight only quantization" );
1006
+ TORCH_CHECK (
1007
+ quant_scales.value ().size () == 2 , " Expecting 2 quant scales for INT8 weight only quantization" );
999
1008
auto & fc1_weight_scales = quant_scales.value ()[0 ];
1000
1009
auto & fc2_weight_scales = quant_scales.value ()[1 ];
1001
1010
return kernels::QuantParams::Int (static_cast <float const *>(fc1_weight_scales.data_ptr ()),
1002
1011
static_cast <float const *>(fc2_weight_scales.data_ptr ()));
1003
1012
}
1004
- // TODO: support groupwise quantization for int8 weight only
1005
- else if (isWFP4A16Quant ())
1006
- {
1007
- TORCH_CHECK (quant_scales.value ().size () == 2 , " Expecting 2 quant scales for W4A16 quantization" );
1008
-
1009
- auto & fc1_weight_scales = quant_scales.value ()[0 ];
1010
- auto & fc2_weight_scales = quant_scales.value ()[1 ];
1011
- int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size;
1012
- return kernels::QuantParams::GroupWise (group_size, static_cast <void const *>(fc1_weight_scales.data_ptr ()),
1013
- static_cast <void const *>(fc2_weight_scales.data_ptr ()), nullptr , nullptr , nullptr , nullptr , nullptr ,
1014
- nullptr );
1015
- }
1016
- else if (isInt4Quant () && mUseWoqGroupScaling )
1013
+ else if (isInt4Quant () && mUseW4GroupScaling )
1017
1014
{
1018
1015
TORCH_CHECK (quant_scales.value ().size () == 8 , " Expecting 8 quant scales for W4A8 quantization" );
1019
1016
@@ -1026,7 +1023,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
1026
1023
auto & fc1_alpha = quant_scales.value ()[6 ];
1027
1024
auto & fc2_alpha = quant_scales.value ()[7 ];
1028
1025
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
1029
- return kernels::QuantParams::GroupWise (group_size, static_cast <void const *>(fc1_weight_scales.data_ptr ()),
1026
+ return kernels::QuantParams::GroupWise (group_size,
1027
+ static_cast <void const *>(fc1_weight_scales.data_ptr ()),
1030
1028
static_cast <void const *>(fc2_weight_scales.data_ptr ()),
1031
1029
static_cast <void const *>(fc1_act_scales.numel () > 0 ? fc1_act_scales.data_ptr () : nullptr ),
1032
1030
static_cast <void const *>(fc2_act_scales.numel () > 0 ? fc2_act_scales.data_ptr () : nullptr ),
@@ -1035,7 +1033,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
1035
1033
static_cast <float const *>(fc1_alpha.numel () > 0 ? fc1_alpha.data_ptr () : nullptr ),
1036
1034
static_cast <float const *>(fc2_alpha.numel () > 0 ? fc2_alpha.data_ptr () : nullptr ));
1037
1035
}
1038
- else
1036
+ else
1039
1037
{
1040
1038
TORCH_CHECK (false , " Unsupported weight only quantization" );
1041
1039
}
@@ -1044,7 +1042,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
1044
1042
{
1045
1043
return kernels::QuantParams{};
1046
1044
}
1047
-
1048
1045
}
1049
1046
1050
1047
bool isFp8Quant () const
@@ -1064,7 +1061,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
1064
1061
return mUseW4GroupScaling && mWeightDtype == c10::ScalarType::Byte;
1065
1062
}
1066
1063
1067
- bool isWInt8Quant () const
1064
+ bool isInt8Quant () const
1068
1065
{
1069
1066
return mWeightDtype == c10::ScalarType::Char;
1070
1067
}
@@ -1079,9 +1076,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
1079
1076
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant ();
1080
1077
}
1081
1078
1082
- bool isWeightOnlyQuant () const
1079
+ bool isIntWeightOnlyQuant () const
1083
1080
{
1084
- return isWInt8Quant () || isInt4Quant ();
1081
+ return isInt8Quant () || isInt4Quant ();
1085
1082
}
1086
1083
1087
1084
bool isWMxfp4AFp8Quant () const
0 commit comments