Skip to content

Commit f1aa234

Browse files
committed
solve conflicts and revise the code
Signed-off-by: Yuening Li <[email protected]>
1 parent e99c071 commit f1aa234

File tree

7 files changed

+250
-463
lines changed

7 files changed

+250
-463
lines changed

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
9797
template <typename TypeAct>
9898
std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> create_weight_quant_runner()
9999
{
100-
if (isWInt8Quant())
100+
if (isInt8Quant())
101101
{
102102
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, uint8_t>>();
103103
}
@@ -109,10 +109,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
109109
return std::make_unique<
110110
kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, TypeAct, TypeAct>>();
111111
}
112-
else
113-
{
114-
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, cutlass::uint4b_t>>();
115-
}
116112
#endif
117113
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, cutlass::uint4b_t>>();
118114
}
@@ -123,16 +119,15 @@ class FusedMoeRunner : public torch::CustomClassHolder
123119
}
124120

125121
FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype,
126-
bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_woq_group_scaling,
122+
bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_int8_woq_per_channel,
127123
bool use_mxfp8_act_scaling, bool use_fused_finalize)
128124
{
129125
mActivationDtype = activation_dtype;
130126
mWeightDtype = weight_dtype;
131127
mOutputDtype = output_dtype;
132128
mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale;
133129
mUseW4GroupScaling = use_w4_group_scaling;
134-
mUseWoqPerChannel = use_woq_per_channel;
135-
mUseWoqGroupScaling = use_woq_group_scaling;
130+
mUseINT8WoqPerChannel = use_int8_woq_per_channel;
136131
mUseMxfp8ActScaling = use_mxfp8_act_scaling;
137132
mUseFusedFinalize = use_fused_finalize;
138133
mInnerDimMultiplier = 1;
@@ -167,7 +162,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
167162
mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG
168163
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype);
169164
}
170-
171165
if (isNvfp4Quant())
172166
{
173167
mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG
@@ -182,7 +176,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
182176
default: mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, false>(mOutputDtype);
183177
}
184178
}
185-
186179
if (isWFP4A16Quant())
187180
{
188181
mInnerDimMultiplier = 2;
@@ -197,9 +190,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
197190
}
198191
#endif
199192
}
200-
201193
#endif
202-
if (isWeightOnlyQuant())
194+
if (isIntWeightOnlyQuant())
203195
{
204196
if (isInt4Quant())
205197
{
@@ -316,8 +308,10 @@ class FusedMoeRunner : public torch::CustomClassHolder
316308
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0],
317309
"fc1_expert_weights and fc2_expert_weights must have the same number of experts.");
318310

319-
if (mUseWoqPerChannel)
311+
if (mUseINT8WoqPerChannel)
320312
{
313+
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
314+
// [num_experts, inter_size, hidden_size]
321315
TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2,
322316
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
323317
}
@@ -331,8 +325,10 @@ class FusedMoeRunner : public torch::CustomClassHolder
331325
int64_t num_rows = input.sizes()[0];
332326
int64_t hidden_size = fc2_expert_weights.sizes()[1];
333327
int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
334-
if (mUseWoqPerChannel)
328+
if (mUseINT8WoqPerChannel)
335329
{
330+
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
331+
// [num_experts, inter_size, hidden_size]
336332
hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
337333
inter_size = fc2_expert_weights.sizes()[1];
338334
}
@@ -614,8 +610,10 @@ class FusedMoeRunner : public torch::CustomClassHolder
614610
int64_t const num_rows = input.sizes()[0];
615611
int64_t hidden_size = fc2_expert_weights.sizes()[1];
616612
int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
617-
if (mUseWoqPerChannel)
613+
if (mUseINT8WoqPerChannel)
618614
{
615+
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
616+
// [num_experts, inter_size, hidden_size]
619617
hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
620618
inter_size = fc2_expert_weights.sizes()[1];
621619
}
@@ -701,8 +699,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
701699

702700
bool mUseDeepSeekFP8BlockScaling = false;
703701
bool mUseW4GroupScaling = false;
704-
bool mUseWoqPerChannel = false;
705-
bool mUseWoqGroupScaling = false;
702+
bool mUseINT8WoqPerChannel = false;
706703
bool mUseMxfp8ActScaling = false;
707704
bool mUseFusedFinalize = true;
708705

