Skip to content

Commit cd51c39

Browse files
Yuening-wadominicshanshan
authored andcommitted
[TRTLLM-5863][feat] Support MoE INT8 Weight-Only-Quantization in PyTorch Workflow (NVIDIA#6629)
Signed-off-by: Yuening Li <[email protected]>
1 parent 18c447e commit cd51c39

File tree

9 files changed

+544
-113
lines changed

9 files changed

+544
-113
lines changed

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 115 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,40 @@ class FusedMoeRunner : public torch::CustomClassHolder
9494
}
9595
};
9696

97+
template <typename TypeAct>
98+
std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> create_weight_quant_runner()
99+
{
100+
if (isInt8Quant())
101+
{
102+
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, uint8_t>>();
103+
}
104+
else if (isInt4Quant())
105+
{
106+
#ifdef ENABLE_FP8
107+
if (mUseW4GroupScaling)
108+
{
109+
return std::make_unique<
110+
kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, TypeAct, TypeAct>>();
111+
}
112+
#endif
113+
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, cutlass::uint4b_t>>();
114+
}
115+
else
116+
{
117+
C10_THROW_ERROR_FORMATTED(Error, "Unsupported weight quantization type");
118+
}
119+
}
120+
97121
FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype,
98-
bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_mxfp8_act_scaling,
99-
bool use_fused_finalize)
122+
bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_int8_woq_per_channel,
123+
bool use_mxfp8_act_scaling, bool use_fused_finalize)
100124
{
101125
mActivationDtype = activation_dtype;
102126
mWeightDtype = weight_dtype;
103127
mOutputDtype = output_dtype;
104128
mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale;
105129
mUseW4GroupScaling = use_w4_group_scaling;
130+
mUseINT8WoqPerChannel = use_int8_woq_per_channel;
106131
mUseMxfp8ActScaling = use_mxfp8_act_scaling;
107132
mUseFusedFinalize = use_fused_finalize;
108133
mInnerDimMultiplier = 1;
@@ -137,7 +162,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
137162
mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG
138163
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype);
139164
}
140-
141165
if (isNvfp4Quant())
142166
{
143167
mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG
@@ -152,7 +176,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
152176
default: mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, false>(mOutputDtype);
153177
}
154178
}
155-
156179
if (isWFP4A16Quant())
157180
{
158181
mInnerDimMultiplier = 2;
@@ -167,45 +190,19 @@ class FusedMoeRunner : public torch::CustomClassHolder
167190
}
168191
#endif
169192
}
170-
171193
#endif
172-
if (isInt4Quant())
194+
if (isIntWeightOnlyQuant())
173195
{
174-
mInnerDimMultiplier = 2; // 2 INT4 -> 1 INT8
175-
if (mActivationDtype == c10::ScalarType::Half)
196+
if (isInt4Quant())
176197
{
177-
#ifdef ENABLE_FP8
178-
if (mUseW4GroupScaling)
179-
{
180-
mKernelRunner
181-
= std::make_unique<kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>>();
182-
}
183-
else
184-
{
185-
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t>>();
186-
}
187-
#else
188-
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t>>();
189-
#endif
198+
mInnerDimMultiplier = 2; // 2 INT4 -> 1 INT8
190199
}
191-
#ifdef ENABLE_BF16
192-
else if (mActivationDtype == c10::ScalarType::BFloat16)
200+
switch (mActivationDtype)
193201
{
194-
#ifdef ENABLE_FP8
195-
if (mUseW4GroupScaling)
196-
{
197-
mKernelRunner = std::make_unique<
198-
kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>>();
199-
}
200-
else
201-
{
202-
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>>();
203-
}
204-
#else
205-
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>>();
206-
#endif
202+
case c10::ScalarType::Half: mKernelRunner = create_weight_quant_runner<half>(); break;
203+
case c10::ScalarType::BFloat16: mKernelRunner = create_weight_quant_runner<__nv_bfloat16>(); break;
204+
default: C10_THROW_ERROR_FORMATTED(Error, "Unsupported activation type for int-type weight");
207205
}
208-
#endif
209206
}
210207
if (!mKernelRunner)
211208
{
@@ -310,13 +307,31 @@ class FusedMoeRunner : public torch::CustomClassHolder
310307
}
311308
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0],
312309
"fc1_expert_weights and fc2_expert_weights must have the same number of experts.");
313-
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
314-
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
310+
311+
if (mUseINT8WoqPerChannel)
312+
{
313+
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
314+
// [num_experts, inter_size, hidden_size]
315+
TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2,
316+
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
317+
}
318+
else
319+
{
320+
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
321+
"fc1_expert_weights inter size must be fc2_expert_weights inter size.");
322+
}
315323

