diff --git a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py new file mode 100644 index 000000000..12bfefdd1 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py @@ -0,0 +1,228 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _offload_gpu_kv_to_cpu( + token_indexes_ptr, + gpu_kv_cache_ptr, + gpu_stride0, + gpu_stride1, + gpu_stride2, + gpu_stride3, + cpu_kv_cache_ptr, + cpu_stride0, + cpu_stride1, + cpu_stride2, + cpu_stride3, + cpu_stride4, + page_indexes_ptr, + page_readies_ptr, + layer_num, + head_all_dim, + BLOCK_HEAD_ALL_DIM: tl.constexpr, + TOKEN_BLOCK: tl.constexpr, +): + block_index = tl.program_id(0) + cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64) + if cpu_page_index == -1: + return + + ready_state = tl.load(page_readies_ptr + block_index) + if ready_state: + return + + token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK) + token_indexes = tl.load(token_indexes_ptr + token_range).to(tl.int64) + head_all_dim_range = tl.arange(0, BLOCK_HEAD_ALL_DIM) + + gpu_stride0 = tl.cast(gpu_stride0, dtype=tl.int64) + + for layer_index in range(layer_num): + gpu_ptr = ( + gpu_kv_cache_ptr + + layer_index * gpu_stride0 + + token_indexes[:, None] * gpu_stride1 + + head_all_dim_range[None, :] + ) + gpu_data = tl.load(gpu_ptr, mask=(head_all_dim_range[None, :] < head_all_dim), other=0.0) + cpu_ptr = ( + cpu_kv_cache_ptr + + cpu_page_index * cpu_stride0 + + layer_index * cpu_stride1 + + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2 + + head_all_dim_range[None, :] + ) + tl.store( + cpu_ptr, + gpu_data, + mask=(head_all_dim_range[None, :] < head_all_dim), + ) + return + + + +@torch.no_grad() +def offload_gpu_kv_to_cpu( + token_indexes: torch.Tensor, + gpu_kv_cache: torch.Tensor, + cpu_kv_cache: torch.Tensor, + page_indexes: torch.Tensor, + page_readies: torch.Tensor, +): + """ + this function is used to offload GPU KV cache to CPU KV cache. + Args: + token_indexes: (token_num,) + gpu_kv_cache: (layer_num, token_num, head_num, head_dim) + cpu_kv_cache: (all_page_num, layer_num, token_block_size, head_num, head_dim) + page_indexes: (page_num,) + page_readies: (page_num,) + """ + token_block_size = cpu_kv_cache.shape[2] + token_num = page_indexes.shape[0] * token_block_size + assert token_indexes.shape[0] >= token_num + assert page_indexes.shape == page_readies.shape + page_num = page_indexes.shape[0] + head_all_dim = gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2] + BLOCK_HEAD_ALL_DIM = triton.next_power_of_2(gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2]) + + grid = (page_num,) + num_warps = 4 + + _offload_gpu_kv_to_cpu[grid]( + token_indexes_ptr=token_indexes, + gpu_kv_cache_ptr=gpu_kv_cache, + gpu_stride0=gpu_kv_cache.stride(0), + gpu_stride1=gpu_kv_cache.stride(1), + gpu_stride2=gpu_kv_cache.stride(2), + gpu_stride3=gpu_kv_cache.stride(3), + cpu_kv_cache_ptr=cpu_kv_cache, + cpu_stride0=cpu_kv_cache.stride(0), + cpu_stride1=cpu_kv_cache.stride(1), + cpu_stride2=cpu_kv_cache.stride(2), + cpu_stride3=cpu_kv_cache.stride(3), + cpu_stride4=cpu_kv_cache.stride(4), + page_indexes_ptr=page_indexes, + page_readies_ptr=page_readies, + layer_num=gpu_kv_cache.shape[0], + head_all_dim=head_all_dim, + BLOCK_HEAD_ALL_DIM=BLOCK_HEAD_ALL_DIM, + TOKEN_BLOCK=token_block_size, + num_warps=num_warps, + num_stages=1, + ) + return + + +@triton.jit +def _load_cpu_cache_to_gpu( + token_indexes_ptr, + gpu_kv_cache_ptr, + gpu_stride0, + gpu_stride1, + gpu_stride2, + gpu_stride3, + cpu_kv_cache_ptr, + cpu_stride0, + cpu_stride1, + cpu_stride2, + cpu_stride3, + cpu_stride4, + page_indexes_ptr, + layer_num, + head_all_dim, + all_move_token_num, + BLOCK_HEAD_ALL_DIM: tl.constexpr, + TOKEN_BLOCK: tl.constexpr, +): + block_index = tl.program_id(0) + cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64) + if cpu_page_index == -1: + return + + gpu_stride0 = tl.cast(gpu_stride0, dtype=tl.int64) + padded_size = TOKEN_BLOCK * tl.num_programs(0) - all_move_token_num + head_all_dim_range = tl.arange(0, BLOCK_HEAD_ALL_DIM) + token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK) + token_range = token_range - padded_size + + token_mask = token_range >= 0 + head_dim_mask = head_all_dim_range < head_all_dim + + token_indexes = tl.load(token_indexes_ptr + token_range, mask=token_mask, other=0).to(tl.int64) + + cpu_page_index = tl.load(page_indexes_ptr + block_index) + for layer_index in range(layer_num): + cpu_ptr = ( + cpu_kv_cache_ptr + + cpu_page_index * cpu_stride0 + + layer_index * cpu_stride1 + + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2 + + head_all_dim_range[None, :] + ) + cpu_data = tl.load(cpu_ptr, mask=head_dim_mask[None, :], other=0.0) + + gpu_ptr = ( + gpu_kv_cache_ptr + + layer_index * gpu_stride0 + + token_indexes[:, None] * gpu_stride1 + + head_all_dim_range[None, :] + ) + tl.store( + gpu_ptr, + cpu_data, + mask=token_mask[:, None] & head_dim_mask[None, :], + ) + return + + +@torch.no_grad() +def load_cpu_kv_to_gpu( + mem_indexes: torch.Tensor, + gpu_kv_cache: torch.Tensor, + cpu_kv_cache: torch.Tensor, + page_indexes: torch.Tensor, +): + """ + this function is used to offload GPU KV cache to CPU KV cache. + Args: + mem_indexes: (token_num,) + gpu_kv_cache: (layer_num, token_num, head_num, head_dim) + cpu_kv_cache: (page_num, layer_num, token_block_size, head_num, head_dim) + page_indexes: (page_num,) + """ + token_block_size = cpu_kv_cache.shape[2] + token_num = page_indexes.shape[0] * token_block_size + assert mem_indexes.shape[0] >= token_num + page_num = page_indexes.shape[0] + BLOCK_HEAD_ALL_DIM = triton.next_power_of_2(gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2]) + + grid = (page_num,) + num_warps = 1 + + _offload_gpu_kv_to_cpu[grid]( + token_indexes_ptr=mem_indexes, + gpu_kv_cache_ptr=gpu_kv_cache, + gpu_stride0=gpu_kv_cache.stride(0), + gpu_stride1=gpu_kv_cache.stride(1), + gpu_stride2=gpu_kv_cache.stride(2), + gpu_stride3=gpu_kv_cache.stride(3), + cpu_kv_cache_ptr=cpu_kv_cache, + cpu_stride0=cpu_kv_cache.stride(0), + cpu_stride1=cpu_kv_cache.stride(1), + cpu_stride2=cpu_kv_cache.stride(2), + cpu_stride3=cpu_kv_cache.stride(3), + cpu_stride4=cpu_kv_cache.stride(4), + page_indexes_ptr=page_indexes, + layer_num=gpu_kv_cache.shape[0], + head_all_dim=gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2], + all_move_token_num=len(mem_indexes), + BLOCK_HEAD_ALL_DIM=BLOCK_HEAD_ALL_DIM, + TOKEN_BLOCK=token_block_size, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index dfbf0c84b..84a2a6233 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -477,4 +477,25 @@ def make_argument_parser() -> argparse.ArgumentParser: default=0.03, help="""The interval of the schedule time, default is 30ms.""", ) + parser.add_argument( + "--enable_cpu_cache", + action="store_true", + help="""enable cpu cache to store kv cache.""", + ) + parser.add_argument( + "--cpu_cache_storage_size", + type=float, + default=2, + help="""The capacity of cpu cache. GB used.""", + ) + parser.add_argument( + "--cpu_cache_token_page_size", + type=int, + default=256, + help="""The token page size of cpu cache""", + ) + parser.add_argument("--enable_disk_cache", action="store_true", help="""enable disk cache to store kv cache.""") + parser.add_argument( + "--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used.""" + ) return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 3835994e5..11efd35a1 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -38,6 +38,7 @@ from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect from fastapi.responses import Response, StreamingResponse, JSONResponse from lightllm.server.core.objs.sampling_params import SamplingParams +from lightllm.server.core.objs import StartArgs from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster @@ -71,7 +72,7 @@ class G_Objs: httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None shared_token_load: TokenLoad = None - def set_args(self, args): + def set_args(self, args: StartArgs): self.args = args from .api_lightllm import lightllm_generate, lightllm_generate_stream from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl @@ -86,22 +87,13 @@ def set_args(self, args): if args.run_mode == "pd_master": self.metric_client = MetricClient(args.metric_port) self.httpserver_manager = HttpServerManagerForPDMaster( - args, - metric_port=args.metric_port, + args=args, ) else: init_tokenizer(args) # for openai api SamplingParams.load_generation_cfg(args.model_dir) self.metric_client = MetricClient(args.metric_port) - self.httpserver_manager = HttpServerManager( - args, - router_port=args.router_port, - cache_port=args.cache_port, - detokenization_pub_port=args.detokenization_pub_port, - visual_port=args.visual_port, - enable_multimodal=args.enable_multimodal, - metric_port=args.metric_port, - ) + self.httpserver_manager = HttpServerManager(args=args) dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index c2a87b4c3..894cc28b7 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -70,6 +70,10 @@ def normal_or_p_d_start(args): if args.run_mode not in ["normal", "prefill", "decode"]: return + if args.enable_cpu_cache: + # 生成一个用于创建cpu kv cache的共享内存id。 + args.cpu_kv_cache_shm_id = uuid.uuid1().int % 123456789 + assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] # 确保单机上多实列不冲突 if args.zmq_mode == "ipc:///tmp/": @@ -213,19 +217,20 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=7 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=8 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( router_port, detokenization_port, - detokenization_pub_port, + http_server_port, visual_port, audio_port, cache_port, metric_port, - ) = can_use_ports[0:7] - can_use_ports = can_use_ports[7:] + multi_level_kv_cache_port, + ) = can_use_ports[0:8] + can_use_ports = can_use_ports[8:] visual_model_tp_ports = [] for _ in range(args.visual_dp): @@ -236,11 +241,12 @@ def normal_or_p_d_start(args): # 将申请好的端口放入args参数中 args.router_port = router_port args.detokenization_port = detokenization_port - args.detokenization_pub_port = detokenization_pub_port + args.http_server_port = http_server_port args.visual_port = visual_port args.audio_port = audio_port args.cache_port = cache_port args.metric_port = metric_port + args.multi_level_kv_cache_port = multi_level_kv_cache_port # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] @@ -267,50 +273,51 @@ def normal_or_p_d_start(args): start_funcs=[ start_cache_manager, ], - start_args=[(cache_port, args)], + start_args=[(args,)], ) + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args, visual_model_tp_ports), + ], + ) + if args.enable_multimodal_audio: from .audioserver.manager import start_audio_process - process_manager.start_submodule_processes( - start_funcs=[ - start_visual_process, - ], - start_args=[ - (args, audio_port, visual_port, cache_port, visual_model_tp_ports), - ], - ) process_manager.start_submodule_processes( start_funcs=[ start_audio_process, ], start_args=[ - (args, router_port, audio_port, cache_port), + (args,), ], ) - else: - process_manager.start_submodule_processes( - start_funcs=[ - start_visual_process, - ], - start_args=[ - (args, router_port, visual_port, cache_port, visual_model_tp_ports), - ], - ) + if args.enable_cpu_cache: + from .multi_level_kv_cache.manager import start_multi_level_kv_cache_manager + + process_manager.start_submodule_processes( + start_funcs=[ + start_multi_level_kv_cache_manager, + ], + start_args=[(args,)], + ) process_manager.start_submodule_processes( start_funcs=[ start_metric_manager, ], - start_args=[(metric_port, args)], + start_args=[(args,)], ) process_manager.start_submodule_processes( start_funcs=[start_router_process, start_detokenization_process], start_args=[ - (args, router_port, detokenization_port, metric_port), - (args, detokenization_port, detokenization_pub_port), + (args,), + (args,), ], ) @@ -380,7 +387,7 @@ def pd_master_start(args): start_funcs=[ start_metric_manager, ], - start_args=[(metric_port, args)], + start_args=[(args,)], ) command = [ diff --git a/lightllm/server/audioserver/manager.py b/lightllm/server/audioserver/manager.py index 707fd11d0..5a90b6b84 100644 --- a/lightllm/server/audioserver/manager.py +++ b/lightllm/server/audioserver/manager.py @@ -10,7 +10,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.utils.log_utils import init_logger from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes -from lightllm.server.core.objs.shm_req_manager import ShmReqManager +from lightllm.server.core.objs.shm_req_manager import ShmReqManager, StartArgs from lightllm.server.multimodal_params import AudioItem from .model_infer.model_rpc import start_model_process, AudioModelRpcClient from lightllm.utils.graceful_utils import graceful_registry @@ -22,20 +22,22 @@ class AudioManager: def __init__( self, - args, - router_port, - audio_port, - cache_port, + args: StartArgs, infer_batch_size=4, ): context = zmq.asyncio.Context(2) - self.send_to_router = context.socket(zmq.PUSH) - self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") - self.recv_from_visualserver = context.socket(zmq.PULL) - self.recv_from_visualserver.bind(f"{args.zmq_mode}127.0.0.1:{audio_port}") - self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - self.cache_port = cache_port + if args.enable_cpu_cache: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") + else: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") + + self.zmq_recv_socket = context.socket(zmq.PULL) + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") + self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) + self.cache_port = args.cache_port self.waiting_reqs: List[GroupReqIndexes] = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp @@ -90,7 +92,7 @@ async def loop_for_fwd(self): # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 # 需要一些一致的流程来保证不出现异步问题。 - self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) continue multimodal_params = group_req_indexes.multimodal_params @@ -106,24 +108,26 @@ async def loop_for_fwd(self): await self.infer_audios(audios_need_infer) audios_need_infer = [] for _group_req_indexes in processing_group_reqs: - self.send_to_router.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_next_module.send_pyobj( + _group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL + ) processing_group_reqs = [] if len(audios_need_infer) == 0: - self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) else: processing_group_reqs.append(group_req_indexes) if len(audios_need_infer) > 0: await self.infer_audios(audios_need_infer) for _group_req_indexes in processing_group_reqs: - self.send_to_router.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_next_module.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) processing_group_reqs = [] audios_need_infer = [] async def loop_for_netio_req(self): while True: - recv_req: GroupReqIndexes = await self.recv_from_visualserver.recv_pyobj() + recv_req: GroupReqIndexes = await self.zmq_recv_socket.recv_pyobj() if isinstance(recv_req, GroupReqIndexes): self.waiting_reqs.append(recv_req) else: @@ -137,12 +141,12 @@ def clean_up(self): return -def start_audio_process(args, router_port, audio_port, cache_port, pipe_writer): +def start_audio_process(args, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) try: - audioserver = AudioManager(args, router_port, audio_port, cache_port) + audioserver = AudioManager(args=args) asyncio.run(audioserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index 5594df6a0..165e565aa 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -3,3 +3,4 @@ from .shm_req_manager import ShmReqManager from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray from .start_args_type import StartArgs +from .atomic_lock import AtomicShmLock diff --git a/lightllm/server/core/objs/atomic_lock.py b/lightllm/server/core/objs/atomic_lock.py index d324f8c91..8d625cc12 100644 --- a/lightllm/server/core/objs/atomic_lock.py +++ b/lightllm/server/core/objs/atomic_lock.py @@ -1,4 +1,5 @@ import atomics +import time from multiprocessing import shared_memory from lightllm.utils.log_utils import init_logger @@ -40,3 +41,17 @@ def __exit__(self, exc_type, exc_val, exc_tb): while not a.cmpxchg_weak(1, 0): pass return False + + # acquire_sleep1ms 和 release 是某些特定场景下主动使用进行锁获取的操作函数 + def acquire_sleep1ms(self): + with atomics.atomicview(buffer=self.shm.buf, atype=atomics.INT) as a: + while not a.cmpxchg_weak(0, 1): + logger.warning(f"acquire_sleep1ms wait for 1ms") + time.sleep(0.001) + pass + + def release(self): + with atomics.atomicview(buffer=self.shm.buf, atype=atomics.INT) as a: + while not a.cmpxchg_weak(1, 0): + pass + return diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index e4b345063..f07da9704 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -5,9 +5,11 @@ from .sampling_params import SamplingParams from .out_token_circlequeue import CircularQueue from .shm_array import ShmArray +from .token_chunck_hash_list import TokenHashList, CpuCachePageList from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.kv_cache_utils import compute_token_list_hash from typing import List, Any, Union @@ -71,7 +73,8 @@ class Req(ctypes.Structure): # 虽然某种程度上 cur_output_len 也有同样的功能,但是为了避免多进程访问导致的问题,添加 # candetoken_out_len 变量单独传输这个信息。 ("candetoken_out_len", ctypes.c_int), - ("prompt_cache_len", ctypes.c_int), # 用于记录prompt cache 的命中长度,用于统计 + ("prompt_cache_len", ctypes.c_int), # 用于记录prompt cache 的命中长度,用于统计,这里指gpu kv cache命中长度 + ("cpu_prompt_cache_len", ctypes.c_int), # 用于记录在 enable_cpu_cache 的场景下,命中的 cpu kv cache 的长度 ("is_paused", ctypes.c_bool), # 标记一个Req因为显存资源管理的原因被临时暂停了。 ("finish_status", FinishStatus), ("is_aborted", ctypes.c_bool), @@ -97,6 +100,12 @@ class Req(ctypes.Structure): ("mtp_accepted_token_num", ctypes.c_int), # mtp_step 保存一个mtp使用的常量参数,用于快速访问,不会被外部输入初始化 ("_mtp_step", ctypes.c_int), + # 用于在开启cpu cache 或者 硬盘 cache时,预先计算,分块输入token的hash值。 + ("token_hash_list", TokenHashList), + # 用于保存查找匹配到的可以被复用的cpu cache 页面信息。 + ("cpu_cache_match_page_indexes", CpuCachePageList), + # 分块hash的块大小 + ("cpu_cache_token_page_size", ctypes.c_int), ] def get_str(self): @@ -130,6 +139,7 @@ def init( self.shm_cur_output_len = 0 self.candetoken_out_len = 0 self.prompt_cache_len = 0 + self.cpu_prompt_cache_len = 0 self.finish_token_index = -1 self.can_released_mark = False self.reward_score = math.nan @@ -153,10 +163,23 @@ def init( self.post_init() + self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size + if get_env_start_args().enable_cpu_cache: + self._fill_input_token_hash() + return + def post_init(self): # 子类继承进行一些额外的初始化操作 pass + def _fill_input_token_hash(self): + self.token_hash_list = TokenHashList() + self.token_hash_list.clear() + hash_values = compute_token_list_hash(self.get_prompt_ids(),self.cpu_cache_token_page_size) + self.token_hash_list.fill(hash_values) + self.cpu_cache_match_page_indexes = CpuCachePageList() + return + def create_prompt_ids_shm_array(self): service_uni_name = get_unique_server_name() name = f"{service_uni_name}_shm_prompts_{self.index_in_shm_mem}" diff --git a/lightllm/server/core/objs/shm_array.py b/lightllm/server/core/objs/shm_array.py index f88eac26e..2d363ada9 100644 --- a/lightllm/server/core/objs/shm_array.py +++ b/lightllm/server/core/objs/shm_array.py @@ -10,7 +10,7 @@ def __init__(self, name, shape, dtype): self.shm = None self.arr = None self.name = name - self.dtype_byte_num = np.array([1], dtype=dtype).dtype.itemsize + self.dtype_byte_num = np.dtype(dtype=dtype).itemsize self.dest_size = np.prod(shape) * self.dtype_byte_num self.shape = shape self.dtype = dtype diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index d4a205a15..8c1ff8a37 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -96,3 +96,21 @@ class StartArgs: mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) kv_quant_calibration_config_path: Optional[str] = field(default=None) + enable_cpu_cache: bool = field(default=False) + cpu_cache_storage_size: float = field(default=2) + cpu_cache_token_page_size: int = field(default=256) + enable_disk_cache: bool = field(default=False) + disk_cache_storage_size: float = field(default=10) + # zmp ports + router_port: int = field(default=None) + detokenization_port: int = field(default=None) + http_server_port: int = field(default=None) + visual_port: int = field(default=None) + audio_port: int = field(default=None) + cache_port: int = field(default=None) + metric_port: int = field(default=None) + multinode_httpmanager_port: int = field(default=12345) + multi_level_kv_cache_port: int = field(default=None) + # multi_modal + enable_multimodal: bool = field(default=False) + enable_multimodal_audio: bool = field(default=False) diff --git a/lightllm/server/core/objs/token_chunck_hash_list.py b/lightllm/server/core/objs/token_chunck_hash_list.py new file mode 100644 index 000000000..245ca5b98 --- /dev/null +++ b/lightllm/server/core/objs/token_chunck_hash_list.py @@ -0,0 +1,76 @@ +import os +import ctypes +from typing import List +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +LIGHTLLM_TOKEN_HASH_LIST_SIZE = int(os.getenv("LIGHTLLM_TOKEN_HASH_LIST_SIZE", 2048)) + + +class TokenHashList(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("items", ctypes.c_uint64 * LIGHTLLM_TOKEN_HASH_LIST_SIZE), # 元素静态数组 + ("size", ctypes.c_int), # 队列大小 + ] + + def __init__(self): + # 初始化头和尾 + self.size = 0 + return + + def is_empty(self): + return self.size == 0 + + def is_full(self): + return self.size == LIGHTLLM_TOKEN_HASH_LIST_SIZE + + def fill(self, data: List[int]): + if len(data) > LIGHTLLM_TOKEN_HASH_LIST_SIZE: + logger.warning( + f"Queue capcity is smaller than data size ({len(data)} > {LIGHTLLM_TOKEN_HASH_LIST_SIZE}), " + f"remove tail to write" + ) + data = data[0:LIGHTLLM_TOKEN_HASH_LIST_SIZE] + self.items[0 : len(data)] = data + self.size = len(data) + return + + def clear(self): + self.size = 0 + + def get_all(self): + return list(self.items[0 : self.size]) + + +class CpuCachePageList(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("items", ctypes.c_int * LIGHTLLM_TOKEN_HASH_LIST_SIZE), # 元素静态数组 + ("size", ctypes.c_int), # 队列大小 + ] + + def __init__(self): + # 初始化头和尾 + self.size = 0 + return + + def is_empty(self): + return self.size == 0 + + def is_full(self): + return self.size == LIGHTLLM_TOKEN_HASH_LIST_SIZE + + def fill(self, data: List[int]): + assert self.size == 0 + assert len(data) <= LIGHTLLM_TOKEN_HASH_LIST_SIZE + self.items[0 : len(data)] = data + self.size = len(data) + return + + def clear(self): + self.size = 0 + + def get_all(self): + return list(self.items[0 : self.size]) diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index f57eae333..4127b72dc 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -4,7 +4,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import zmq import inspect -from lightllm.server.core.objs import ShmReqManager +from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.utils.graceful_utils import graceful_registry from typing import Union, Dict, List @@ -22,26 +22,20 @@ class DeTokenizationManager: def __init__( self, - args, - eos_id, - model_weightdir, - tokenizor_mode, - detokenization_port, - detokenization_pub_port, - trust_remote_code, + args: StartArgs, ): self.args = args context = zmq.Context(2) - self.recv_from_router = context.socket(zmq.PULL) - self.recv_from_router.bind(f"{args.zmq_mode}127.0.0.1:{detokenization_port}") + self.zmq_recv_socket = context.socket(zmq.PULL) + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.detokenization_port}") self.pub_to_httpserver = context.socket(zmq.PUB) - self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{detokenization_pub_port}") + self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}") - self.tokenizer = get_tokenizer(model_weightdir, tokenizor_mode, trust_remote_code=trust_remote_code) + self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) self.all_special_ids = set(self.tokenizer.all_special_ids) self.req_id_to_out: Dict[int, DecodeReq] = {} - self.eos_id = eos_id + self.eos_id = args.eos_id self._init_get_token_id_to_token_str() self.is_pd_decode_mode = self.args.run_mode == "decode" self.shm_req_manager = ShmReqManager() @@ -78,7 +72,7 @@ def handle_loop(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): - recv_obj: GroupReqIndexes = self.recv_from_router.recv_pyobj(zmq.NOBLOCK) + recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) assert isinstance(recv_obj, GroupReqIndexes) self._add_new_group_req_index(recv_obj=recv_obj) @@ -158,19 +152,13 @@ def remove_finished_reqs(self): return -def start_detokenization_process(args, detokenization_port, detokenization_pub_port, pipe_writer): +def start_detokenization_process(args, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) try: manager = DeTokenizationManager( - args, - args.eos_id, - args.model_dir, - args.tokenizer_mode, - detokenization_port=detokenization_port, - detokenization_pub_port=detokenization_pub_port, - trust_remote_code=args.trust_remote_code, + args=args, ) except Exception as e: pipe_writer.send(str(e)) diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index 557bdcf3b..fc0059206 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -2,6 +2,7 @@ import uuid import inspect from typing import Union, Optional +from lightllm.server.core.objs import StartArgs from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache from rpyc.utils.classic import obtain @@ -49,7 +50,7 @@ def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: return self._impl.get_items_embed(ids) -def start_cache_manager(port: int, args, pipe_writer): +def start_cache_manager(args: StartArgs, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) @@ -57,7 +58,7 @@ def start_cache_manager(port: int, args, pipe_writer): service = CacheServer(manager) from rpyc.utils.server import ThreadedServer - t = ThreadedServer(service, port=port, protocol_config={"allow_pickle": True}) + t = ThreadedServer(service, port=args.cache_port, protocol_config={"allow_pickle": True}) pipe_writer.send("init ok") t.start() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index de552d80c..27c2e23f4 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -20,7 +20,7 @@ from ..multimodal_params import AudioItem, MultimodalParams, ImageItem from ..req_id_generator import ReqIDGenerator from .async_queue import AsyncQueue -from lightllm.server.core.objs import Req, FinishStatus +from lightllm.server.core.objs import Req, FinishStatus, StartArgs from lightllm.server.core.objs import SamplingParams from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE from lightllm.server.core.objs.io_objs import GroupReqObjs @@ -40,18 +40,12 @@ class HttpServerManager: def __init__( self, - args, - router_port, - cache_port, - detokenization_pub_port, - visual_port, - metric_port, - enable_multimodal, + args: StartArgs, ): self.args = args context = zmq.asyncio.Context(2) self.send_to_router = context.socket(zmq.PUSH) - self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") + self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") self.multinode_req_manager = None self.nnodes = args.nnodes @@ -80,17 +74,21 @@ def __init__( f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}" ) - self.enable_multimodal = enable_multimodal + self.enable_multimodal = args.enable_multimodal if self.enable_multimodal: - self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) + self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") + self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + if args.enable_cpu_cache and not self.args.enable_multimodal: + self.send_to_multi_level_kv_cache = context.socket(zmq.PUSH) + self.send_to_multi_level_kv_cache.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") self.shm_req_manager = ShmReqManager() - self.recv_from_detokenization = context.socket(zmq.SUB) - self.recv_from_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_pub_port}") - self.recv_from_detokenization.setsockopt(zmq.SUBSCRIBE, b"") + # recv from detokenization + self.zmq_recv_socket = context.socket(zmq.SUB) + self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") + self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"") self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) @@ -98,7 +96,7 @@ def __init__( self.forwarding_queue: AsyncQueue = None # p d 分离模式使用的转发队列, 需要延迟初始化 self.max_req_total_len = args.max_req_total_len - self.metric_client = MetricClient(metric_port) + self.metric_client = MetricClient(args.metric_port) self.pd_mode: NodeRole = NodeRole(self.args.run_mode) assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL] @@ -437,38 +435,33 @@ async def transfer_to_next_module( group_req_objs: Optional[GroupReqObjs] = None, ): - if self.pd_mode == NodeRole.P: + if self.pd_mode.is_P_or_NORMAL(): if self.enable_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) - else: - self.send_to_router.send_pyobj( + return + + if self.args.enable_cpu_cache: + self.send_to_multi_level_kv_cache.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) - return + return - if self.pd_mode == NodeRole.D: - # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可 self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) return - if self.pd_mode == NodeRole.NORMAL: - if self.enable_multimodal: - self.send_to_visual.send_pyobj( - group_req_objs.to_group_req_index(), - protocol=pickle.HIGHEST_PROTOCOL, - ) - else: - self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), - protocol=pickle.HIGHEST_PROTOCOL, - ) + if self.pd_mode == NodeRole.D: + # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了 + self.send_to_router.send_pyobj( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL, + ) return assert False, "dead code path" @@ -514,6 +507,7 @@ async def _wait_to_token_package( metadata["prompt_ids"] = prompt_ids prompt_cache_len = metadata.pop("prompt_cache_len", 0) + cpu_prompt_cache_len = metadata.pop("cpu_prompt_cache_len", 0) if is_first_token: first_token_cost_ms = (time.time() - start_time) * 1000 is_first_token = False @@ -550,6 +544,7 @@ async def _wait_to_token_package( f"prompt_token_num:{prompt_tokens} " f"prompt_cache_len:{prompt_cache_len} " f"prompt_cache_ratio:{prompt_cache_ratio} " + f"cpu_prompt_cache_len:{cpu_prompt_cache_len} " f"mtp_avg_token_per_step:{mtp_avg_token_per_step} " ) if group_request_id < 0: @@ -641,7 +636,7 @@ async def handle_loop(self): while True: try: - await asyncio.wait_for(self.recv_from_detokenization.recv_pyobj(), timeout=0.05) + await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) except asyncio.TimeoutError: pass @@ -670,6 +665,7 @@ async def handle_loop(self): "special": special, "count_output_tokens": count_output_tokens, "prompt_cache_len": req.prompt_cache_len, + "cpu_prompt_cache_len": req.cpu_prompt_cache_len, "mtp_accepted_token_num": req.mtp_accepted_token_num, } if self.args.return_all_prompt_logprobs: diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 05b2d987c..54635cff4 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -15,7 +15,7 @@ from typing import Union, List, Tuple, Dict from lightllm.server.core.objs import FinishStatus from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType -from lightllm.server.core.objs import SamplingParams +from lightllm.server.core.objs import SamplingParams, StartArgs from ..multimodal_params import MultimodalParams from ..tokenizer import get_tokenizer from ..req_id_generator import ReqIDGenerator, convert_sub_id_to_group_id @@ -32,11 +32,10 @@ class HttpServerManagerForPDMaster: def __init__( self, - args, - metric_port, + args: StartArgs, ): self.args = args - self.metric_client = MetricClient(metric_port) + self.metric_client = MetricClient(args.metric_port) self.id_gen = ReqIDGenerator() self.prefill_nodes: List[PD_Client_Obj] = [] self.decode_nodes: List[PD_Client_Obj] = [] diff --git a/lightllm/server/metrics/manager.py b/lightllm/server/metrics/manager.py index 486c6db2f..69adc1c7e 100644 --- a/lightllm/server/metrics/manager.py +++ b/lightllm/server/metrics/manager.py @@ -7,6 +7,7 @@ import queue from .metrics import Monitor from prometheus_client import generate_latest +from lightllm.server.core.objs import StartArgs from rpyc import SocketStream from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry @@ -133,7 +134,7 @@ def run(self): logger.error(f"monitor error {str(e)}") -def start_metric_manager(port: int, args, pipe_writer): +def start_metric_manager(args: StartArgs, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) @@ -144,6 +145,6 @@ def start_metric_manager(port: int, args, pipe_writer): from rpyc.utils.server import ThreadedServer - t = ThreadedServer(service, port=port) + t = ThreadedServer(service, port=args.metric_port) pipe_writer.send("init ok") t.start() diff --git a/lightllm/server/multi_level_kv_cache/__init__.py b/lightllm/server/multi_level_kv_cache/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/server/multi_level_kv_cache/cpu_cache_client.py b/lightllm/server/multi_level_kv_cache/cpu_cache_client.py new file mode 100644 index 000000000..ec7bba02f --- /dev/null +++ b/lightllm/server/multi_level_kv_cache/cpu_cache_client.py @@ -0,0 +1,275 @@ +import ctypes +import torch +import numpy as np +from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name +from typing import List, Optional, Tuple +from lightllm.utils.log_utils import init_logger +from .shm_objs import ShmDict, ShmLinkedList, _LinkedListItem, IntList +from lightllm.server.core.objs import AtomicShmLock +from lightllm.utils.kv_cache_utils import ( + calcu_cpu_cache_meta, + create_shm_kv_cache_ptr, + attach_shm_kv_cache_ptr, + register_shm_ptr_to_pin, +) + +logger = init_logger(__name__) + + +class CpuKvCacheClient(object): + """ + This class is responsible for handling cpu kv cache meta data. + """ + + def __init__(self, init_shm_data: bool): + self.args = get_env_start_args() + # to do here need calcu from from settings. + self.kv_cache_tensor_meta = calcu_cpu_cache_meta() + self.page_num: int = self.kv_cache_tensor_meta.page_num + self.lock = AtomicShmLock(lock_name=f"{get_unique_server_name()}_cpu_kv_cache_client_lock") + self._create_cpu_status_list(init_shm_data) + if init_shm_data: + self._create_shm_cpu_kv_cache() + else: + self._attach_shm_cpu_kv_cache() + return + + def get_one_empty_page(self, hash_key: int, disk_offload_enable: bool) -> Optional[int]: + assert self.page_hash_dict.get(hash_key) is None + head = self.page_items.head + tail = self.page_items.tail + cur_page: _CpuPageStatus = head.get_next_item() + if cur_page.self_index == tail.self_index: + return None + + if cur_page.can_realloc(disk_offload_enable=disk_offload_enable): + page_index = cur_page.self_index + cur_page.del_self_from_list() + if not cur_page.is_empty(): + self.page_hash_dict.remove(cur_page.hash_key) + cur_page.hash_key = hash_key + cur_page.status = cur_page.LOADING + cur_page.ref_count += 1 + self.page_hash_dict.put(hash_key, page_index) + self.page_items.add_item_to_tail(cur_page.self_index) + return page_index + else: + return None + + def allocate_one_page(self, hash_key: int, disk_offload_enable: bool) -> Tuple[Optional[int], bool]: + page_index = self.page_hash_dict.get(hash_key) + if page_index is not None: + page_item: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + if page_item.is_data_ready(): + page_item.ref_count += 1 + page_item.del_self_from_list() + self.page_items.add_item_to_tail(index=page_index) + return page_index, True + else: + page_item.ref_count += 1 + page_item.del_self_from_list() + self.page_items.add_item_to_tail(index=page_index) + return page_index, False + else: + page_index = self.get_one_empty_page(hash_key=hash_key, disk_offload_enable=disk_offload_enable) + if page_index is not None: + return page_index, False + else: + return None, False + + def allocate_pages(self, hash_keys: List[int], disk_offload_enable: bool) -> Tuple[List[int], List[bool]]: + """ + allocate_pages will add _CpuPageStaus ref_count + """ + page_list = [] + ready_list = [] + for hash_key in hash_keys: + page_index, ready = self.allocate_one_page(hash_key=hash_key, disk_offload_enable=disk_offload_enable) + if page_index is not None: + page_list.append(page_index) + ready_list.append(ready) + else: + page_list.append(-1) + ready_list.append(False) + break + + left_num = len(hash_keys) - len(page_list) + page_list.extend([-1 for _ in range(left_num)]) + ready_list.extend([False for _ in range(left_num)]) + return page_list, ready_list + + def update_pages_status_to_ready(self, page_list: List[int], deref: bool = True, disk_offload_enable: bool = False): + for page_index in page_list: + if page_index != -1: + cur_page: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + if cur_page.status < cur_page.READY: + cur_page.status = cur_page.READY + if disk_offload_enable: + self.offload_page_indexes.add_item(value=cur_page.self_index) + if deref: + assert cur_page.ref_count > 0 + cur_page.ref_count -= 1 + return + + def query_one_page(self, hash_key: int) -> Tuple[Optional[int], bool]: + page_index = self.page_hash_dict.get(hash_key) + if page_index is not None: + page_item: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + if page_item.is_data_ready(): + page_item.ref_count += 1 + # lru 更新 + page_item.del_self_from_list() + self.page_items.add_item_to_tail(index=page_index) + return page_index, True + else: + # lru 更新 + page_item.del_self_from_list() + self.page_items.add_item_to_tail(index=page_index) + return None, False + else: + return None, False + + def check_allpages_ready(self, page_list: List[int]) -> bool: + for page_index in page_list: + if page_index == -1: + continue + page_item: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + if not page_item.is_data_ready(): + return False + return True + + def deref_pages(self, page_list: List[int]): + """ + deref_pages + """ + for page_index in page_list: + if page_index != -1: + self.deref_one_page(page_index=page_index) + return + + def deref_one_page(self, page_index: int): + page_item: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + assert page_item.ref_count > 0 + page_item.ref_count -= 1 + return + + def get_pages_to_offloading(self) -> List[int]: + page_list = self.offload_page_indexes.pop_all_item() + ans_list = [] + if page_list is not None: + for page_index in page_list: + page_item: _CpuPageStatus = self.page_items.get_item_by_index(index=page_index) + if page_item.is_ready(): + page_item.ref_count += 1 + page_item.status = page_item.OFFLOADING + ans_list.append(page_index) + return ans_list + + def update_pages_status_to_ready_recycle(self, page_list: List[int], deref: bool = True): + for page_index in page_list: + if page_index != -1: + cur_page: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + assert cur_page.is_offloading() + cur_page.status = cur_page.READY_RECYCLE + if deref: + assert cur_page.ref_count > 0 + cur_page.ref_count -= 1 + return + + def _create_cpu_status_list(self, init_shm_data: bool): + self.page_items = ShmLinkedList( + name=f"{get_unique_server_name()}_cpu_kv_cache_page_items", + item_class=_CpuPageStatus, + capacity=self.page_num, + init_shm_data=init_shm_data, + ) + self.page_hash_dict = ShmDict( + name=f"{get_unique_server_name()}_cpu_kv_cache_hash", + capacity=self.page_num * 2, + init_shm_data=init_shm_data, + ) + self.offload_page_indexes = IntList( + name=f"{get_unique_server_name()}_cpu_kv_cache_offload_page_indexes", + capacity=self.page_num, + init_shm_data=init_shm_data, + ) + return + + def _create_shm_cpu_kv_cache(self): + shm_ptr = create_shm_kv_cache_ptr() + numpy_array = np.frombuffer( + memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8 + ) + # 将 NumPy 数组转换为 PyTorch 张量 + shape = ( + self.kv_cache_tensor_meta.page_num, + self.kv_cache_tensor_meta.layer_num, + self.kv_cache_tensor_meta.token_page_size, + self.kv_cache_tensor_meta.num_heads, + self.kv_cache_tensor_meta.head_dim, + ) + self.cpu_kv_cache_tensor = torch.from_numpy(numpy_array).view(dtype=torch.bfloat16).view(shape) + return + + def _attach_shm_cpu_kv_cache(self): + shm_ptr = attach_shm_kv_cache_ptr() + device_ptr = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.kv_cache_tensor_meta.calcu_size()) + shape = ( + self.kv_cache_tensor_meta.page_num, + self.kv_cache_tensor_meta.layer_num, + self.kv_cache_tensor_meta.token_page_size, + self.kv_cache_tensor_meta.num_heads, + self.kv_cache_tensor_meta.head_dim, + ) + self.cpu_kv_cache_tensor = torch.empty(size=shape, dtype=torch.bfloat16, device="meta") + # 将指针绑定到 tensor上,方便triton获取真实的地址。 + self.cpu_kv_cache_tensor.data_ptr = lambda: device_ptr + return + + +class _CpuPageStatus(_LinkedListItem): + _pack_ = 4 + _fields_ = [("status", ctypes.c_int), ("ref_count", ctypes.c_int), ("hash_key", ctypes.c_uint64)] + + EMPTY = 0 # 空闲 + LOADING = 1 # 从 gpu buffer 加载到 cpu 的状态,或者是从磁盘加载到 cpu 的状态 + READY = 2 # 数据已经加载到 cpu ok 的状态 + OFFLOADING = 3 # 从 cpu 卸载到 硬盘的状态 + READY_RECYCLE = 4 # 因为卸载到硬盘已经完成,所以可以进行回收使用 + + def __init__(self): + self.init() + + def init(self): + super().init() + self.ref_count = 0 + self.status = self.EMPTY + self.hash_key = 0 + return + + def is_empty(self): + return self.status == self.EMPTY + + def is_loading(self): + return self.status == self.LOADING + + def is_ready(self): + return self.status == self.READY + + def is_offloading(self): + return self.status == self.OFFLOADING + + def is_ready_recycle(self): + return self.status == self.READY_RECYCLE + + def is_data_ready(self): + """ + 判断数据是否是填充ok的,可能包含多种状态下属于数据是可填充的状态。 + """ + return self.status >= self.READY + + def can_realloc(self, disk_offload_enable: bool): + if disk_offload_enable: + return (self.is_empty() or self.is_ready_recycle()) and self.ref_count == 0 + else: + return (self.is_empty() or self.is_data_ready()) and self.ref_count == 0 diff --git a/lightllm/server/multi_level_kv_cache/manager.py b/lightllm/server/multi_level_kv_cache/manager.py new file mode 100644 index 000000000..8fba48809 --- /dev/null +++ b/lightllm/server/multi_level_kv_cache/manager.py @@ -0,0 +1,160 @@ +import uvloop +import asyncio + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +import zmq +import inspect +import pickle +import time +import threading +import concurrent.futures +from queue import Queue +from lightllm.server.core.objs import ShmReqManager, Req, StartArgs +from lightllm.server.core.objs.io_objs import GroupReqIndexes +from lightllm.utils.graceful_utils import graceful_registry +from .cpu_cache_client import CpuKvCacheClient +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MultiLevelKVCacheManager: + def __init__( + self, + args: StartArgs, + ): + self.args: StartArgs = args + context = zmq.Context(2) + self.zmq_recv_socket = context.socket(zmq.PULL) + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") + + self.send_to_router = context.socket(zmq.PUSH) + self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") + logger.info(f"send_to_router sendhwm {self.send_to_router.getsockopt(zmq.SNDHWM)}") + self.cpu_cache_client = CpuKvCacheClient(init_shm_data=True) + self.shm_req_manager = ShmReqManager() + # 控制同时进行cpu cache 匹配操作的数量。 + self.semaphore = threading.Semaphore(3) + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=6) + # 控制 cpu cache time out的时间,如果超过这个时间无法获取信号量则直接转发。 + self.cpu_cache_time_out = 0.3 + self.recv_queue = Queue(maxsize=1024) + self.cpu_cache_thread = threading.Thread(target=self.cpu_cache_hanle_loop, daemon=True) + self.cpu_cache_thread.start() + return + + def cpu_cache_hanle_loop(self): + while True: + try: + current_group_req = self.recv_queue.get() + + self.executor.submit(self._handle_group_req_cpu_cache_match, current_group_req, time.time()) + except BaseException as e: + logger.exception(str(e)) + return + + def _handle_group_req_cpu_cache_match(self, group_req_indexes: GroupReqIndexes, start_time: float): + """ + match cpu cache pages + """ + # 进行超时判定,如果太长时间拿不到信号量,则说明匹配任务繁忙, + # 放弃进行 cpu cache page 的匹配。 + while True: + current_time = time.time() + if current_time - start_time >= self.cpu_cache_time_out: + self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + + if self.semaphore.acquire(blocking=False): + break + time.sleep(0.005) + + reqs_shm_index = group_req_indexes.shm_req_indexes + reqs = [self.shm_req_manager.get_req_obj_by_index(index) for index in reqs_shm_index] + + # 对每个请求进行cpu cache page 的匹配操作。 + for req in reqs: + # diverse_mode 只有主请求一个初始化 cpu cache 信息。 + if self.args.diverse_mode and req.request_id != req.group_req_id: + continue + if req.is_aborted: + continue + + self.cpu_cache_client.lock.acquire_sleep1ms() + req: Req = req + finded_page_indexes = [] + for token_chuncked_hash_value in req.token_hash_list.get_all(): + page_index, ready = self.cpu_cache_client.query_one_page(token_chuncked_hash_value) + if page_index is not None: + assert ready + finded_page_indexes.append(page_index) + else: + break + self.cpu_cache_client.lock.release() + + # 等待所有的cpu cache 页面ready + while not self.cpu_cache_client.check_allpages_ready(finded_page_indexes): + time.sleep(0.01) + + req.cpu_cache_match_page_indexes.fill(finded_page_indexes) + + for req in reqs: + self.shm_req_manager.put_back_req_obj(req) + + # 释放信号量 + self.semaphore.release() + + self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + + def recv_loop(self): + try: + recv_max_count = 128 + + while True: + recv_objs = [] + try: + # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 + for _ in range(recv_max_count): + recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + assert isinstance(recv_obj, GroupReqIndexes) + recv_objs.append(recv_obj) + + start_time = recv_obj.time_mark + logger.info( + f"multi_level_kv_cache recive group req id {recv_obj.group_req_id} " + f"cost time {time.time() - start_time} s" + ) + + # 当队列中存在较多的请求时,将一次接受的数量上调 + recv_max_count = min(int(recv_max_count * 1.3), 256) + except zmq.ZMQError: + # 当队列已经开始清空的时候,将一次接受的数量下调 + recv_max_count = 128 + + for recv_obj in recv_objs: + self.recv_queue.put(recv_obj) + + if len(recv_objs) == 0: + time.sleep(0.01) + + except Exception as e: + logger.exception(f"detoken process has exception {str(e)}") + return + + +def start_multi_level_kv_cache_manager(args, pipe_writer): + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + + try: + manager = MultiLevelKVCacheManager( + args=args, + ) + except Exception as e: + pipe_writer.send(str(e)) + raise + + pipe_writer.send("init ok") + manager.recv_loop() + return diff --git a/lightllm/server/multi_level_kv_cache/shm_objs.py b/lightllm/server/multi_level_kv_cache/shm_objs.py new file mode 100644 index 000000000..2e8892115 --- /dev/null +++ b/lightllm/server/multi_level_kv_cache/shm_objs.py @@ -0,0 +1,270 @@ +import ctypes +import numpy as np +from multiprocessing import shared_memory +from typing import List, Optional +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class IntList(object): + def __init__(self, name: str, capacity: int, init_shm_data: bool): + self.capacity: int = capacity + byte_size = np.dtype(np.int32).itemsize * (self.capacity + 1) + shm_name = name + shm = _create_shm(name=shm_name, byte_size=byte_size) + self.shm = shm + + if self.shm.size != byte_size: + logger.info(f"size not same, unlink lock shm {self.shm.name} and create again") + self.shm.close() + self.shm.unlink() + self.shm = None + self.shm = _create_shm(name=shm_name, byte_size=byte_size) + + self.arr = np.ndarray((self.capacity + 1), dtype=np.int32, buffer=self.shm.buf) + if init_shm_data: + self.arr.fill(0) + return + + def size(self): + return self.arr[-1] + + def add_item(self, value: int): + write_index = self.arr[-1] + self.arr[write_index] = value + self.arr[-1] += 1 + return + + def pop_all_item(self) -> Optional[List[int]]: + if self.size() == 0: + return None + + ans = self.arr[0 : self.size()].tolist() + self.arr[-1] = 0 + return ans + + +class ShmLinkedList(object): + def __init__(self, name: str, item_class: "_LinkedListItem.__class__", capacity: int, init_shm_data: bool): + self.capacity: int = capacity + # add head and tail node. + byte_size = ctypes.sizeof(item_class) * (self.capacity + 2) + shm_name = name + shm = _create_shm(name=shm_name, byte_size=byte_size) + self.shm = shm + + if self.shm.size != byte_size: + logger.info(f"size not same, unlink lock shm {self.shm.name} and create again") + self.shm.close() + self.shm.unlink() + self.shm = None + self.shm = _create_shm(name=shm_name, byte_size=byte_size) + # 构建 hash table 表 + self.linked_items: List[_LinkedListItem] = (item_class * (self.capacity + 2)).from_buffer(self.shm.buf) + # 如果不转变存储,set_list_obj 的对象上绑定的非shm信息在下一次从 shm 中获取对象时将丢失 + self.linked_items = [item for item in self.linked_items] + for e in self.linked_items: + e.set_list_obj(self) + + self.head = self.linked_items[self.capacity] + self.tail = self.linked_items[self.capacity + 1] + + if init_shm_data: + for e in self.linked_items: + e.init() + + self.head.self_index = self.capacity + self.tail.self_index = self.capacity + 1 + self.head.next_index = self.tail.self_index + self.tail.pre_index = self.head.self_index + + for i in range(self.capacity): + item = self.linked_items[i] + item.self_index = i + self.add_item_to_tail(i) + return + + def add_item_to_tail(self, index: int): + item = self.linked_items[index] + pre_node = self.linked_items[self.tail.pre_index] + pre_node.next_index = item.self_index + item.pre_index = pre_node.self_index + item.next_index = self.tail.self_index + self.tail.pre_index = item.self_index + return + + def get_item_by_index(self, index: int) -> "_LinkedListItem": + item = self.linked_items[index] + return item + + def pop_head_item(self) -> "_LinkedListItem": + head_item = self.linked_items[self.head.next_index] + if head_item.self_index == self.tail.self_index: + return None + head_item.del_self_from_list() + return head_item + + +class ShmDict(object): + def __init__(self, name: str, capacity: int, init_shm_data: bool): + self.capacity: int = capacity + self.link_items: ShmLinkedList = ShmLinkedList( + name=name, item_class=_HashLinkItem, capacity=self.capacity * 2, init_shm_data=init_shm_data + ) + # 将前capacity个item,作为hash item的链表头。 + if init_shm_data: + for i in range(self.capacity): + self.link_items.pop_head_item() + item: _HashLinkItem = self.link_items.get_item_by_index(i) + item.pre_index = -1 + item.next_index = -1 + return + + def put(self, key: int, value: int): + dest_index = key % self.capacity + hash_item: _HashLinkItem = self.link_items.get_item_by_index(dest_index) + if hash_item.next_index == -1: # 空的 + add_link_item: _HashLinkItem = self.link_items.pop_head_item() + add_link_item.key = key + add_link_item.value = value + hash_item.next_index = add_link_item.self_index + add_link_item.pre_index = hash_item.self_index + add_link_item.next_index = -1 + return + + # 存在元素,先遍历是否已经存在 + start_link_item: _HashLinkItem = hash_item.get_next_item() + cur_link_item = start_link_item + # 找到对应key的元素,并设置对应的value + while True: + if cur_link_item.key == key: + cur_link_item.value = value + return + else: + next_item = cur_link_item.get_next_item() + if next_item is None: + break + else: + cur_link_item = next_item + + # 没有找到时候,直接插入一个新的节点 + add_link_item: _HashLinkItem = self.link_items.pop_head_item() + add_link_item.key = key + add_link_item.value = value + + cur_link_item.next_index = add_link_item.self_index + add_link_item.pre_index = cur_link_item.self_index + add_link_item.next_index = -1 + return + + def get(self, key: int) -> Optional[int]: + dest_index = key % self.capacity + hash_item: _HashLinkItem = self.link_items.get_item_by_index(dest_index) + if hash_item.next_index == -1: + return None + else: + start_link_item: _HashLinkItem = hash_item.get_next_item() + cur_link_item = start_link_item + # 找到对应key的元素,并设置对应的value + while cur_link_item is not None: + if cur_link_item.key == key: + return cur_link_item.value + else: + cur_link_item = cur_link_item.get_next_item() + return None + + def remove(self, key: int): + dest_index = key % self.capacity + hash_item: _HashLinkItem = self.link_items.get_item_by_index(dest_index) + if hash_item.next_index == -1: + logger.warning(f"shm dict not contain key {key}") + return + + start_link_item: _HashLinkItem = hash_item.get_next_item() + cur_link_item = start_link_item + + # 找到对应key的元素,并设置对应的value + while cur_link_item is not None: + if cur_link_item.key == key: + break + else: + cur_link_item = cur_link_item.get_next_item() + + if cur_link_item is not None: + # remove item + pre_item = cur_link_item.get_pre_item() + pre_item.next_index = cur_link_item.next_index + if cur_link_item.next_index != -1: + next_item = cur_link_item.get_next_item() + next_item.pre_index = pre_item.self_index + + self.link_items.add_item_to_tail(index=cur_link_item.self_index) + else: + logger.warning(f"shm dict not contain key {key}") + return + + +class _LinkedListItem(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("self_index", ctypes.c_int), + ("pre_index", ctypes.c_int), + ("next_index", ctypes.c_int), + ] + + def __init__(self): + self.init() + + def init(self): + self.self_index = -1 + self.pre_index = -1 + self.next_index = -1 + return + + def set_list_obj(self, parent_list: ShmLinkedList): + self.linked_items = parent_list.linked_items + return + + def get_next_item(self) -> "_LinkedListItem": + if self.next_index == -1: + return None + return self.linked_items[self.next_index] + + def get_pre_item(self) -> "_LinkedListItem": + if self.pre_index == -1: + return None + return self.linked_items[self.pre_index] + + def del_self_from_list(self): + pre_node = self.get_pre_item() + next_node = self.get_next_item() + pre_node.next_index = next_node.self_index + next_node.pre_index = pre_node.self_index + return + + +class _HashLinkItem(_LinkedListItem): + _pack_ = 4 + _fields_ = [ + ("key", ctypes.c_uint64), + ("value", ctypes.c_int), + ] + + def __init__(self): + self.init() + + def init(self): + super().init() + self.key = 0 + self.value = -1 + + +def _create_shm(name: str, byte_size: int): + try: + shm = shared_memory.SharedMemory(name=name, create=True, size=byte_size) + logger.info(f"create lock shm {name}") + except: + shm = shared_memory.SharedMemory(name=name, create=False, size=byte_size) + logger.info(f"link lock shm {name}") + return shm diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 65ec4354b..3db6b32b0 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -1,7 +1,7 @@ # Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py import torch import numpy as np -from typing import Tuple, Dict, Set, List +from typing import Tuple, Dict, Set, List, Optional from sortedcontainers import SortedSet from .shared_arr import SharedArray from lightllm.common.mem_manager import MemoryManager @@ -123,16 +123,16 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager: Memo ) self.tree_total_tokens_num.arr[0] = 0 - def insert(self, key, value=None): + def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: if value is None: value = key assert len(key) == len(value) # and len(key) >= 1 if len(key) == 0: - return 0 + return 0, None return self._insert_helper(self.root_node, key, value) - def _insert_helper(self, node: TreeNode, key, value): + def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]: if node.is_leaf(): self.evict_tree_set.discard(node) @@ -147,7 +147,7 @@ def _insert_helper(self, node: TreeNode, key, value): child.update_time() if child.is_leaf(): self.evict_tree_set.add(child) - return prefix_len + return prefix_len, child elif prefix_len < len(key) and prefix_len < len(child.token_id_key): if child.is_leaf(): @@ -167,9 +167,10 @@ def _insert_helper(self, node: TreeNode, key, value): if child.is_leaf(): self.evict_tree_set.add(child) - return prefix_len + return prefix_len, new_node elif prefix_len < len(key) and prefix_len == len(child.token_id_key): - return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) + _prefix_len, ans_node = self._insert_helper(child, key[prefix_len:], value[prefix_len:]) + return prefix_len + _prefix_len, ans_node else: assert False, "can not run to here" @@ -179,7 +180,7 @@ def _insert_helper(self, node: TreeNode, key, value): self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) if new_node.is_leaf(): self.evict_tree_set.add(new_node) - return 0 + return 0, new_node finally: node.update_time() if node.is_leaf(): @@ -313,6 +314,25 @@ def dec_node_ref_counter(self, node: TreeNode): self.evict_tree_set.add(old_node) return + def add_node_ref_counter(self, node: TreeNode): + if node is None: + return + # 如果减引用的是叶节点,需要先从 evict_tree_set 中移除 + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 0: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + node.ref_counter += 1 + node = node.parent + + # 加回。 + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + def get_refed_tokens_num(self): return self.refed_tokens_num.arr[0] diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 24b8a9ddb..16da52175 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -33,7 +33,7 @@ class RouterManager: - def __init__(self, args: StartArgs, router_port, detokenization_port, metric_port): + def __init__(self, args: StartArgs): self.args = args self.model_weightdir = args.model_dir self.world_size = args.tp @@ -70,11 +70,11 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.running_batch: Batch = None context = zmq.Context(2) - self.recv_from_httpserver = context.socket(zmq.PULL) - self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{router_port}") + self.zmq_recv_socket = context.socket(zmq.PULL) + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.router_port}") self.send_to_detokenization = context.socket(zmq.PUSH) - self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_port}") + self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{args.detokenization_port}") if self.is_multinode_tp: self.mulitnode_group = dist.init_process_group( @@ -84,7 +84,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por rank=args.node_rank, ) - self.metric_client = MetricClient(metric_port) + self.metric_client = MetricClient(args.metric_port) self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode"] self.is_pd_decode_mode = self.args.run_mode == "decode" # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 @@ -453,7 +453,7 @@ async def _recv_new_reqs_and_schedule(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): - recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): self._add_req(recv_req) else: @@ -477,7 +477,7 @@ def clean_up(self): return -def start_router_process(args, router_port, detokenization_port, metric_port, pipe_writer): +def start_router_process(args, pipe_writer): # 注册 graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) start_parent_check_thread() @@ -491,10 +491,7 @@ def handle_exception(loop, context): try: router = RouterManager( - args, - router_port=router_port, - detokenization_port=detokenization_port, - metric_port=metric_port, + args=args, ) loop.run_until_complete(router.wait_to_model_ready()) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 67d69aa38..3a2a0eb47 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -1,6 +1,4 @@ -import os -import copy -import time +import enum import torch import torch.distributed as dist import numpy as np @@ -33,10 +31,15 @@ class InferenceContext: vocab_size = None overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 + cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream def register( - self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int + self, backend, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int ): + self.args = get_env_start_args() + from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend + + self.backend: ModeBackend = backend self.req_manager = req_manager self.req_sampling_manager = self.req_manager.req_sampling_params_manager self.radix_cache = radix_cache @@ -54,6 +57,11 @@ def get_overlap_stream(self) -> torch.cuda.Stream: self.overlap_stream = torch.cuda.Stream() return self.overlap_stream + def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream: + if self.cpu_kv_cache_stream is None: + self.cpu_kv_cache_stream = torch.cuda.Stream() + return self.cpu_kv_cache_stream + def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]: req_objs = [] request_ids = [] @@ -100,7 +108,7 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() if is_group_finished: - prefix_len = self.radix_cache.insert(key, value) + prefix_len, _ = self.radix_cache.insert(key, value) old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: @@ -274,6 +282,20 @@ def has_constraint_setting(self) -> bool: class InferReq: + class _CpuCacheTaskStatus(enum.Enum): + NOT_STARTED = 0 + RUNNING = 1 + FINISHED = 2 + + def is_not_started(self): + return self == self.NOT_STARTED + + def is_running(self): + return self == self.RUNNING + + def is_finished(self): + return self == self.FINISHED + def __init__( self, req_id: int, @@ -299,6 +321,10 @@ def __init__( self.need_out_token_id_statistics = True self.out_token_id_count: Dict[int, int] = None + # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache + # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 + self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED + # mtp_step 用来记录一个请求 draft模型每步需要生成的token数量 # 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量 self.mtp_step: int = get_env_start_args().mtp_step diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index fd75afdbf..e01016693 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -31,6 +31,7 @@ from lightllm.server.router.shm_reqs_io_buffer import ShmReqsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +from .multi_level_kv_cache import MultiLevelKvCacheModule class ModeBackend: @@ -156,6 +157,7 @@ def init_model(self, kvargs): self.logger.info(f"loaded model class {self.model.__class__}") g_infer_context.register( + backend=self, req_manager=self.model.req_manager, radix_cache=self.radix_cache, shm_req_manager=self.shm_req_manager, @@ -190,6 +192,9 @@ def init_model(self, kvargs): # 开启 mtp 模式,需要完成mtp model的初始化 if self.args.mtp_mode: self.init_mtp_draft_model(kvargs) + + if self.args.enable_cpu_cache: + self.multi_level_cache_module = MultiLevelKvCacheModule(self) # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 # 可以降低 cpu overhead,大幅提升gpu得使用率。 @@ -322,7 +327,10 @@ def _read_reqs_buffer_and_init_reqs(self): req: InferReq = g_infer_context.requests_mapping[obj.req_id] req.infer_aborted = True else: - self._init_reqs(reqs=cmds) + req_ids = self._init_reqs(reqs=cmds) + if self.args.enable_cpu_cache: + self._fill_cpu_cache_to_reqs(req_ids=req_ids) + return # 一些可以复用的通用功能函数 @@ -344,6 +352,13 @@ def _init_reqs(self, reqs: List[Tuple]): req_ids = [e[0] for e in reqs] return req_ids + def _fill_cpu_cache_to_reqs(self, req_ids): + req_objs: List[InferReq] = [g_infer_context.requests_mapping[req_id] for req_id in req_ids] + g_infer_state_lock.acquire() + self.multi_level_cache_module.fill_cpu_cache_to_reqs(reqs=req_objs) + g_infer_state_lock.release() + return + # 一些可以复用的通用功能函数 def _get_classed_reqs( self, @@ -370,6 +385,8 @@ def _get_classed_reqs( 4. prefill_reqs 需要进行prefill操作的请求 5. decode_reqs 需要进行decode操作的请求 """ + if self.args.enable_cpu_cache and len(g_infer_context.infer_req_ids) > 0: + self.multi_level_cache_module.update_cpu_cache_task_states() if req_ids is None: req_ids = g_infer_context.infer_req_ids @@ -450,8 +467,13 @@ def _get_classed_reqs( g_infer_state_lock.release() self._pre_handle_finished_reqs(finished_reqs=finished_reqs) - g_infer_context.filter_reqs(finished_reqs=finished_reqs) + # 如果使能了 cpu cache 功能,对于已经完成的请求,进行 gpu kv 卸载到 cpu cache的操作。 + if self.args.enable_cpu_cache: + true_finished_reqs = self.multi_level_cache_module.handle_finished_reqs(finished_reqs=finished_reqs) + else: + true_finished_reqs = finished_reqs + g_infer_context.filter_reqs(finished_reqs=true_finished_reqs) g_infer_context.pause_reqs(wait_pause_reqs, is_master_in_dp=self.is_master_in_dp) if recover_paused: diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py index 1ff167960..1e5cccb1f 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py @@ -131,7 +131,7 @@ def _put_kv_received_to_radix_cache(self, group_req_id: int): radix_cache = self.backend.radix_cache key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") value = torch.tensor(fused_token_indexes + move_task.decode_token_indexes, dtype=torch.int64, device="cpu") - prefix_len = radix_cache.insert(key, value) + prefix_len, _ = radix_cache.insert(key, value) assert len(fused_token_indexes) <= prefix_len self.backend.model.mem_manager.free(value[len(fused_token_indexes) : prefix_len]) self.backend.radix_cache.dec_node_ref_counter(tree_node) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py index be9c434d8..d127702b3 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py @@ -74,17 +74,17 @@ def _prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(self, finished_reqs: key = req.get_input_token_ids()[0 : req.cur_kv_len] key = torch.tensor(key, dtype=torch.int64, device="cpu") value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - prefix_len = self.radix_cache.insert(key, value) + prefix_len, new_shared_kv_node = self.radix_cache.insert(key, value) old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len self.model.mem_manager.free( self.model.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len] ) if req.shared_kv_node is not None: + # 将原有共享节点替换为新共享节点,新共享节点对应的长度为当前的cur_kv_len self.radix_cache.dec_node_ref_counter(req.shared_kv_node) - req.shared_kv_node = None - - req.cur_kv_len = 0 - req.shm_req.shm_cur_kv_len = 0 + self.radix_cache.add_node_ref_counter(new_shared_kv_node) + req.shared_kv_node = new_shared_kv_node + assert new_shared_kv_node.node_prefix_total_len == req.cur_kv_len if req.shm_req.sample_params.move_kv_to_decode_node.exists: # 注意兼容纯tp 和 tp dp 混合模式的逻辑 diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py new file mode 100644 index 000000000..605684519 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -0,0 +1,225 @@ +import threading +import torch.distributed as dist +import torch +import dataclasses +from typing import Optional, List, Deque +from collections import deque +from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient +from lightllm.utils.envs_utils import get_env_start_args +from ..infer_batch import InferReq +from lightllm.utils.dist_utils import create_new_group_for_current_dp +from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu, load_cpu_kv_to_gpu +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +class MultiLevelKvCacheModule(object): + def __init__(self, backend): + self.args = get_env_start_args() + from .base_backend import ModeBackend + + self.backend: ModeBackend = backend + self.gloo_group = create_new_group_for_current_dp("gloo") + self.filter_group = create_new_group_for_current_dp("gloo") + self.init_sync_group = create_new_group_for_current_dp("nccl") + dist.barrier(group=self.init_sync_group) + + self.sync_group = create_new_group_for_current_dp("nccl") + dist.barrier(group=self.sync_group) + self.sync_tensor = torch.zeros((1,), dtype=torch.int64, device="cuda") + + + self.cpu_cache_handle_queue: Deque[TransTask] = deque() + self.cpu_cache_client = CpuKvCacheClient(init_shm_data=False) + + def handle_finished_reqs(self, finished_reqs: List[InferReq]) -> List[InferReq]: + """ + 将满足cpu kv cache 卸载条件的请求进行处理,并返回需要真正退出的请求列表。 + """ + + if self.args.enable_cpu_cache: + # 如果开启了cpu cache,将达到finished状态的请求开启将gpu kv cache 卸载到 cpu cache中的操作。 + # 当 kv cache 卸载完成后,才会进行请求的真实退出操作。 + true_finished_reqs = [] + for req in finished_reqs: + # 只有 group_req_id 和 request_id 相同的请求才会被卸载到 cpu cache 中。 + # 这个限制是为了兼容 diverse 模式下的请求处理。 + if req.shm_req.group_req_id != req.shm_req.request_id: + true_finished_reqs.append(req) + continue + + # 过滤不适合进行 kv 卸载到 cpu cache 的请求。 + if req.cur_kv_len < self.args.cpu_cache_token_page_size: + true_finished_reqs.append(req) + continue + + # 如果请求已经完成了 cpu cache 的任务,则满足了退出条件 + if req.cpu_cache_task_status.is_finished(): + true_finished_reqs.append(req) + elif req.cpu_cache_task_status.is_running(): + # 如果请求已经发起过卸载任务,则在当前轮不进行处理 + continue + else: + assert req.cpu_cache_task_status.is_not_started() + # 发起将请求的 kv cache 卸载到 cpu cache 中的任务 + trans_task = self._start_kv_cache_offload_task( + req=req, cpu_kv_cache_stream=g_infer_context.get_cpu_kv_cache_stream() + ) + + if trans_task is not None: + self.cpu_cache_handle_queue.append(trans_task) + else: + true_finished_reqs.append(req) + + return true_finished_reqs + else: + return finished_reqs + + def _start_kv_cache_offload_task( + self, req: InferReq, cpu_kv_cache_stream: torch.cuda.Stream + ) -> Optional["TransTask"]: + with torch.cuda.stream(cpu_kv_cache_stream): + all_token_hash_list = req.shm_req.token_hash_list.get_all() + block_size = req.cur_kv_len // self.args.cpu_cache_token_page_size + move_block_size = min(block_size, len(all_token_hash_list)) + if move_block_size == 0: + req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + return None + if self.backend.is_master_in_dp: + self.cpu_cache_client.lock.acquire_sleep1ms() + page_list, ready_list = self.cpu_cache_client.allocate_pages( + all_token_hash_list[:move_block_size], + disk_offload_enable=self.args.enable_disk_cache, + ) + self.cpu_cache_client.lock.release() + item_size = len(page_list) + dist.broadcast_object_list([item_size], group=self.gloo_group, group_src=0) + if item_size == 0: + req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + return None + dist.broadcast_object_list(page_list, group=self.gloo_group, group_src=0) + dist.broadcast_object_list(ready_list, group=self.gloo_group, group_src=0) + else: + recv_list = [None] + dist.broadcast_object_list(recv_list, group=self.gloo_group, group_src=0) + item_size = recv_list[0] + if item_size == 0: + req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + return None + page_list = [None] * item_size + ready_list = [None] * item_size + dist.broadcast_object_list(page_list, group=self.gloo_group, group_src=0) + dist.broadcast_object_list(ready_list, group=self.gloo_group, group_src=0) + + page_indexes = torch.tensor(page_list, dtype=torch.int32, device="cpu", pin_memory=True) + page_readies = torch.tensor(ready_list, dtype=torch.bool, device="cpu", pin_memory=True) + + token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0 : req.cur_kv_len] + offload_gpu_kv_to_cpu( + token_indexes=token_indexes, + gpu_kv_cache=self.backend.model.mem_manager.kv_buffer, + cpu_kv_cache=self.cpu_cache_client.cpu_kv_cache_tensor, + page_indexes=page_indexes, + page_readies=page_readies, + ) + + # 用一个allreduce 操作和 sync_event 来确保所有gpu worker都完成对cpu kv cache的写入。 + dist.all_reduce(tensor=self.sync_tensor, group=self.sync_group, async_op=False) + sync_event = torch.cuda.Event() + sync_event.record() + req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING + trans_task = TransTask( + page_indexes=page_indexes, page_readies=page_readies, req_obj=req, sync_event=sync_event + ) + + return trans_task + + def update_cpu_cache_task_states(self): + if self.backend.is_master_in_dp: + trans_ok_tasks = [] + while len(self.cpu_cache_handle_queue) != 0: + task: TransTask = self.cpu_cache_handle_queue.popleft() + if task.sync_event.query(): + trans_ok_tasks.append(task) + else: + self.cpu_cache_handle_queue.appendleft(task) + break + item_size = len(trans_ok_tasks) + dist.broadcast_object_list([item_size], group=self.filter_group, group_src=0) + + else: + recv_list = [None] + dist.broadcast_object_list(recv_list, group=self.filter_group, group_src=0) + item_size = recv_list[0] + trans_ok_tasks: List[TransTask] = [self.cpu_cache_handle_queue.popleft() for _ in range(item_size)] + + if item_size > 0: + page_array_list = [task.page_indexes for task in trans_ok_tasks] + page_list = torch.cat(page_array_list, dim=0).tolist() + if self.backend.is_master_in_dp: + self.cpu_cache_client.lock.acquire_sleep1ms() + self.cpu_cache_client.update_pages_status_to_ready( + page_list=page_list, deref=True, disk_offload_enable=self.args.enable_disk_cache + ) + self.cpu_cache_client.lock.release() + for task in trans_ok_tasks: + task.req_obj.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + return + + def fill_cpu_cache_to_reqs(self, reqs: List[InferReq]): + idle_token_num = g_infer_context.get_can_alloc_token_num() + token_page_size = self.args.cpu_cache_token_page_size + all_page_list = [] + is_master_in_dp = self.backend.is_master_in_dp + for req in reqs: + page_list = req.shm_req.cpu_cache_match_page_indexes.get_all() + match_tokens = len(page_list) * token_page_size + # 更新命中的 cpu kv cache 长度. + if is_master_in_dp: + req.shm_req.cpu_prompt_cache_len = match_tokens + + need_token_num = match_tokens - req.cur_kv_len + # 多匹配了一定数量的token 才进行复制操作,不然操作效率不高 + if need_token_num > 256: + if need_token_num <= idle_token_num: + if self.backend.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num) + + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num) + + # 将 cpu page 的内容拷贝到 gpu 页面中 + load_cpu_kv_to_gpu( + mem_indexes=mem_indexes, + gpu_kv_cache=self.backend.model.mem_manager.kv_buffer, + cpu_kv_cache=self.cpu_cache_client.cpu_kv_cache_tensor, + page_indexes=torch.tensor(page_list, dtype=torch.int32, device="cpu").cuda(non_blocking=True), + ) + + torch.cuda.current_stream().synchronize() + + idle_token_num -= need_token_num + g_infer_context.req_manager.req_to_token_indexs[ + req.req_idx, req.cur_kv_len : (req.cur_kv_len + need_token_num) + ] = mem_indexes + req.cur_kv_len = req.cur_kv_len + need_token_num + if self.backend.is_master_in_dp: + req.shm_req.shm_cur_kv_len = req.cur_kv_len + + all_page_list.extend(page_list) + + dist.barrier(group=self.init_sync_group) + + if self.backend.is_master_in_dp: + self.cpu_cache_client.lock.acquire_sleep1ms() + self.cpu_cache_client.deref_pages(page_list=all_page_list) + self.cpu_cache_client.lock.release() + return + + +@dataclasses.dataclass +class TransTask: + page_indexes: torch.Tensor + page_readies: torch.Tensor + req_obj: InferReq + sync_event: torch.cuda.Event diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index abddea356..913e8ab67 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -2,6 +2,7 @@ from typing import List from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue +from lightllm.utils.envs_utils import get_env_start_args class ChunkedBeamContinuesBatchQueue(BaseQueue): @@ -90,8 +91,13 @@ def generate_new_batch(self, current_batch: Batch): new_batch_first_router_need_tokens = 0 # 主要是对 prefill 大块计算时候的token数量限制 aborted_count = 0 cur_group_reqs = [] + # 在开启 cpu cache 功能的情况下,由于multi_level_kv_cache 模块会对请求申请一些cpu kv cache + # 页面,这些页面的释放是在推理进程中完成的,所以如果直接在调度的时候就退出,会导致这些页面无法回收 + # ,所以在使能 cpu cache 的情况下,不在调度的过程中进行 cpu cache页面的释放,而是延迟到推理的 + # 过程中进行回收 + disable_queue_aborted = get_env_start_args().enable_cpu_cache for req in self.waiting_req_list: - if req.is_aborted: + if req.is_aborted and not disable_queue_aborted: aborted_count += 1 abort_req_list.append(req) continue diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 660271ab6..93bf74777 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -3,6 +3,7 @@ from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.common.basemodel.infer_lock import g_router_lock +from lightllm.utils.envs_utils import get_env_start_args class ChunkedPrefillQueue(BaseQueue): @@ -76,9 +77,13 @@ def generate_new_batch(self, current_batch: Batch): aborted_count = 0 waiting_queue = self.waiting_req_list - + # 在开启 cpu cache 功能的情况下,由于multi_level_kv_cache 模块会对请求申请一些cpu kv cache + # 页面,这些页面的释放是在推理进程中完成的,所以如果直接在调度的时候就退出,会导致这些页面无法回收 + # ,所以在使能 cpu cache 的情况下,不在调度的过程中进行 cpu cache页面的释放,而是延迟到推理的 + # 过程中进行回收 + disable_queue_aborted = get_env_start_args().enable_cpu_cache for req in waiting_queue: - if req.is_aborted: + if req.is_aborted and not disable_queue_aborted: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 4a3dec826..bc73adf03 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -7,7 +7,7 @@ import inspect from typing import List from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes -from lightllm.server.core.objs import ShmReqManager +from lightllm.server.core.objs import ShmReqManager, StartArgs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.server.multimodal_params import MultimodalParams, ImageItem @@ -24,20 +24,26 @@ class VisualManager: def __init__( self, - args, - next_module_port, - visual_port, - cache_port, + args: StartArgs, visual_model_rpc_ports, ): context = zmq.Context(2) - self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio) - self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{next_module_port}") - self.recv_from_httpserver = context.socket(zmq.PULL) - self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}") - self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - self.cache_port = cache_port + if args.enable_multimodal_audio: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") + else: + if args.enable_cpu_cache: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") + else: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") + + self.zmq_recv_socket = context.socket(zmq.PULL) + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) + self.cache_port = args.cache_port self.waiting_reqs: List[GroupReqIndexes] = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp @@ -156,7 +162,7 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): self.waiting_reqs.append(recv_req) else: @@ -175,12 +181,12 @@ def clean_up(self): return -def start_visual_process(args, next_module_port, visual_port, cache_port, model_rpc_ports, pipe_writer): +def start_visual_process(args, model_rpc_ports, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) start_parent_check_thread() try: - visualserver = VisualManager(args, next_module_port, visual_port, cache_port, model_rpc_ports) + visualserver = VisualManager(args=args, visual_model_rpc_ports=model_rpc_ports) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py new file mode 100644 index 000000000..fbc2d819f --- /dev/null +++ b/lightllm/utils/kv_cache_utils.py @@ -0,0 +1,184 @@ +import ctypes +import dataclasses +import xxhash +import numpy as np +from functools import lru_cache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger +from lightllm.utils.config_utils import get_config_json +from typing import List, Tuple, Optional + +logger = init_logger(__name__) + + +def compute_token_list_hash(tokens: List[int], cpu_cache_token_page_size: int) -> List[int]: + if len(tokens) == 0: + return [] + + chunks_hash_value = [] + hsum = xxhash.xxh3_64() + + # 计算每个分块的哈希值, 但是输入token需要少一个,因为 + # 如果计算所有的token,会导致输入input_len 命中全长的 + # cpu cache, 导致prefill 过程无法有输入来导出下一个输出。 + calcu_num = (len(tokens) - 1) // cpu_cache_token_page_size + + for i in range(calcu_num): + start_index = i * cpu_cache_token_page_size + end_index = (i + 1) * cpu_cache_token_page_size + chunk = tokens[start_index:end_index] + chunk_np = np.array(chunk, dtype=np.uint64) + hsum.update(chunk_np.tobytes()) + + hash_value = hsum.intdigest() + chunks_hash_value.append(hash_value) + + return chunks_hash_value + + +@lru_cache(maxsize=None) +def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": + args = get_env_start_args() + assert args.enable_cpu_cache + model_config = get_config_json(args.model_dir) + item_size = 2 + head_dim = model_config["hidden_size"] // model_config["num_attention_heads"] + num_key_value_heads = model_config["num_key_value_heads"] * 2 # key and value + layer_num = model_config["num_hidden_layers"] + + one_token_byte_size = layer_num * num_key_value_heads * head_dim * item_size + one_page_byte_size = args.cpu_cache_token_page_size * one_token_byte_size + cpu_cache_page_num = int((args.cpu_cache_storage_size * 1024 * 1024 * 1024) / one_page_byte_size) + + cpu_cache_meta = CpuKVCacheMeta( + page_num=cpu_cache_page_num, + layer_num=layer_num, + token_page_size=args.cpu_cache_token_page_size, + num_heads=num_key_value_heads, + head_dim=head_dim, + item_size=item_size, + ) + + logger.info(f"cpu kv cache page num: {cpu_cache_meta.page_num}") + + return cpu_cache_meta + + +@lru_cache(maxsize=None) +def create_shm_kv_cache_ptr() -> int: + args = get_env_start_args() + + # 加载 libc + libc = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libc.so.6") + + # 设置 shmget 函数的参数类型和返回类型 + libc.shmget.argtypes = (ctypes.c_long, ctypes.c_size_t, ctypes.c_int) + libc.shmget.restype = ctypes.c_int + + # 设置 shmat 函数的参数类型和返回类型 + libc.shmat.argtypes = (ctypes.c_int, ctypes.c_void_p, ctypes.c_int) + libc.shmat.restype = ctypes.c_void_p + + # 创建共享内存 + key = args.cpu_kv_cache_shm_id # 共享内存的键 + size = calcu_cpu_cache_meta().calcu_size() # 共享内存大小 + shmflg = 0o666 | 0o1000 # 权限和 IPC_CREAT 标志 + + shmid = libc.shmget(key, size, shmflg) + + if shmid < 0: + raise Exception("Error creating shared memory") + + logger.info(f"Shared memory ID: {shmid}") + + # 附加共享内存 + shm_addr = libc.shmat(shmid, ctypes.c_void_p(0), 0) + + if shm_addr == ctypes.c_void_p(-1).value: + raise Exception("Error attaching shared memory") + + logger.info(f"Shared memory attached at address: {shm_addr}") + + return shm_addr + + +@lru_cache(maxsize=None) +def attach_shm_kv_cache_ptr() -> int: + """ + Attach to the shared memory segment with the given shmid. + """ + args = get_env_start_args() + libc = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libc.so.6") + + # 设置 shmget 和 shmat 函数的参数类型和返回类型 + libc.shmget.argtypes = (ctypes.c_long, ctypes.c_size_t, ctypes.c_int) + libc.shmget.restype = ctypes.c_int + libc.shmat.argtypes = (ctypes.c_int, ctypes.c_void_p, ctypes.c_int) + libc.shmat.restype = ctypes.c_void_p + + # 通过键获取共享内存 ID + key = args.cpu_kv_cache_shm_id # 共享内存的键 + shmid = libc.shmget(key, 0, 0) + + if shmid < 0: + raise Exception("Error getting shared memory") + + logger.info(f"Shared memory ID: {shmid}") + + # 附加共享内存 + shm_addr = libc.shmat(shmid, ctypes.c_void_p(0), 0) + + if shm_addr == ctypes.c_void_p(-1).value: + raise Exception("Error attaching shared memory") + + logger.info(f"Shared memory attached at address: {shm_addr}") + return shm_addr + + +@dataclasses.dataclass +class CpuKVCacheMeta: + page_num: int + layer_num: int + token_page_size: int + num_heads: int + head_dim: int + item_size: int + + def calcu_size(self): + return self.page_num * self.layer_num * self.token_page_size * self.num_heads * self.head_dim * self.item_size + + +def register_shm_ptr_to_pin(shm_ptr: int, size: int) -> int: + # 加载 CUDA 库 + cuda = ctypes.CDLL("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so") # Linux 下的 CUDA 库路径 + + # 定义 cudaHostRegister 函数的参数和返回类型 + cuda.cudaHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint] + cuda.cudaHostRegister.restype = ctypes.c_int + + # 定义 cudaHostGetDevicePointer 函数原型 + cuda.cudaHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int] + cuda.cudaHostGetDevicePointer.restype = ctypes.c_int + + # 定义常量 + cudaHostRegisterDefault = 0 # 默认注册标志 + + # 调用 cudaHostRegister + result = cuda.cudaHostRegister(shm_ptr, size, cudaHostRegisterDefault) + + if result != 0: + raise Exception(f"Error registering host memory: {result}") + else: + logger.info("Host memory registered successfully.") + + device_ptr = ctypes.c_void_p() # 输出设备指针 + host_ptr = ctypes.c_void_p(shm_ptr) # 输入主机指针 + + result = cuda.cudaHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0) + + if result != 0: + raise RuntimeError(f"cudaHostGetDevicePointer failed with error code {result}") + + logger.info(f"get Host memory registered Device ptr {device_ptr.value}") + + return device_ptr.value diff --git a/test.py b/test.py new file mode 100755 index 000000000..db8f87d5b --- /dev/null +++ b/test.py @@ -0,0 +1,150 @@ +import torch +import numpy as np +from torch.profiler import profile, record_function, ProfilerActivity + +data_o = torch.zeros((128 * 1024), dtype=torch.int32, device="cuda") +in_data = list(range(0, 1000)) +in_datas = [list(range(0, 1000)) for _ in range(100)] +import time + +cpu_tensor = torch.zeros((128 * 1024), dtype=torch.int32, device="cpu", pin_memory=False) +pin_mem_tensor = torch.zeros((128 * 1024), dtype=torch.int32, device="cpu", pin_memory=True) +gpu_tensor = torch.zeros((128 * 1024), dtype=torch.int32, device="cuda") + +a = torch.arange(0, 10).cuda() +b = torch.arange(0, 10).cuda() + +print((gpu_tensor == 1).dtype) +# max_data = tmp.max() +with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + profile_memory=False, + on_trace_ready=torch.profiler.tensorboard_trace_handler("./profile/profile.file"), +) as prof: + # gpu_tensor[:] = pin_mem_tensor + # torch.cuda.current_stream().synchronize() + # a = torch.tensor([1,3, 7], device="cuda") + # gpu_tensor[:] = pin_mem_tensor + for _ in range(100): + cpu_tensor.cuda(non_blocking=True) + +print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=16), flush=True) + + +# CUDA_VISIBLE_DEVICES=4,5,6,7 LOADWORKER=16 python -m lightllm.server.api_server --port 8019 --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ --tp 4 --dp 1 --diverse_mode | tee log.txt + +# CUDA_VISIBLE_DEVICES=4,5,6,7 LOADWORKER=16 python -m lightllm.server.api_server --port 8019 --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ --tp 1 --dp 1 --diverse_mode | tee log.txt 你试试这个 + + +# CUDA_VISIBLE_DEVICES=4,5,6,7 LOADWORKER=16 python -m lightllm.server.api_server --port 8019 --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ --tp 4 --dp 1 --output_constraint_mode xgrammar | tee log.txt + + +pin_mem_tensor.numpy()[0:10] = list(range(10)) + +print("ok") + +# CUDA_VISIBLE_DEVICES=4,5,6,7 LOADWORKER=16 python -m lightllm.server.api_server --port 8019 --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ --tp 8 --dp 8 | tee log.txt 你试试这个 + + + + +# LOADWORKER=16 python -m lightllm.server.api_server --model_dir /mtc/DeepSeek-R1 --mtp_draft_model_dir /mtc/DeepSeek-R1-NextN/ --mtp_mode deepseekv3 --mtp_step 1 --enable_fa3 --graph_max_batch_size 64 --tp 8 --port 15001 | tee debug.txt + + +# LOADWORKER=16 python -m lightllm.server.api_server --model_dir /mtc/DeepSeek-R1 --mtp_draft_model_dir /mtc/DeepSeek-R1-NextN/ --mtp_mode deepseekv3 --mtp_step 1 --enable_fa3 --graph_max_batch_size 64 --tp 8 --port 15001 | tee debug.txt + + +# LOADWORKER=16 python -m lightllm.server.api_server --model_dir /mtc/DeepSeek-R1 --mtp_draft_model_dir /mtc/DeepSeek-R1-NextN/ --mtp_mode deepseekv3 --mtp_step 1 --enable_fa3 --graph_max_batch_size 64 --tp 8 --port 15001 | tee debug.txt + + +# MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8019 \ +# --model_dir /mtc/DeepSeek-R1 \ +# --tp 8 \ +# --dp 8 \ +# --enable_fa3 \ +# --enable_prefill_microbatch_overlap \ +# --enable_decode_microbatch_overlap \ +# --mem_fraction 0.8 \ +# --batch_max_tokens 4096 + + +# MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8019 \ +# --model_dir /mtc/DeepSeek-R1 \ +# --tp 8 \ +# --dp 8 \ +# --enable_fa3 \ +# --mem_fraction 0.8 \ +# --batch_max_tokens 4096 \ +# --mtp_draft_model_dir /mtc/DeepSeek-R1-NextN/ --mtp_mode deepseekv3 --mtp_step 1 + + + +# CUDA_VISIBLE_DEVICES=0,1 LOADWORKER=18 python -m lightllm.server.api_server --port 8019 \ +# --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ \ +# --tp 4 \ +# --enable_fa3 \ +# --nnodes 2 \ +# --node_rank 0 \ +# --nccl_host 127.0.0.1 \ +# --nccl_port 2732 + +# CUDA_VISIBLE_DEVICES=2,3 LOADWORKER=18 python -m lightllm.server.api_server --port 8021 \ +# --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ \ +# --tp 4 \ +# --enable_fa3 \ +# --nnodes 2 \ +# --node_rank 1 \ +# --nccl_host 127.0.0.1 \ +# --nccl_port 2732 + + +# python -m lightllm.server.api_server --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ --run_mode "pd_master" --host 127.0.0.1 --port 60011 + + +# CUDA_VISIBLE_DEVICES=0,1 MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \ +# --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ \ +# --run_mode "prefill" \ +# --tp 2 \ +# --dp 1 \ +# --host 0.0.0.0 \ +# --port 8019 \ +# --nccl_port 2732 \ +# --enable_fa3 \ +# --disable_cudagraph \ +# --pd_master_ip 127.0.0.1 \ +# --pd_master_port 60011 + + +# CUDA_VISIBLE_DEVICES=2,3 MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \ +# --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ \ +# --run_mode "decode" \ +# --tp 2 \ +# --dp 1 \ +# --host 0.0.0.0 \ +# --port 8121 \ +# --nccl_port 27321 \ +# --enable_fa3 \ +# --pd_master_ip 127.0.0.1 \ +# --pd_master_port 60011 + + +# CUDA_VISIBLE_DEVICES=0,1 LOADWORKER=16 python -m lightllm.server.api_server --port 8019 --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ --tp 4 --dp 4 --nccl_port 27321 --node_rank 0 --nnodes 2 | tee log.txt + +# CUDA_VISIBLE_DEVICES=2,3 LOADWORKER=16 python -m lightllm.server.api_server --port 8011 --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ --tp 4 --dp 4 --nccl_port 27321 --node_rank 1 --nnodes 2 + + +# LOADWORKER=16 python -m lightllm.server.api_server --port 8019 --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ --tp 8 --dp 1 --enable_fa3 + + + +# lightllm v1.0.1-4209c8c4-deepep + + +# docker run -itd --gpus all --privileged=true --shm-size=128G -v /mtc:/mtc --name wzj 44feca8a0c86 + + +# CUDA_VISIBLE_DEVICES=2,3 LOADWORKER=16 python -m lightllm.server.api_server --port 8011 --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ --tp 4 --dp 4 --nccl_port 27321 --node_rank 1 --nnodes 2 + + +# LOADWORKER=16 python -m lightllm.server.api_server --port 8011 --model_dir /mtc/niushengxiao/Qwen/Qwen2.5-14B-Instruct/ --tp 1 --dp 1 --nccl_port 27321 --enable_cpu_cache \ No newline at end of file