@@ -917,7 +914,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
917914
TORCH_CHECK(false, "MXFP8 x MXFP4 quantization is not supported in OSS Cutlass Moe Gemm");
918915
#endif
919916
}
920-
921917
else if (isNvfp4Quant())
922918
{
923919
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for nvfp4 quantization");
@@ -990,30 +986,31 @@ class FusedMoeRunner : public torch::CustomClassHolder
990986
return kernels::QuantParams::FP8BlockScaling(
991987
static_cast<float const*>(fc1_scales.data_ptr()), static_cast<float const*>(fc2_scales.data_ptr()));
992988
}
993-
else if (isWeightOnlyQuant())
989+
else if (isWFP4A16Quant())
994990
{
995991
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for weight only quantization");
996-
if (mUseWoqPerChannel)
992+
TORCH_CHECK(quant_scales.value().size() == 2, "Expecting 2 quant scales for W4A16 quantization");
993+
994+
auto& fc1_weight_scales = quant_scales.value()[0];
995+
auto& fc2_weight_scales = quant_scales.value()[1];
996+
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size;
997+
return kernels::QuantParams::GroupWise(group_size, static_cast<void const*>(fc1_weight_scales.data_ptr()),
998+
static_cast<void const*>(fc2_weight_scales.data_ptr()), nullptr, nullptr, nullptr, nullptr, nullptr,
999+
nullptr);
1000+
}
1001+
else if (isIntWeightOnlyQuant())
1002+
{
1003+
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for weight only quantization");
1004+
if (mUseINT8WoqPerChannel)
9971005
{
998-
TORCH_CHECK(quant_scales.value().size() == 2, "Expecting 2 quant scales for weight only quantization");
1006+
TORCH_CHECK(
1007+
quant_scales.value().size() == 2, "Expecting 2 quant scales for INT8 weight only quantization");
9991008
auto& fc1_weight_scales = quant_scales.value()[0];
10001009
auto& fc2_weight_scales = quant_scales.value()[1];
10011010
return kernels::QuantParams::Int(static_cast<float const*>(fc1_weight_scales.data_ptr()),
10021011
static_cast<float const*>(fc2_weight_scales.data_ptr()));
10031012
}
1004-
// TODO: support groupwise quantization for int8 weight only
1005-
else if (isWFP4A16Quant())
1006-
{
1007-
TORCH_CHECK(quant_scales.value().size() == 2, "Expecting 2 quant scales for W4A16 quantization");
1008-
1009-
auto& fc1_weight_scales = quant_scales.value()[0];
1010-
auto& fc2_weight_scales = quant_scales.value()[1];
1011-
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size;
1012-
return kernels::QuantParams::GroupWise(group_size, static_cast<void const*>(fc1_weight_scales.data_ptr()),
1013-
static_cast<void const*>(fc2_weight_scales.data_ptr()), nullptr, nullptr, nullptr, nullptr, nullptr,
1014-
nullptr);
1015-
}
1016-
else if (isInt4Quant() && mUseWoqGroupScaling)
1013+
else if (isInt4Quant() && mUseW4GroupScaling)
10171014
{
10181015
TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for W4A8 quantization");
10191016

@@ -1026,7 +1023,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
10261023
auto& fc1_alpha = quant_scales.value()[6];
10271024
auto& fc2_alpha = quant_scales.value()[7];
10281025
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
1029-
return kernels::QuantParams::GroupWise(group_size, static_cast<void const*>(fc1_weight_scales.data_ptr()),
1026+
return kernels::QuantParams::GroupWise(group_size,
1027+
static_cast<void const*>(fc1_weight_scales.data_ptr()),
10301028
static_cast<void const*>(fc2_weight_scales.data_ptr()),
10311029
static_cast<void const*>(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() : nullptr),
10321030
static_cast<void const*>(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() : nullptr),
@@ -1035,7 +1033,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
10351033
static_cast<float const*>(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr),
10361034
static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr));
10371035
}
1038-
else
1036+
else
10391037
{
10401038
TORCH_CHECK(false, "Unsupported weight only quantization");
10411039
}
@@ -1044,7 +1042,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
10441042
{
10451043
return kernels::QuantParams{};
10461044
}
1047-
10481045
}
10491046

10501047
bool isFp8Quant() const
@@ -1064,7 +1061,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
10641061
return mUseW4GroupScaling && mWeightDtype == c10::ScalarType::Byte;
10651062
}
10661063

1067-
bool isWInt8Quant() const
1064+
bool isInt8Quant() const
10681065
{
10691066
return mWeightDtype == c10::ScalarType::Char;
10701067
}
@@ -1079,9 +1076,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
10791076
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant();
10801077
}
10811078

