diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 392fff09115..6f5ed94d0f4 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -1,3 +1,4 @@ +import math from typing import Dict, List, Optional, Union import torch @@ -365,6 +366,10 @@ class DeepGemmFusedMoE(CutlassFusedMoE): 3. moe_finalize_scale_op: finalize the scale of the output tensor. """ + # To reuse pytorch memory segments allocated during graph capture. + allocated_buffer_in_graph_pool: dict[str, list[torch.Tensor]] = {} + allocated_buffer_in_runtime: dict[str, torch.Tensor] = {} + def __init__( self, *, @@ -410,28 +415,94 @@ def __init__( ) def get_workspace(self, m_max: int, group_size: int): + + def select_buffer_with_more_elements( + graph_buffer: Optional[torch.Tensor], + runtime_buffer: Optional[torch.Tensor] + ) -> tuple[Optional[torch.Tensor], bool]: + if graph_buffer is None: + return runtime_buffer, False + if runtime_buffer is None: + return graph_buffer, True + use_graph = runtime_buffer.numel() > graph_buffer.numel() + return (runtime_buffer if use_graph else graph_buffer, use_graph) + + def get_empty(tensor_shape: list[int], dtype: torch.dtype, + cache_name: str) -> torch.Tensor: + capture_graph = torch.cuda.is_current_stream_capturing() + if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None: + numel_like = math.prod(tensor_shape) + runtime_buffer = None + if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime: + buffer = DeepGemmFusedMoE.allocated_buffer_in_runtime[ + cache_name] + numel_buffer = buffer.numel() + runtime_buffer = buffer if numel_buffer >= numel_like else runtime_buffer + + graph_buffer = None + # Safely get the list of candidates. Defaults to an empty list if key is missing. + candidate_buffers = DeepGemmFusedMoE.allocated_buffer_in_graph_pool.get( + cache_name, []) + for buffer in candidate_buffers: + numel_buffer = buffer.numel() + # buffer just needs to be large enough. + if numel_buffer >= numel_like: + graph_buffer = buffer + break + + if capture_graph and graph_buffer is not None: + return graph_buffer[0:numel_like].view(tensor_shape) + else: + buffer, use_graph = select_buffer_with_more_elements( + graph_buffer, runtime_buffer) + if buffer is not None: + if not use_graph and capture_graph: + # move the buffer into graph buffers since it's running in graph capturing mode. + DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault( + cache_name, []).append(buffer) + del DeepGemmFusedMoE.allocated_buffer_in_runtime[ + cache_name] + + return buffer[0:numel_like].view(tensor_shape) + + # Reach here, no buffer is found. Then, we will use a new buffer to replace the small one. Release the memory first. + if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime: + del DeepGemmFusedMoE.allocated_buffer_in_runtime[cache_name] + + # If we get here, no suitable buffer was found in the cache. Create a new one. + new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype) + if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None: + if capture_graph: + DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault( + cache_name, []).append(new_buffer) + else: + DeepGemmFusedMoE.allocated_buffer_in_runtime[ + cache_name] = new_buffer + return new_buffer + hidden_size = self.hidden_size intermediate_size = self.intermediate_size_per_partition num_experts = self.expert_size_per_partition # create workspace fp8_dim = max(hidden_size, intermediate_size) - workspace_0 = torch.empty((num_experts * m_max * fp8_dim), - dtype=torch.float8_e4m3fn, - device='cuda') - workspace_1 = torch.empty( - (num_experts * m_max * max(intermediate_size * 2, hidden_size)), + workspace_0 = get_empty((num_experts * m_max * fp8_dim, ), + dtype=torch.float8_e4m3fn, + cache_name='workspace_0') + workspace_1 = get_empty( + (num_experts * m_max * max(intermediate_size * 2, hidden_size), ), dtype=torch.bfloat16, - device='cuda') + cache_name='workspace_1') # create workspace for scaling factors m_padded = fp8_utils.align(m_max, 4) scale_k = fp8_utils.ceil_div(fp8_dim, group_size) scale_k_padded = fp8_utils.align(scale_k, 4) - workspace_sf = torch.empty( - (num_experts * (scale_k_padded // 4) * m_padded), + + workspace_sf = get_empty( + (num_experts * (scale_k_padded // 4) * m_padded, ), dtype=torch.int32, - device='cuda') + cache_name='workspace_sf') workspace = { "workspace_0": workspace_0,