Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 80 additions & 9 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -365,6 +366,10 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
3. moe_finalize_scale_op: finalize the scale of the output tensor.
"""

# To reuse pytorch memory segments allocated during graph capture.
allocated_buffer_in_graph_pool: dict[str, list[torch.Tensor]] = {}
allocated_buffer_in_runtime: dict[str, torch.Tensor] = {}

def __init__(
self,
*,
Expand Down Expand Up @@ -410,28 +415,94 @@ def __init__(
)

def get_workspace(self, m_max: int, group_size: int):

def select_buffer_with_more_elements(
graph_buffer: Optional[torch.Tensor],
runtime_buffer: Optional[torch.Tensor]
) -> tuple[Optional[torch.Tensor], bool]:
if graph_buffer is None:
return runtime_buffer, False
if runtime_buffer is None:
return graph_buffer, True
use_graph = runtime_buffer.numel() > graph_buffer.numel()
return (runtime_buffer if use_graph else graph_buffer, use_graph)

def get_empty(tensor_shape: list[int], dtype: torch.dtype,
cache_name: str) -> torch.Tensor:
capture_graph = torch.cuda.is_current_stream_capturing()
if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None:
numel_like = math.prod(tensor_shape)
runtime_buffer = None
if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime:
buffer = DeepGemmFusedMoE.allocated_buffer_in_runtime[
cache_name]
numel_buffer = buffer.numel()
runtime_buffer = buffer if numel_buffer >= numel_like else runtime_buffer

graph_buffer = None
# Safely get the list of candidates. Defaults to an empty list if key is missing.
candidate_buffers = DeepGemmFusedMoE.allocated_buffer_in_graph_pool.get(
cache_name, [])
for buffer in candidate_buffers:
numel_buffer = buffer.numel()
# buffer just needs to be large enough.
if numel_buffer >= numel_like:
graph_buffer = buffer
break

if capture_graph and graph_buffer is not None:
return graph_buffer[0:numel_like].view(tensor_shape)
else:
buffer, use_graph = select_buffer_with_more_elements(
graph_buffer, runtime_buffer)
if buffer is not None:
if not use_graph and capture_graph:
# move the buffer into graph buffers since it's running in graph capturing mode.
DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault(
cache_name, []).append(buffer)
del DeepGemmFusedMoE.allocated_buffer_in_runtime[
cache_name]

return buffer[0:numel_like].view(tensor_shape)

# Reach here, no buffer is found. Then, we will use a new buffer to replace the small one. Release the memory first.
if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime:
del DeepGemmFusedMoE.allocated_buffer_in_runtime[cache_name]

# If we get here, no suitable buffer was found in the cache. Create a new one.
new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype)
if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None:
if capture_graph:
DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault(
cache_name, []).append(new_buffer)
else:
DeepGemmFusedMoE.allocated_buffer_in_runtime[
cache_name] = new_buffer
return new_buffer

hidden_size = self.hidden_size
intermediate_size = self.intermediate_size_per_partition
num_experts = self.expert_size_per_partition

# create workspace
fp8_dim = max(hidden_size, intermediate_size)
workspace_0 = torch.empty((num_experts * m_max * fp8_dim),
dtype=torch.float8_e4m3fn,
device='cuda')
workspace_1 = torch.empty(
(num_experts * m_max * max(intermediate_size * 2, hidden_size)),
workspace_0 = get_empty((num_experts * m_max * fp8_dim, ),
dtype=torch.float8_e4m3fn,
cache_name='workspace_0')
workspace_1 = get_empty(
(num_experts * m_max * max(intermediate_size * 2, hidden_size), ),
dtype=torch.bfloat16,
device='cuda')
cache_name='workspace_1')

# create workspace for scaling factors
m_padded = fp8_utils.align(m_max, 4)
scale_k = fp8_utils.ceil_div(fp8_dim, group_size)
scale_k_padded = fp8_utils.align(scale_k, 4)
workspace_sf = torch.empty(
(num_experts * (scale_k_padded // 4) * m_padded),

workspace_sf = get_empty(
(num_experts * (scale_k_padded // 4) * m_padded, ),
dtype=torch.int32,
device='cuda')
cache_name='workspace_sf')

workspace = {
"workspace_0": workspace_0,
Expand Down