Skip to content

Commit 2fa5a9f

Browse files
committed
Allocate MoE workspace only when necessary
Signed-off-by: Yilin Fan <[email protected]>
1 parent 9e02f6b commit 2fa5a9f

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
308308
std::vector<int64_t> output_shape = {num_rows, hidden_size};
309309
auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype));
310310

311-
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
312-
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode);
311+
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
312+
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream);
313313

314314
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
315315
kernels::MoeMinLatencyParams min_latency_params{};
@@ -439,8 +439,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
439439
min_latency_params.experts_to_token_score = static_cast<float*>(experts_to_token_score.data_ptr());
440440
min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr());
441441

442-
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
443-
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode);
442+
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
443+
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream);
444444

445445
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
446446

@@ -577,6 +577,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
577577
// e.g. 16 nvfp4 elements are packed into a single int64 element
578578
int64_t mInnerDimMultiplier;
579579
char* mProfileWorkspace = nullptr;
580+
WorkspaceInfo workspace_info;
580581

581582
bool mUseDeepSeekFP8BlockScaling = false;
582583
bool mUseW4A8GroupScaling = false;
@@ -622,9 +623,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
622623
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
623624
}
624625

625-
WorkspaceInfo getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
626+
WorkspaceInfo const& getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
626627
int num_experts, int experts_per_token, ActivationType activation_type,
627-
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode)
628+
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode, cudaStream_t stream)
628629
{
629630
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts,
630631
experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling,
@@ -633,15 +634,29 @@ class FusedMoeRunner : public torch::CustomClassHolder
633634

634635
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
635636

636-
size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
637+
int64_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
637638

638-
WorkspaceInfo info{};
639-
info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
640-
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
641-
info.src_to_dest_map
642-
= common::nextWorkspacePtr(static_cast<int8_t*>(info.workspace.data_ptr()), moe_workspace_size);
639+
bool is_capturing = tensorrt_llm::common::isCapturing(stream);
640+
// Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay
641+
if (is_capturing || workspace_info.workspace.numel() < total_workspace_size)
642+
{
643+
if (is_capturing)
644+
{
645+
TLLM_LOG_DEBUG(
646+
"Allocating MoE workspace with %ld bytes size during cuda graph capture", total_workspace_size);
647+
}
648+
else
649+
{
650+
TLLM_LOG_DEBUG("MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes",
651+
workspace_info.workspace.numel(), total_workspace_size);
652+
}
653+
workspace_info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
654+
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
655+
}
656+
workspace_info.src_to_dest_map
657+
= common::nextWorkspacePtr(static_cast<int8_t*>(workspace_info.workspace.data_ptr()), moe_workspace_size);
643658

644-
return info;
659+
return workspace_info;
645660
}
646661

647662
kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,

0 commit comments

Comments
 (0)