@@ -25,8 +25,9 @@ namespace torch_ext
25
25
{
26
26
namespace btg = batchedGemm::trtllm::gen;
27
27
using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType;
28
+ using MoeRunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner;
28
29
29
- std::vector<torch::Tensor> fp4_block_scale_moe_runner (torch::Tensor const & routing_logits,
30
+ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner (torch::Tensor const & routing_logits,
30
31
torch::optional<torch::Tensor> const & routing_bias, torch::Tensor const & hidden_states,
31
32
torch::Tensor const & hidden_states_scale, torch::Tensor const & gemm1_weights,
32
33
torch::Tensor const & gemm1_weights_scale, torch::Tensor const & gemm2_weights,
@@ -35,14 +36,16 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
35
36
int64_t const num_experts, int64_t const top_k, std::optional<int64_t > const n_group,
36
37
std::optional<int64_t > const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
37
38
int64_t const local_num_experts, std::optional<double > const routed_scaling_factor, int64_t const tile_tokens_dim,
38
- int64_t const routing_method_type, bool const do_finalize)
39
+ int64_t const routing_method_type, bool const do_finalize, MoeRunnerType& moe_runner, int64_t const moeConfigIndex )
39
40
{
40
41
auto const sm = tensorrt_llm::common::getSMVersion ();
41
42
TORCH_CHECK (sm == 100 , " Only SM100 is supported by FP4 block scale MOE" );
42
43
TORCH_CHECK (routing_logits.scalar_type () == at::ScalarType::Float
43
44
|| routing_logits.scalar_type () == at::ScalarType::BFloat16,
44
45
" routing_logits must be float or bfloat16." );
45
46
TORCH_CHECK (routing_logits.dim () == 2 , " routing_logits must be 2D." );
47
+ TORCH_CHECK (routing_logits.sizes ()[0 ] == hidden_states.sizes ()[0 ],
48
+ " routing_logits and hidden_states must have the same number of tokens." );
46
49
TORCH_CHECK (routing_logits.sizes ()[1 ] == num_experts, " routing_logits has incorrect shape." );
47
50
if (routing_bias.has_value ())
48
51
{
@@ -261,13 +264,7 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
261
264
args.output2_scales_scalar = output2_scales_scalar.data_ptr <float >();
262
265
args.do_finalize = do_finalize;
263
266
264
- tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner moe_runner (
265
- args.mDtypeElt , args.mUseDeepSeekFp8 , tile_tokens_dim);
266
-
267
- auto const moeConfigIndex = moe_runner.getDefaultValidConfigIndex (
268
- args.top_k , args.hidden_size , args.intermediate_size , args.local_num_experts , args.num_tokens );
269
-
270
- auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes (args, moeConfigIndex);
267
+ auto const workspace_sizes = moe_runner.getWorkspaceSizeInBytes (args, moeConfigIndex);
271
268
272
269
at::Tensor workspace_fc1 = at::detail::empty_cuda (
273
270
{std::get<0 >(workspace_sizes)}, at::ScalarType::Char, hidden_states.device (), std::nullopt);
@@ -286,6 +283,63 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
286
283
return {output};
287
284
}
288
285
286
+ // Wrapped the TRTLLM-Gen kernel runner in a Torch custom class to allow
287
+ // use with the torch workflow autotuner class.
288
+ class FP4BlockScaleMoeRunner : public torch ::CustomClassHolder
289
+ {
290
+ public:
291
+ explicit FP4BlockScaleMoeRunner (int64_t tileTokensDim)
292
+ : mTileTokensDim(tileTokensDim)
293
+ {
294
+ mRunner = std::make_unique<RunnerType>(mDtypeElt , mUseDeepSeekFp8 , mTileTokensDim );
295
+ }
296
+
297
+ [[nodiscard]] std::vector<torch::Tensor> run (torch::Tensor const & routing_logits,
298
+ torch::optional<torch::Tensor> const & routing_bias, torch::Tensor const & hidden_states,
299
+ torch::Tensor const & hidden_states_scale, torch::Tensor const & gemm1_weights,
300
+ torch::Tensor const & gemm1_weights_scale, torch::Tensor const & gemm2_weights,
301
+ torch::Tensor const & gemm2_weights_scale, torch::Tensor const & output1_scales_scalar,
302
+ torch::Tensor const & output1_scales_gate_scalar, torch::Tensor const & output2_scales_scalar,
303
+ int64_t const num_experts, int64_t const top_k, std::optional<int64_t > const n_group,
304
+ std::optional<int64_t > const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
305
+ int64_t const local_num_experts, std::optional<double > const routed_scaling_factor,
306
+ int64_t const routing_method_type, bool const do_finalize, int64_t moeConfigIndex)
307
+ {
308
+
309
+ // Autotuner has requested a default or 'fallback' config index
310
+ if (moeConfigIndex == -1 )
311
+ {
312
+ auto const num_tokens = hidden_states.sizes ()[0 ];
313
+
314
+ // 2x FP4 per byte element
315
+ auto const hidden_size = 2 * hidden_states.sizes ()[1 ];
316
+
317
+ moeConfigIndex = mRunner ->getDefaultValidConfigIndex (
318
+ top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
319
+ }
320
+
321
+ return run_fp4_block_scale_moe_runner (routing_logits, routing_bias, hidden_states, hidden_states_scale,
322
+ gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output1_scales_scalar,
323
+ output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
324
+ intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, mTileTokensDim ,
325
+ routing_method_type, do_finalize, *mRunner , moeConfigIndex);
326
+ }
327
+
328
+ [[nodiscard]] std::vector<int64_t > getValidConfigs (
329
+ int64_t topK, int64_t hiddenSize, int64_t intermediateSize, int64_t numLocalExperts, int64_t numTokens) const
330
+ {
331
+ return mRunner ->getValidConfigIndices (topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
332
+ }
333
+
334
+ private:
335
+ using RunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner;
336
+
337
+ std::unique_ptr<RunnerType> mRunner ;
338
+ btg::Dtype mDtypeElt {btg::Dtype::E2m1};
339
+ bool mUseDeepSeekFp8 {false };
340
+ int64_t mTileTokensDim ;
341
+ };
342
+
289
343
torch::Tensor shuffleMatrix (torch::Tensor matrix, torch::Tensor permuteIndices)
290
344
{
291
345
return torch::index_select (matrix, 0 , permuteIndices);
@@ -295,36 +349,10 @@ torch::Tensor shuffleMatrix(torch::Tensor matrix, torch::Tensor permuteIndices)
295
349
296
350
TORCH_LIBRARY_FRAGMENT (trtllm, m)
297
351
{
298
- m.def (
299
- " fp4_block_scale_moe_runner("
300
- " Tensor routing_logits,"
301
- " Tensor? routing_bias,"
302
- " Tensor hidden_states,"
303
- " Tensor hidden_states_scale,"
304
- " Tensor gemm1_weights,"
305
- " Tensor gemm1_weights_scale,"
306
- " Tensor gemm2_weights,"
307
- " Tensor gemm2_weights_scale,"
308
- " Tensor output1_scale_scalar,"
309
- " Tensor output1_scale_gate_scalar,"
310
- " Tensor output2_scale_scalar,"
311
- " int num_experts,"
312
- " int top_k,"
313
- " int? n_group,"
314
- " int? topk_group,"
315
- " int intermediate_size,"
316
- " int local_expert_offset,"
317
- " int local_num_experts,"
318
- " float? routed_scaling_factor,"
319
- " int tile_tokens_dim,"
320
- " int routing_method_type,"
321
- " bool do_finalize) -> Tensor[]" );
322
- }
323
-
324
- // Accepts CUDA tensor only
325
- TORCH_LIBRARY_IMPL (trtllm, CUDA, m)
326
- {
327
- m.impl (" fp4_block_scale_moe_runner" , &torch_ext::fp4_block_scale_moe_runner);
352
+ m.class_ <torch_ext::FP4BlockScaleMoeRunner>(" FP4BlockScaleMoERunner" )
353
+ .def (torch::init<int64_t >())
354
+ .def (" get_valid_configs" , &torch_ext::FP4BlockScaleMoeRunner::getValidConfigs)
355
+ .def (" run_moe" , &torch_ext::FP4BlockScaleMoeRunner::run);
328
356
}
329
357
330
358
// Accepts both CPU and CUDA tensors
0 commit comments