@@ -94,15 +94,40 @@ class FusedMoeRunner : public torch::CustomClassHolder
94
94
}
95
95
};
96
96
97
+ template <typename TypeAct>
98
+ std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> create_weight_quant_runner ()
99
+ {
100
+ if (isInt8Quant ())
101
+ {
102
+ return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, uint8_t >>();
103
+ }
104
+ else if (isInt4Quant ())
105
+ {
106
+ #ifdef ENABLE_FP8
107
+ if (mUseW4GroupScaling )
108
+ {
109
+ return std::make_unique<
110
+ kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t , TypeAct, TypeAct>>();
111
+ }
112
+ #endif
113
+ return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, cutlass::uint4b_t >>();
114
+ }
115
+ else
116
+ {
117
+ C10_THROW_ERROR_FORMATTED (Error, " Unsupported weight quantization type" );
118
+ }
119
+ }
120
+
97
121
FusedMoeRunner (c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype,
98
- bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_mxfp8_act_scaling ,
99
- bool use_fused_finalize)
122
+ bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_int8_woq_per_channel ,
123
+ bool use_mxfp8_act_scaling, bool use_fused_finalize)
100
124
{
101
125
mActivationDtype = activation_dtype;
102
126
mWeightDtype = weight_dtype;
103
127
mOutputDtype = output_dtype;
104
128
mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale;
105
129
mUseW4GroupScaling = use_w4_group_scaling;
130
+ mUseINT8WoqPerChannel = use_int8_woq_per_channel;
106
131
mUseMxfp8ActScaling = use_mxfp8_act_scaling;
107
132
mUseFusedFinalize = use_fused_finalize;
108
133
mInnerDimMultiplier = 1 ;
@@ -137,7 +162,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
137
162
mInnerDimMultiplier = 16 ; // 16 FP4 -> 1 LONG
138
163
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype );
139
164
}
140
-
141
165
if (isNvfp4Quant ())
142
166
{
143
167
mInnerDimMultiplier = 16 ; // 16 FP4 -> 1 LONG
@@ -152,7 +176,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
152
176
default : mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, false >(mOutputDtype );
153
177
}
154
178
}
155
-
156
179
if (isWFP4A16Quant ())
157
180
{
158
181
mInnerDimMultiplier = 2 ;
@@ -167,45 +190,19 @@ class FusedMoeRunner : public torch::CustomClassHolder
167
190
}
168
191
#endif
169
192
}
170
-
171
193
#endif
172
- if (isInt4Quant ())
194
+ if (isIntWeightOnlyQuant ())
173
195
{
174
- mInnerDimMultiplier = 2 ; // 2 INT4 -> 1 INT8
175
- if (mActivationDtype == c10::ScalarType::Half)
196
+ if (isInt4Quant ())
176
197
{
177
- #ifdef ENABLE_FP8
178
- if (mUseW4GroupScaling )
179
- {
180
- mKernelRunner
181
- = std::make_unique<kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t , half, half>>();
182
- }
183
- else
184
- {
185
- mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t >>();
186
- }
187
- #else
188
- mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t >>();
189
- #endif
198
+ mInnerDimMultiplier = 2 ; // 2 INT4 -> 1 INT8
190
199
}
191
- #ifdef ENABLE_BF16
192
- else if (mActivationDtype == c10::ScalarType::BFloat16)
200
+ switch (mActivationDtype )
193
201
{
194
- #ifdef ENABLE_FP8
195
- if (mUseW4GroupScaling )
196
- {
197
- mKernelRunner = std::make_unique<
198
- kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t , __nv_bfloat16, __nv_bfloat16>>();
199
- }
200
- else
201
- {
202
- mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t >>();
203
- }
204
- #else
205
- mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t >>();
206
- #endif
202
+ case c10::ScalarType::Half: mKernelRunner = create_weight_quant_runner<half>(); break ;
203
+ case c10::ScalarType::BFloat16: mKernelRunner = create_weight_quant_runner<__nv_bfloat16>(); break ;
204
+ default : C10_THROW_ERROR_FORMATTED (Error, " Unsupported activation type for int-type weight" );
207
205
}
208
- #endif
209
206
}
210
207
if (!mKernelRunner )
211
208
{
@@ -310,13 +307,31 @@ class FusedMoeRunner : public torch::CustomClassHolder
310
307
}
311
308
TORCH_CHECK (fc1_expert_weights.sizes ()[0 ] == fc2_expert_weights.sizes ()[0 ],
312
309
" fc1_expert_weights and fc2_expert_weights must have the same number of experts." );
313
- TORCH_CHECK (fc1_expert_weights.sizes ()[1 ] == fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier * 2 ,
314
- " fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size." );
310
+
311
+ if (mUseINT8WoqPerChannel )
312
+ {
313
+ // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
314
+ // [num_experts, inter_size, hidden_size]
315
+ TORCH_CHECK (fc1_expert_weights.sizes ()[2 ] == fc2_expert_weights.sizes ()[1 ] * mInnerDimMultiplier * 2 ,
316
+ " fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size." );
317
+ }
318
+ else
319
+ {
320
+ TORCH_CHECK (fc1_expert_weights.sizes ()[1 ] == fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier * 2 ,
321
+ " fc1_expert_weights inter size must be fc2_expert_weights inter size." );
322
+ }
315
323
316
324
int experts_per_token = token_selected_experts.sizes ()[1 ];
317
325
int64_t num_rows = input.sizes ()[0 ];
318
326
int64_t hidden_size = fc2_expert_weights.sizes ()[1 ];
319
327
int64_t inter_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
328
+ if (mUseINT8WoqPerChannel )
329
+ {
330
+ // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
331
+ // [num_experts, inter_size, hidden_size]
332
+ hidden_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
333
+ inter_size = fc2_expert_weights.sizes ()[1 ];
334
+ }
320
335
321
336
if (isWMxfp4AMxfp8Quant () || isWMxfp4AFp8Quant ())
322
337
{
@@ -593,8 +608,15 @@ class FusedMoeRunner : public torch::CustomClassHolder
593
608
}
594
609
595
610
int64_t const num_rows = input.sizes ()[0 ];
596
- int64_t const hidden_size = fc2_expert_weights.sizes ()[1 ];
597
- int64_t const inter_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
611
+ int64_t hidden_size = fc2_expert_weights.sizes ()[1 ];
612
+ int64_t inter_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
613
+ if (mUseINT8WoqPerChannel )
614
+ {
615
+ // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
616
+ // [num_experts, inter_size, hidden_size]
617
+ hidden_size = fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ;
618
+ inter_size = fc2_expert_weights.sizes ()[1 ];
619
+ }
598
620
int64_t const group_size_
599
621
= isInt4Quant () ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size : -1 ;
600
622
int64_t const group_size = isWFP4A16Quant ()
@@ -677,6 +699,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
677
699
678
700
bool mUseDeepSeekFP8BlockScaling = false ;
679
701
bool mUseW4GroupScaling = false ;
702
+ bool mUseINT8WoqPerChannel = false ;
680
703
bool mUseMxfp8ActScaling = false ;
681
704
bool mUseFusedFinalize = true ;
682
705
@@ -891,7 +914,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
891
914
TORCH_CHECK (false , " MXFP8 x MXFP4 quantization is not supported in OSS Cutlass Moe Gemm" );
892
915
#endif
893
916
}
894
-
895
917
else if (isNvfp4Quant ())
896
918
{
897
919
TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for nvfp4 quantization" );
@@ -966,8 +988,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
966
988
}
967
989
else if (isWFP4A16Quant ())
968
990
{
969
- TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for W4 quantization" );
970
- TORCH_CHECK (quant_scales.value ().size () == 2 , " Expecting 8 quant scales for W4A16 quantization" );
991
+ TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for weight only quantization" );
992
+ TORCH_CHECK (quant_scales.value ().size () == 2 , " Expecting 2 quant scales for W4A16 quantization" );
971
993
972
994
auto & fc1_weight_scales = quant_scales.value ()[0 ];
973
995
auto & fc2_weight_scales = quant_scales.value ()[1 ];
@@ -976,28 +998,45 @@ class FusedMoeRunner : public torch::CustomClassHolder
976
998
static_cast <void const *>(fc2_weight_scales.data_ptr ()), nullptr , nullptr , nullptr , nullptr , nullptr ,
977
999
nullptr );
978
1000
}
979
- else if (isInt4Quant ())
1001
+ else if (isIntWeightOnlyQuant ())
980
1002
{
981
- TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for W4 quantization" );
982
- TORCH_CHECK (quant_scales.value ().size () == 8 , " Expecting 8 quant scales for W4A8 quantization" );
983
-
984
- auto & fc1_weight_scales = quant_scales.value ()[0 ];
985
- auto & fc2_weight_scales = quant_scales.value ()[1 ];
986
- auto & fc1_act_scales = quant_scales.value ()[2 ];
987
- auto & fc2_act_scales = quant_scales.value ()[3 ];
988
- auto & fc1_weight_zeros = quant_scales.value ()[4 ];
989
- auto & fc2_weight_zeros = quant_scales.value ()[5 ];
990
- auto & fc1_alpha = quant_scales.value ()[6 ];
991
- auto & fc2_alpha = quant_scales.value ()[7 ];
992
- int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
993
- return kernels::QuantParams::GroupWise (group_size, static_cast <void const *>(fc1_weight_scales.data_ptr ()),
994
- static_cast <void const *>(fc2_weight_scales.data_ptr ()),
995
- static_cast <void const *>(fc1_act_scales.numel () > 0 ? fc1_act_scales.data_ptr () : nullptr ),
996
- static_cast <void const *>(fc2_act_scales.numel () > 0 ? fc2_act_scales.data_ptr () : nullptr ),
997
- static_cast <void const *>(fc1_weight_zeros.numel () > 0 ? fc1_weight_zeros.data_ptr () : nullptr ),
998
- static_cast <void const *>(fc2_weight_zeros.numel () > 0 ? fc2_weight_zeros.data_ptr () : nullptr ),
999
- static_cast <float const *>(fc1_alpha.numel () > 0 ? fc1_alpha.data_ptr () : nullptr ),
1000
- static_cast <float const *>(fc2_alpha.numel () > 0 ? fc2_alpha.data_ptr () : nullptr ));
1003
+ TORCH_CHECK (quant_scales.has_value (), " Expecting quant scales for weight only quantization" );
1004
+ if (mUseINT8WoqPerChannel )
1005
+ {
1006
+ TORCH_CHECK (
1007
+ quant_scales.value ().size () == 2 , " Expecting 2 quant scales for INT8 weight only quantization" );
1008
+ auto & fc1_weight_scales = quant_scales.value ()[0 ];
1009
+ auto & fc2_weight_scales = quant_scales.value ()[1 ];
1010
+ return kernels::QuantParams::Int (static_cast <float const *>(fc1_weight_scales.data_ptr ()),
1011
+ static_cast <float const *>(fc2_weight_scales.data_ptr ()));
1012
+ }
1013
+ else if (isInt4Quant () && mUseW4GroupScaling )
1014
+ {
1015
+ TORCH_CHECK (quant_scales.value ().size () == 8 , " Expecting 8 quant scales for W4A8 quantization" );
1016
+
1017
+ auto & fc1_weight_scales = quant_scales.value ()[0 ];
1018
+ auto & fc2_weight_scales = quant_scales.value ()[1 ];
1019
+ auto & fc1_act_scales = quant_scales.value ()[2 ];
1020
+ auto & fc2_act_scales = quant_scales.value ()[3 ];
1021
+ auto & fc1_weight_zeros = quant_scales.value ()[4 ];
1022
+ auto & fc2_weight_zeros = quant_scales.value ()[5 ];
1023
+ auto & fc1_alpha = quant_scales.value ()[6 ];
1024
+ auto & fc2_alpha = quant_scales.value ()[7 ];
1025
+ int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
1026
+ return kernels::QuantParams::GroupWise (group_size,
1027
+ static_cast <void const *>(fc1_weight_scales.data_ptr ()),
1028
+ static_cast <void const *>(fc2_weight_scales.data_ptr ()),
1029
+ static_cast <void const *>(fc1_act_scales.numel () > 0 ? fc1_act_scales.data_ptr () : nullptr ),
1030
+ static_cast <void const *>(fc2_act_scales.numel () > 0 ? fc2_act_scales.data_ptr () : nullptr ),
1031
+ static_cast <void const *>(fc1_weight_zeros.numel () > 0 ? fc1_weight_zeros.data_ptr () : nullptr ),
1032
+ static_cast <void const *>(fc2_weight_zeros.numel () > 0 ? fc2_weight_zeros.data_ptr () : nullptr ),
1033
+ static_cast <float const *>(fc1_alpha.numel () > 0 ? fc1_alpha.data_ptr () : nullptr ),
1034
+ static_cast <float const *>(fc2_alpha.numel () > 0 ? fc2_alpha.data_ptr () : nullptr ));
1035
+ }
1036
+ else
1037
+ {
1038
+ TORCH_CHECK (false , " Unsupported weight only quantization" );
1039
+ }
1001
1040
}
1002
1041
else
1003
1042
{
@@ -1022,6 +1061,11 @@ class FusedMoeRunner : public torch::CustomClassHolder
1022
1061
return mUseW4GroupScaling && mWeightDtype == c10::ScalarType::Byte;
1023
1062
}
1024
1063
1064
+ bool isInt8Quant () const
1065
+ {
1066
+ return mWeightDtype == c10::ScalarType::Char;
1067
+ }
1068
+
1025
1069
bool isInt4Quant () const
1026
1070
{
1027
1071
return mWeightDtype == c10::ScalarType::QUInt4x2;
@@ -1032,6 +1076,11 @@ class FusedMoeRunner : public torch::CustomClassHolder
1032
1076
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant ();
1033
1077
}
1034
1078
1079
+ bool isIntWeightOnlyQuant () const
1080
+ {
1081
+ return isInt8Quant () || isInt4Quant ();
1082
+ }
1083
+
1035
1084
bool isWMxfp4AFp8Quant () const
1036
1085
{
1037
1086
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long
@@ -1050,7 +1099,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
1050
1099
TORCH_LIBRARY (trtllm, m)
1051
1100
{
1052
1101
m.class_ <torch_ext::FusedMoeRunner>(" FusedMoeRunner" )
1053
- .def (torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool , bool , bool , bool >())
1102
+ .def (torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool , bool , bool , bool , bool >())
1054
1103
.def (" run_gemm_profile" , &torch_ext::FusedMoeRunner::runGemmProfile)
1055
1104
.def (" get_tactic_num" , &torch_ext::FusedMoeRunner::getTacticNum)
1056
1105
.def (" run_moe" , &torch_ext::FusedMoeRunner::runMoe)
0 commit comments