Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions cpp/tensorrt_llm/thop/moeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
std::vector<int64_t> output_shape = {num_rows, hidden_size};
auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype));

WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode);
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream);

auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
kernels::MoeMinLatencyParams min_latency_params{};
Expand Down Expand Up @@ -439,8 +439,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
min_latency_params.experts_to_token_score = static_cast<float*>(experts_to_token_score.data_ptr());
min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr());

WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode);
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream);

auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);

Expand Down Expand Up @@ -577,6 +577,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
// e.g. 16 nvfp4 elements are packed into a single int64 element
int64_t mInnerDimMultiplier;
char* mProfileWorkspace = nullptr;
WorkspaceInfo workspace_info;

bool mUseDeepSeekFP8BlockScaling = false;
bool mUseW4A8GroupScaling = false;
Expand Down Expand Up @@ -622,9 +623,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
}

WorkspaceInfo getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
WorkspaceInfo const& getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int num_experts, int experts_per_token, ActivationType activation_type,
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode)
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode, cudaStream_t stream)
{
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts,
experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling,
Expand All @@ -633,15 +634,29 @@ class FusedMoeRunner : public torch::CustomClassHolder

std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};

size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
int64_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());

WorkspaceInfo info{};
info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
info.src_to_dest_map
= common::nextWorkspacePtr(static_cast<int8_t*>(info.workspace.data_ptr()), moe_workspace_size);
bool is_capturing = tensorrt_llm::common::isCapturing(stream);
// Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay
if (is_capturing || workspace_info.workspace.numel() < total_workspace_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When workspace_info.workspace.numel() < total_workspace_size and MOE kernels are running asynchronously in different streams, is it possible that 2 kernels from different streams access the same workspace at the same time? @jinyangyuan-nvidia @nv-yilinf

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think historically we have assumed that we only ever have one stream running MOE in a few places. But looking at the chunked MOE logic this definitely is a problematic assumption. Its quite possible there are a few bugs with this.

{
if (is_capturing)
{
TLLM_LOG_DEBUG(
"Allocating MoE workspace with %ld bytes size during cuda graph capture", total_workspace_size);
}
else
{
TLLM_LOG_DEBUG("MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes",
workspace_info.workspace.numel(), total_workspace_size);
}
workspace_info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
}
workspace_info.src_to_dest_map
= common::nextWorkspacePtr(static_cast<int8_t*>(workspace_info.workspace.data_ptr()), moe_workspace_size);

return info;
return workspace_info;
}

kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,
Expand Down