316324
int experts_per_token = token_selected_experts.sizes()[1];
317325
int64_t num_rows = input.sizes()[0];
318326
int64_t hidden_size = fc2_expert_weights.sizes()[1];
319327
int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
328+
if (mUseINT8WoqPerChannel)
329+
{
330+
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
331+
// [num_experts, inter_size, hidden_size]
332+
hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
333+
inter_size = fc2_expert_weights.sizes()[1];
334+
}
320335

321336
if (isWMxfp4AMxfp8Quant() || isWMxfp4AFp8Quant())
322337
{
@@ -593,8 +608,15 @@ class FusedMoeRunner : public torch::CustomClassHolder
593608
}
594609

595610
int64_t const num_rows = input.sizes()[0];
596-
int64_t const hidden_size = fc2_expert_weights.sizes()[1];
597-
int64_t const inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
611+
int64_t hidden_size = fc2_expert_weights.sizes()[1];
612+
int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
613+
if (mUseINT8WoqPerChannel)
614+
{
615+
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
616+
// [num_experts, inter_size, hidden_size]
617+
hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
618+
inter_size = fc2_expert_weights.sizes()[1];
619+
}
598620
int64_t const group_size_
599621
= isInt4Quant() ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size : -1;
600622
int64_t const group_size = isWFP4A16Quant()
@@ -677,6 +699,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
677699

678700
bool mUseDeepSeekFP8BlockScaling = false;
679701
bool mUseW4GroupScaling = false;
702+
bool mUseINT8WoqPerChannel = false;
680703
bool mUseMxfp8ActScaling = false;
681704
bool mUseFusedFinalize = true;
682705

