Skip to content

Commit aa849a3

Browse files
committed
Add autotune support for NVP4 TRTLLM Gen MoE. Squash of several dev commits
Signed-off-by: Dom Brown <[email protected]> Add new C++ wrapper runner and use that instead Signed-off-by: Dom Brown <[email protected]> Using new python runner Signed-off-by: Dom Brown <[email protected]> Adds autotune Signed-off-by: Dom Brown <[email protected]> Ensure cache key reuse Signed-off-by: Dom Brown <[email protected]> Structure tests such that all are autotune by default, run one case of non-autotune to test fallback tactic selection Signed-off-by: Dom Brown <[email protected]>
1 parent 471bf0b commit aa849a3

File tree

3 files changed

+658
-333
lines changed

3 files changed

+658
-333
lines changed

cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ namespace torch_ext
2525
{
2626
namespace btg = batchedGemm::trtllm::gen;
2727
using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType;
28+
using MoeRunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner;
2829

29-
std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routing_logits,
30+
std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& routing_logits,
3031
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
3132
torch::Tensor const& hidden_states_scale, torch::Tensor const& gemm1_weights,
3233
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
@@ -35,14 +36,16 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
3536
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
3637
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
3738
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
38-
int64_t const routing_method_type, bool const do_finalize)
39+
int64_t const routing_method_type, bool const do_finalize, MoeRunnerType& moe_runner, int64_t const moeConfigIndex)
3940
{
4041
auto const sm = tensorrt_llm::common::getSMVersion();
4142
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP4 block scale MOE");
4243
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float
4344
|| routing_logits.scalar_type() == at::ScalarType::BFloat16,
4445
"routing_logits must be float or bfloat16.");
4546
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
47+
TORCH_CHECK(routing_logits.sizes()[0] == hidden_states.sizes()[0],
48+
"routing_logits and hidden_states must have the same number of tokens.");
4649
TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits has incorrect shape.");
4750
if (routing_bias.has_value())
4851
{
@@ -261,13 +264,7 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
261264
args.output2_scales_scalar = output2_scales_scalar.data_ptr<float>();
262265
args.do_finalize = do_finalize;
263266

264-
tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner moe_runner(
265-
args.mDtypeElt, args.mUseDeepSeekFp8, tile_tokens_dim);
266-
267-
auto const moeConfigIndex = moe_runner.getDefaultValidConfigIndex(
268-
args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens);
269-
270-
auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex);
267+
auto const workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex);
271268

272269
at::Tensor workspace_fc1 = at::detail::empty_cuda(
273270
{std::get<0>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt);
@@ -286,6 +283,63 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
286283
return {output};
287284
}
288285

286+
// Wrapped the TRTLLM-Gen kernel runner in a Torch custom class to allow
287+
// use with the torch workflow autotuner class.
288+
class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
289+
{
290+
public:
291+
explicit FP4BlockScaleMoeRunner(int64_t tileTokensDim)
292+
: mTileTokensDim(tileTokensDim)
293+
{
294+
mRunner = std::make_unique<RunnerType>(mDtypeElt, mUseDeepSeekFp8, mTileTokensDim);
295+
}
296+
297+
[[nodiscard]] std::vector<torch::Tensor> run(torch::Tensor const& routing_logits,
298+
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
299+
torch::Tensor const& hidden_states_scale, torch::Tensor const& gemm1_weights,
300+
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
301+
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
302+
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
303+
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
304+
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
305+
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor,
306+
int64_t const routing_method_type, bool const do_finalize, int64_t moeConfigIndex)
307+
{
308+
309+
// Autotuner has requested a default or 'fallback' config index
310+
if (moeConfigIndex == -1)
311+
{
312+
auto const num_tokens = hidden_states.sizes()[0];
313+
314+
// 2x FP4 per byte element
315+
auto const hidden_size = 2 * hidden_states.sizes()[1];
316+
317+
moeConfigIndex = mRunner->getDefaultValidConfigIndex(
318+
top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
319+
}
320+
321+
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, hidden_states_scale,
322+
gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output1_scales_scalar,
323+
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
324+
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, mTileTokensDim,
325+
routing_method_type, do_finalize, *mRunner, moeConfigIndex);
326+
}
327+
328+
[[nodiscard]] std::vector<int64_t> getValidConfigs(
329+
int64_t topK, int64_t hiddenSize, int64_t intermediateSize, int64_t numLocalExperts, int64_t numTokens) const
330+
{
331+
return mRunner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
332+
}
333+
334+
private:
335+
using RunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner;
336+
337+
std::unique_ptr<RunnerType> mRunner;
338+
btg::Dtype mDtypeElt{btg::Dtype::E2m1};
339+
bool mUseDeepSeekFp8{false};
340+
int64_t mTileTokensDim;
341+
};
342+
289343
torch::Tensor shuffleMatrix(torch::Tensor matrix, torch::Tensor permuteIndices)
290344
{
291345
return torch::index_select(matrix, 0, permuteIndices);
@@ -295,36 +349,10 @@ torch::Tensor shuffleMatrix(torch::Tensor matrix, torch::Tensor permuteIndices)
295349

296350
TORCH_LIBRARY_FRAGMENT(trtllm, m)
297351
{
298-
m.def(
299-
"fp4_block_scale_moe_runner("
300-
"Tensor routing_logits,"
301-
"Tensor? routing_bias,"
302-
"Tensor hidden_states,"
303-
"Tensor hidden_states_scale,"
304-
"Tensor gemm1_weights,"
305-
"Tensor gemm1_weights_scale,"
306-
"Tensor gemm2_weights,"
307-
"Tensor gemm2_weights_scale,"
308-
"Tensor output1_scale_scalar,"
309-
"Tensor output1_scale_gate_scalar,"
310-
"Tensor output2_scale_scalar,"
311-
"int num_experts,"
312-
"int top_k,"
313-
"int? n_group,"
314-
"int? topk_group,"
315-
"int intermediate_size,"
316-
"int local_expert_offset,"
317-
"int local_num_experts,"
318-
"float? routed_scaling_factor,"
319-
"int tile_tokens_dim,"
320-
"int routing_method_type,"
321-
"bool do_finalize) -> Tensor[]");
322-
}
323-
324-
// Accepts CUDA tensor only
325-
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
326-
{
327-
m.impl("fp4_block_scale_moe_runner", &torch_ext::fp4_block_scale_moe_runner);
352+
m.class_<torch_ext::FP4BlockScaleMoeRunner>("FP4BlockScaleMoERunner")
353+
.def(torch::init<int64_t>())
354+
.def("get_valid_configs", &torch_ext::FP4BlockScaleMoeRunner::getValidConfigs)
355+
.def("run_moe", &torch_ext::FP4BlockScaleMoeRunner::run);
328356
}
329357

330358
// Accepts both CPU and CUDA tensors

0 commit comments

Comments
 (0)