1082-
bool isWeightOnlyQuant() const
1079+
bool isIntWeightOnlyQuant() const
10831080
{
1084-
return isWInt8Quant() || isInt4Quant();
1081+
return isInt8Quant() || isInt4Quant();
10851082
}
10861083

10871084
bool isWMxfp4AFp8Quant() const

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ def __init__(
4343
cluster_rank: int,
4444
use_deepseek_fp8_block_scale: bool,
4545
use_w4_group_scaling: bool,
46-
use_woq_per_channel: bool,
47-
use_woq_group_scaling: bool,
46+
use_int8_woq_per_channel: bool,
4847
use_mxfp8_act_scaling: bool,
4948
min_latency_mode: bool,
5049
use_fused_finalize: bool,
@@ -63,23 +62,21 @@ def __init__(
6362
self.enable_alltoall = False
6463
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
6564
self.use_w4_group_scaling = use_w4_group_scaling
66-
self.use_woq_per_channel = use_woq_per_channel
67-
self.use_woq_group_scaling = use_woq_group_scaling
65+
self.use_int8_woq_per_channel = use_int8_woq_per_channel
6866
self.use_mxfp8_act_scaling = use_mxfp8_act_scaling
6967
self.min_latency_mode = min_latency_mode
7068
self.use_fused_finalize = use_fused_finalize
7169

7270
instance_key = (x_dtype, weight_dtype, output_dtype,
7371
use_deepseek_fp8_block_scale, use_w4_group_scaling,
74-
use_woq_per_channel, use_woq_group_scaling,
75-
use_mxfp8_act_scaling)
72+
use_int8_woq_per_channel, use_mxfp8_act_scaling)
7673

7774
if instance_key not in MoERunner.runner_dict:
7875
MoERunner.runner_dict[
7976
instance_key] = torch.classes.trtllm.FusedMoeRunner(
8077
x_dtype, weight_dtype, output_dtype,
8178
use_deepseek_fp8_block_scale, use_w4_group_scaling,
82-
use_woq_per_channel, use_mxfp8_act_scaling,
79+
use_int8_woq_per_channel, use_mxfp8_act_scaling,
8380
use_fused_finalize)
8481
self.fused_moe_runner = MoERunner.runner_dict[instance_key]
8582

@@ -144,8 +141,7 @@ def fused_moe(
144141
enable_alltoall: bool = False,
145142
use_deepseek_fp8_block_scale: bool = False,
146143
use_w4_group_scaling: bool = False,
147-
use_woq_per_channel: bool = False,
148-
use_woq_group_scaling: bool = False,
144+
use_int8_woq_per_channel: bool = False,
149145
use_mxfp8_act_scaling: bool = False,
150146
min_latency_mode: bool = False,
151147
use_fused_finalize: bool = True,
@@ -182,8 +178,7 @@ def fused_moe(
182178
cluster_rank=cluster_rank,
183179
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
184180
use_w4_group_scaling=use_w4_group_scaling,
185-
use_woq_per_channel=use_woq_per_channel,
186-
use_woq_group_scaling=use_woq_group_scaling,
181+
use_int8_woq_per_channel=use_int8_woq_per_channel,
187182
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
188183
min_latency_mode=min_latency_mode,
189184
use_fused_finalize=use_fused_finalize,
@@ -267,15 +262,16 @@ def _(
267262
enable_alltoall: bool = False,
268263
use_deepseek_fp8_block_scale: bool = False,
269264
use_w4_group_scaling: bool = False,
270-
use_woq_per_channel: bool = False,
271-
use_woq_group_scaling: bool = False,
265+
use_int8_woq_per_channel: bool = False,
272266
use_mxfp8_act_scaling: bool = False,
273267
min_latency_mode: bool = False,
274268
use_fused_finalize: bool = True,
275269
tune_max_num_tokens: int = 8192,
276270
):
277271
seq_len = input.shape[0]
278-
if use_woq_per_channel:
272+
if use_int8_woq_per_channel:
273+
# Note: The weight shape for INT8 weight only quantization is different, i.e.,
274+
# fc2_expert_weights: [num_experts, inter_size, hidden_size]
279275
hidden_size = fc2_expert_weights.shape[2]
280276
else:
281277
hidden_size = fc2_expert_weights.shape[1]

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
# isort: off
1717
from .quantization import (
1818
DeepSeekFP8BlockScalesFusedMoEMethod, FP8QDQFusedMoEMethod,
19-
MoEWeightLoadingMode, NVFP4CutlassFusedMoEMethod, UnquantizedFusedMoEMethod, WeightOnlyFusedMoEMethod,
20-
21-
W4A8MXFP4FP8CutlassFusedMoEMethod, W4A8MXFP4MXFP8CutlassFusedMoEMethod,
22-
WFP4A16FusedMoEMethod, WInt4AFP8FusedMoEMethod)
19+
MoEWeightLoadingMode, NVFP4CutlassFusedMoEMethod, UnquantizedFusedMoEMethod,
20+
INT8WoqPerChannelFusedMoEMethod, W4A8MXFP4FP8CutlassFusedMoEMethod,
21+
W4A8MXFP4MXFP8CutlassFusedMoEMethod, WFP4A16FusedMoEMethod,
22+
WInt4AFP8FusedMoEMethod)
2323
# isort: on
2424
from .routing import BaseMoeRoutingMethod
2525

@@ -45,8 +45,6 @@ class CutlassFusedMoE(MoE):
4545
FusedMoE Op: dynamic quant + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute (return one tensor)
4646
p8 qdq, nvfp4:
4747
FusedMoE Op: scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute (return one tensor)
48-
weight only:
49-
FusedMoE Op: ... to finish
5048
5149
FusedMoE module:
5250
max-throughput mode:
@@ -185,15 +183,10 @@ def has_w4afp8(self):
185183
)
186184

187185
@property
188-
def has_woq_per_channel(self):
189-
return self.quant_config.layer_quant_mode.is_weight_only(
186+
def has_int8_woq_per_channel(self):
187+
return self.quant_config.layer_quant_mode.is_int8_weight_only(
190188
) and not self.quant_config.layer_quant_mode.has_per_group_scaling()
191189

192-
@property
193-
def has_woq_per_group_scaling(self):
194-
return self.quant_config.layer_quant_mode.is_weight_only(
195-
) and self.quant_config.layer_quant_mode.has_per_group_scaling()
196-
197190
@cached_property
198191
def enable_alltoall(self):
199192
return (self.mapping.moe_ep_size > self.routing_method.experts_per_token
@@ -214,8 +207,8 @@ def _get_quant_method(self):
214207
elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
215208
):
216209
return WInt4AFP8FusedMoEMethod()
217-
elif self.has_woq_per_channel:
218-
return WeightOnlyFusedMoEMethod()
210+
elif self.has_int8_woq_per_channel:
211+
return INT8WoqPerChannelFusedMoEMethod()
219212
elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8():
220213
return W4A8MXFP4FP8CutlassFusedMoEMethod()
221214
elif self.quant_config.layer_quant_mode.has_w4a16_mxfp4():
@@ -273,8 +266,7 @@ def forward_chunk(
273266
# quantize inputs
274267
use_deepseek_fp8_block_scale = False
275268
use_w4_group_scaling = False
276-
use_woq_per_channel = False
277-
use_woq_group_scaling = False
269+
use_int8_woq_per_channel = False
278270
use_mxfp8_act_scaling = False
279271
weight_dtype = self.w3_w1_weight.dtype
280272
x_sf = None
@@ -288,19 +280,15 @@ def forward_chunk(
288280
use_deepseek_fp8_block_scale = True
289281
elif self.has_w4afp8:
290282
use_w4_group_scaling = True
291-
use_woq_group_scaling = True
292283
weight_dtype = torch.quint4x2
293-
elif self.has_woq_per_channel:
294-
use_woq_per_channel = True
295-
elif self.has_woq_per_group_scaling:
296-
use_woq_group_scaling = True
297284
elif self.has_w4a16_mxfp4:
298285
pad_size = self.hidden_size - x.shape[1]
299286
original_hidden_size = x.shape[1]
300287
x = torch.nn.functional.pad(x, (0, pad_size))
301-
302288
use_w4_group_scaling = True
303289
weight_dtype = torch.uint8
290+
elif self.has_int8_woq_per_channel:
291+
use_int8_woq_per_channel = True
304292
elif self.has_nvfp4:
305293
if run_post_quant_allgather or self.enable_alltoall:
306294
if isinstance(x, Fp4QuantizedTensor):
@@ -440,8 +428,7 @@ def forward_chunk(
440428
enable_alltoall=self.enable_alltoall,
441429
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
442430
use_w4_group_scaling=use_w4_group_scaling,
443-
use_woq_per_channel=use_woq_per_channel,
444-
use_woq_group_scaling=use_woq_group_scaling,
431+
use_int8_woq_per_channel=use_int8_woq_per_channel,
445432
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
446433
min_latency_mode=False,
447434
use_fused_finalize=self.use_fused_finalize,

0 commit comments

Comments
 (0)