@@ -891,7 +914,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
891914
TORCH_CHECK(false, "MXFP8 x MXFP4 quantization is not supported in OSS Cutlass Moe Gemm");
892915
#endif
893916
}
894-
895917
else if (isNvfp4Quant())
896918
{
897919
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for nvfp4 quantization");
@@ -966,8 +988,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
966988
}
967989
else if (isWFP4A16Quant())
968990
{
969-
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for W4 quantization");
970-
TORCH_CHECK(quant_scales.value().size() == 2, "Expecting 8 quant scales for W4A16 quantization");
991+
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for weight only quantization");
992+
TORCH_CHECK(quant_scales.value().size() == 2, "Expecting 2 quant scales for W4A16 quantization");
971993

972994
auto& fc1_weight_scales = quant_scales.value()[0];
973995
auto& fc2_weight_scales = quant_scales.value()[1];
@@ -976,28 +998,45 @@ class FusedMoeRunner : public torch::CustomClassHolder
976998
static_cast<void const*>(fc2_weight_scales.data_ptr()), nullptr, nullptr, nullptr, nullptr, nullptr,
977999
nullptr);
9781000
}
979-
else if (isInt4Quant())
1001+
else if (isIntWeightOnlyQuant())
9801002
{
981-
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for W4 quantization");
982-
TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for W4A8 quantization");
983-
984-
auto& fc1_weight_scales = quant_scales.value()[0];
985-
auto& fc2_weight_scales = quant_scales.value()[1];
986-
auto& fc1_act_scales = quant_scales.value()[2];
987-
auto& fc2_act_scales = quant_scales.value()[3];
988-
auto& fc1_weight_zeros = quant_scales.value()[4];
989-
auto& fc2_weight_zeros = quant_scales.value()[5];
990-
auto& fc1_alpha = quant_scales.value()[6];
991-
auto& fc2_alpha = quant_scales.value()[7];
992-
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
993-
return kernels::QuantParams::GroupWise(group_size, static_cast<void const*>(fc1_weight_scales.data_ptr()),
994-
static_cast<void const*>(fc2_weight_scales.data_ptr()),
995-
static_cast<void const*>(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() : nullptr),
996-
static_cast<void const*>(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() : nullptr),
997-
static_cast<void const*>(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr),
998-
static_cast<void const*>(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr),
999-
static_cast<float const*>(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr),
1000-
static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr));
1003+
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for weight only quantization");
1004+
if (mUseINT8WoqPerChannel)
1005+
{
1006+
TORCH_CHECK(
1007+
quant_scales.value().size() == 2, "Expecting 2 quant scales for INT8 weight only quantization");
1008+
auto& fc1_weight_scales = quant_scales.value()[0];
1009+
auto& fc2_weight_scales = quant_scales.value()[1];
1010+
return kernels::QuantParams::Int(static_cast<float const*>(fc1_weight_scales.data_ptr()),
1011+
static_cast<float const*>(fc2_weight_scales.data_ptr()));
1012+
}
1013+
else if (isInt4Quant() && mUseW4GroupScaling)
1014+
{
1015+
TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for W4A8 quantization");
1016+
1017+
auto& fc1_weight_scales = quant_scales.value()[0];
1018+
auto& fc2_weight_scales = quant_scales.value()[1];
1019+
auto& fc1_act_scales = quant_scales.value()[2];
1020+
auto& fc2_act_scales = quant_scales.value()[3];
1021+
auto& fc1_weight_zeros = quant_scales.value()[4];
1022+
auto& fc2_weight_zeros = quant_scales.value()[5];
1023+
auto& fc1_alpha = quant_scales.value()[6];
1024+
auto& fc2_alpha = quant_scales.value()[7];
1025+
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
1026+
return kernels::QuantParams::GroupWise(group_size,
1027+
static_cast<void const*>(fc1_weight_scales.data_ptr()),
1028+
static_cast<void const*>(fc2_weight_scales.data_ptr()),
1029+
static_cast<void const*>(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() : nullptr),
1030+
static_cast<void const*>(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() : nullptr),
1031+
static_cast<void const*>(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr),
1032+
static_cast<void const*>(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr),
1033+
static_cast<float const*>(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr),
1034+
static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr));
1035+
}
1036+
else
1037+
{
1038+
TORCH_CHECK(false, "Unsupported weight only quantization");
1039+
}
10011040
}
10021041
else
10031042
{
@@ -1022,6 +1061,11 @@ class FusedMoeRunner : public torch::CustomClassHolder
10221061
return mUseW4GroupScaling && mWeightDtype == c10::ScalarType::Byte;
10231062
}
10241063

1064+
bool isInt8Quant() const
1065+
{
1066+
return mWeightDtype == c10::ScalarType::Char;
1067+
}
1068+
10251069
bool isInt4Quant() const
10261070
{
10271071
return mWeightDtype == c10::ScalarType::QUInt4x2;
@@ -1032,6 +1076,11 @@ class FusedMoeRunner : public torch::CustomClassHolder
10321076
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant();
10331077
}
10341078

