@@ -308,8 +308,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
308308 std::vector<int64_t > output_shape = {num_rows, hidden_size};
309309 auto output = torch::empty (output_shape, input.options ().dtype (mOutputDtype ));
310310
311- WorkspaceInfo workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
312- static_cast <int >(experts_per_token), activation_type, parallelism_config, min_latency_mode);
311+ WorkspaceInfo const & workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
312+ static_cast <int >(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream );
313313
314314 auto const quant_params = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales);
315315 kernels::MoeMinLatencyParams min_latency_params{};
@@ -439,8 +439,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
439439 min_latency_params.experts_to_token_score = static_cast <float *>(experts_to_token_score.data_ptr ());
440440 min_latency_params.active_expert_global_ids = static_cast <int *>(active_expert_global_ids.data_ptr ());
441441
442- WorkspaceInfo workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
443- static_cast <int >(experts_per_token), activation_type, parallelism_config, min_latency_mode);
442+ WorkspaceInfo const & workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
443+ static_cast <int >(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream );
444444
445445 auto const quant_params = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales);
446446
@@ -577,6 +577,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
577577 // e.g. 16 nvfp4 elements are packed into a single int64 element
578578 int64_t mInnerDimMultiplier ;
579579 char * mProfileWorkspace = nullptr ;
580+ WorkspaceInfo workspace_info;
580581
581582 bool mUseDeepSeekFP8BlockScaling = false ;
582583 bool mUseW4A8GroupScaling = false ;
@@ -622,9 +623,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
622623 mKernelRunner ->setTactic (best_gemm1_profile, best_gemm2_profile);
623624 }
624625
625- WorkspaceInfo getWorkspaceInfo (int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
626+ WorkspaceInfo const & getWorkspaceInfo (int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
626627 int num_experts, int experts_per_token, ActivationType activation_type,
627- kernels::MOEParallelismConfig const & parallelismConfig, bool min_latency_mode)
628+ kernels::MOEParallelismConfig const & parallelismConfig, bool min_latency_mode, cudaStream_t stream )
628629 {
629630 size_t moe_workspace_size = mKernelRunner ->getWorkspaceSize (num_rows, hidden_size, inter_size, num_experts,
630631 experts_per_token, activation_type, parallelismConfig, /* use_lora */ false , mUseDeepSeekFP8BlockScaling ,
@@ -633,15 +634,29 @@ class FusedMoeRunner : public torch::CustomClassHolder
633634
634635 std::vector<size_t > workspaces{moe_workspace_size, src_to_dest_map_size};
635636
636- size_t total_workspace_size = common::calculateTotalWorkspaceSize (workspaces.data (), workspaces.size ());
637+ int64_t const total_workspace_size = common::calculateTotalWorkspaceSize (workspaces.data (), workspaces.size ());
637638
638- WorkspaceInfo info{};
639- info.workspace = torch::empty ({static_cast <long >(total_workspace_size)},
640- torch::dtype (torch::kInt8 ).device (torch::kCUDA ).requires_grad (false ));
641- info.src_to_dest_map
642- = common::nextWorkspacePtr (static_cast <int8_t *>(info.workspace .data_ptr ()), moe_workspace_size);
639+ bool is_capturing = tensorrt_llm::common::isCapturing (stream);
640+ // Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay
641+ if (is_capturing || workspace_info.workspace .numel () < total_workspace_size)
642+ {
643+ if (is_capturing)
644+ {
645+ TLLM_LOG_DEBUG (
646+ " Allocating MoE workspace with %ld bytes size during cuda graph capture" , total_workspace_size);
647+ }
648+ else
649+ {
650+ TLLM_LOG_DEBUG (" MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes" ,
651+ workspace_info.workspace .numel (), total_workspace_size);
652+ }
653+ workspace_info.workspace = torch::empty ({static_cast <long >(total_workspace_size)},
654+ torch::dtype (torch::kInt8 ).device (torch::kCUDA ).requires_grad (false ));
655+ }
656+ workspace_info.src_to_dest_map
657+ = common::nextWorkspacePtr (static_cast <int8_t *>(workspace_info.workspace .data_ptr ()), moe_workspace_size);
643658
644- return info ;
659+ return workspace_info ;
645660 }
646661
647662 kernels::QuantParams getQuantParams (int64_t const num_experts_on_rank, int64_t const hidden_size,
0 commit comments