1079+
bool isIntWeightOnlyQuant() const
1080+
{
1081+
return isInt8Quant() || isInt4Quant();
1082+
}
1083+
10351084
bool isWMxfp4AFp8Quant() const
10361085
{
10371086
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long
@@ -1050,7 +1099,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
10501099
TORCH_LIBRARY(trtllm, m)
10511100
{
10521101
m.class_<torch_ext::FusedMoeRunner>("FusedMoeRunner")
1053-
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool, bool>())
1102+
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool, bool, bool>())
10541103
.def("run_gemm_profile", &torch_ext::FusedMoeRunner::runGemmProfile)
10551104
.def("get_tactic_num", &torch_ext::FusedMoeRunner::getTacticNum)
10561105
.def("run_moe", &torch_ext::FusedMoeRunner::runMoe)

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
cluster_rank: int,
4444
use_deepseek_fp8_block_scale: bool,
4545
use_w4_group_scaling: bool,
46+
use_int8_woq_per_channel: bool,
4647
use_mxfp8_act_scaling: bool,
4748
min_latency_mode: bool,
4849
use_fused_finalize: bool,
@@ -61,20 +62,22 @@ def __init__(
6162
self.enable_alltoall = False
6263
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
6364
self.use_w4_group_scaling = use_w4_group_scaling
65+
self.use_int8_woq_per_channel = use_int8_woq_per_channel
6466
self.use_mxfp8_act_scaling = use_mxfp8_act_scaling
6567
self.min_latency_mode = min_latency_mode
6668
self.use_fused_finalize = use_fused_finalize
6769

6870
instance_key = (x_dtype, weight_dtype, output_dtype,
6971
use_deepseek_fp8_block_scale, use_w4_group_scaling,
70-
use_mxfp8_act_scaling)
72+
use_int8_woq_per_channel, use_mxfp8_act_scaling)
7173

7274
if instance_key not in MoERunner.runner_dict:
7375
MoERunner.runner_dict[
7476
instance_key] = torch.classes.trtllm.FusedMoeRunner(
7577
x_dtype, weight_dtype, output_dtype,
7678
use_deepseek_fp8_block_scale, use_w4_group_scaling,
77-
use_mxfp8_act_scaling, use_fused_finalize)
79+
use_int8_woq_per_channel, use_mxfp8_act_scaling,
80+
use_fused_finalize)
7881
self.fused_moe_runner = MoERunner.runner_dict[instance_key]
7982

8083
def get_valid_tactics(
@@ -138,6 +141,7 @@ def fused_moe(
138141
enable_alltoall: bool = False,
139142
use_deepseek_fp8_block_scale: bool = False,
140143
use_w4_group_scaling: bool = False,
144+
use_int8_woq_per_channel: bool = False,
141145
use_mxfp8_act_scaling: bool = False,
142146
min_latency_mode: bool = False,
143147
use_fused_finalize: bool = True,
@@ -174,6 +178,7 @@ def fused_moe(
174178
cluster_rank=cluster_rank,
175179
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
176180
use_w4_group_scaling=use_w4_group_scaling,
181+
use_int8_woq_per_channel=use_int8_woq_per_channel,
177182
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
178183
min_latency_mode=min_latency_mode,
179184
use_fused_finalize=use_fused_finalize,
@@ -257,13 +262,19 @@ def _(
257262
enable_alltoall: bool = False,
258263
use_deepseek_fp8_block_scale: bool = False,
259264
use_w4_group_scaling: bool = False,
265+
use_int8_woq_per_channel: bool = False,
260266
use_mxfp8_act_scaling: bool = False,
261267
min_latency_mode: bool = False,
262268
use_fused_finalize: bool = True,
263269
tune_max_num_tokens: int = 8192,
264270
):
265271
seq_len = input.shape[0]
266-
hidden_size = fc2_expert_weights.shape[1]
272+
if use_int8_woq_per_channel:
273+
# Note: The weight shape for INT8 weight only quantization is different, i.e.,
274+
# fc2_expert_weights: [num_experts, inter_size, hidden_size]
275+
hidden_size = fc2_expert_weights.shape[2]
276+
else:
277+
hidden_size = fc2_expert_weights.shape[1]
267278

268279
if min_latency_mode:
269280
num_experts_on_rank = fc2_expert_weights.shape[0]

0 commit comments

Comments
 (0)