diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 241fc1447ce..4901c8027ad 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -34,6 +34,7 @@ def __init__( attn_metadata: AttentionMetadata, spec_metadata: Optional[SpecMetadata] = None, use_mrope: bool = False, + lora_params: Optional[dict] = None, ) -> None: """ Stores a CUDA graph and its associated input buffers. @@ -68,6 +69,7 @@ def __init__( self.attn_metadata = attn_metadata self.spec_metadata = spec_metadata + self.lora_params = lora_params self._output = None self._graph = None self.optional_extra_model_inputs = ["mrope_position_deltas"] @@ -90,6 +92,9 @@ def capture( "mrope_position_deltas": self.mrope_position_deltas, } + if self.lora_params is not None: + inputs["lora_params"] = self.lora_params + # We have to do warm up runs to initialize PyTorch's # internal states according to the docs: # https://pytorch.org/docs/stable/notes/cuda.html#cuda-graph-semantics diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 98eb2e870d4..8fa9db806f2 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -23,7 +23,7 @@ torch_dtype_to_str, trace_func) from tensorrt_llm.inputs.multimodal import MultimodalParams from tensorrt_llm.logger import logger -from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig +from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantAlgo from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2 @@ -282,6 +282,15 @@ def __init__( ) attn_backend = pytorch_backend_config.attn_backend + self.lora_manager: Optional[LoraManager] = None + if lora_config is not None: + self.lora_manager = LoraManager() + + self.lora_prefetch_requests_list = None # TODO smor - fix "LoRARequest" import + if lora_config is not None and lora_config.lora_request is not None: + self.lora_prefetch_requests_list = lora_config.lora_request + self.has_lora_prefetched = False + self.model = self._load_model( model_path, mapping=self.mapping, @@ -427,6 +436,27 @@ def set_lora_model_config(self, lora_target_modules: list[str], hidden_size=self.model.config.hidden_size, dtype=torch_dtype_to_str(self.model.config.torch_dtype)) + def set_lora_manager_cpp_peft_cache_manager( + self, resource_manager: ResourceManager): + cpp_peft_cache_manager = resource_manager.get_resource_manager( + ResourceManagerType.PEFT_CACHE_MANAGER) + if cpp_peft_cache_manager is not None and self.lora_manager is not None: + self.lora_manager.set_cpp_peft_cache_manager( + cpp_peft_cache_manager.impl) + + def prefetch_lora_dirs(self): + if self.lora_prefetch_requests_list is None: + return + + for request in self.lora_prefetch_requests_list: + self.lora_manager.load_from_ckpt( + [request.path], + model_config=self.lora_model_config, + runtime_mapping=None, + uids=[request.adapter_id]) + + self.has_lora_prefetched = True + @property def use_mrope(self): use_mrope = False @@ -481,6 +511,16 @@ def warmup(self, resource_manager: ResourceManager) -> None: self.cuda_graph_dummy_request = None def get_cuda_graph_warmup_request(batch_size): + lora_config = None + if self.has_lora_prefetched: + # TODO smor currently I assume a single adapter with uid 0, change this + uid = 0 + from tensorrt_llm.bindings import executor as tllm + lora_config = tllm.LoraConfig( + task_id=uid, + weights=self.lora_manager.cpp_lora_weights[uid], + config=self.lora_manager.cpp_lora_config[uid]) + available_blocks = kv_cache_manager.get_num_free_blocks() if available_blocks >= batch_size: result = ScheduledRequests() @@ -492,6 +532,8 @@ def get_cuda_graph_warmup_request(batch_size): is_gen=True, max_num_draft_tokens=self.max_draft_len, use_mrope=use_mrope, + lora_request= + lora_config, # TODO smor- tests assume BS1 then this will be ignored for now, need to resolve ) available_tokens = kv_cache_manager.get_num_available_tokens( self.max_draft_len) @@ -505,6 +547,7 @@ def get_cuda_graph_warmup_request(batch_size): is_gen=True, max_num_draft_tokens=self.max_draft_len, use_mrope=use_mrope, + lora_request=lora_config, )[0] # Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case. # This batch contains both the longest request and the shortest requests, @@ -863,7 +906,8 @@ def _round_up_batch_size(self, batch_size: int) -> int: def _maybe_get_cuda_graph( self, batch: ScheduledRequests, - spec_config: Optional["DecodingBaseConfig"] = None + spec_config: Optional["DecodingBaseConfig"] = None, + resource_manager: Optional[ResourceManager] = None ) -> Optional[DecodingCUDAGraphRunner]: """ Get a CUDA graph runner or return None (e.g. if CUDA graphs are disabled @@ -908,8 +952,57 @@ def _maybe_get_cuda_graph( else: spec_metadata = None + lora_params = None + + if self.has_lora_prefetched: + peft_cache_manager = resource_manager.get_resource_manager( + ResourceManagerType.PEFT_CACHE_MANAGER) + + context_requests = batch.context_requests + generation_requests = batch.generation_requests + + if len(context_requests) > 0 and len(generation_requests) > 0: + raise ValueError( + "SMOR, non empty context and generation requests isn't tested yet" + ) + + if len(context_requests) > 0: + raise ValueError("SMOR, context requests isn't tested yet") + + if len(generation_requests) > 1: + raise ValueError("SMOR, generation requests isn't tested yet") + + generation_request = generation_requests[0] + # TODO smor I have no idea why this is happening + generation_request.lora_weights = generation_request.lora_weights.reshape( + [1] + list(generation_request.lora_weights.shape)) + generation_request.lora_config = generation_request.lora_config.reshape( + [1] + list(generation_request.lora_config.shape)) + peft_cache_manager.impl.add_request_peft(generation_request, True) + + py_lora_task_layer_module_configs = peft_cache_manager.impl.ensure_batch( + context_requests, generation_requests, False) + for req in context_requests: + req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[ + req. + py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None + for req in generation_requests: + req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[ + req. + py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None + + # TODO smor - look at get lora params from requests + # You need something that isn't scheduled requests + # It also appears that you should make sure resource manager is called, because prefetch + # has to be added to peftCacheManager as well. So it still shouldn't work + + lora_params = self._get_lora_params_from_requests( + batch, attn_metadata) + print(f"SMOR, not failed on lora_params in maybe_get_cuda_graph") + self._cuda_graphs[batch_size] = DecodingCUDAGraphRunner( - batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope) + batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope, + lora_params) return self._cuda_graphs[batch_size] def __del__(self) -> None: @@ -2040,7 +2133,9 @@ def forward( with self._maybe_pad_batch(scheduled_requests, kv_cache_manager) as scheduled_requests: maybe_graph = self._maybe_get_cuda_graph( - scheduled_requests, spec_config=self.spec_config) + scheduled_requests, + spec_config=self.spec_config, + resource_manager=resource_manager) if maybe_graph is not None: attn_metadata = maybe_graph.attn_metadata if self.is_spec_decode: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e5b302310fc..f6413a4f4a9 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -279,6 +279,9 @@ def __init__(self, self.inflight_req_ids = ReqIdsSet() self.canceled_req_ids = [] + self.model_engine.set_lora_manager_cpp_peft_cache_manager( + self.resource_manager) + self.model_engine.prefetch_lora_dirs() self.model_engine.warmup(self.resource_manager) if self.draft_model_engine is not None: self.draft_model_engine.warmup(self.resource_manager) @@ -316,6 +319,9 @@ def __init__(self, if start_worker: self.start_worker() + def get_lora_manager(self): + return self.model_engine.lora_manager + def _event_loop_wrapper(self): try: with customized_gc_thresholds( diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index ecb58efc25c..cfad07cefea 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -373,6 +373,7 @@ def add_dummy_requests( prepare_resource: bool = True, max_num_draft_tokens: int = 0, use_mrope: bool = False, + lora_request=None, ): beam_width = 1 # TODO: more than 1 beam? requests = [] @@ -389,6 +390,17 @@ def add_dummy_requests( # Using 1 instead of 0 prevents NaN during warmup in e.g. Deepseek mrope_position_deltas = torch.zeros( 1, device="cuda", dtype=torch.int32) if use_mrope else None + + lora_task_id = None + lora_weights = None + lora_config = None + + if lora_request is not None: + # TODO smor currently work with single adapter only, not sure how this should work with request ids + lora_task_id = lora_request.task_id + lora_weights = lora_request.weights + lora_config = lora_request.config + req = LlmRequest(request_id=req_id, max_new_tokens=1, input_tokens=[1] * token_num, @@ -396,7 +408,10 @@ def add_dummy_requests( sampling_params._get_sampling_config()), is_streaming=False, mrope_position_deltas=mrope_position_deltas, - encoder_input_tokens=encoder_input_tokens) + encoder_input_tokens=encoder_input_tokens, + lora_task_id=lora_task_id, + lora_weights=lora_weights, + lora_config=lora_config) req.is_dummy_request = True req.paged_kv_block_ids = [] if prepare_resource: diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index aa793d30ea6..a1d3d216253 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -161,12 +161,7 @@ def _create_engine(): if getattr(executor_config, "backend", "") == "pytorch" and lora_config is not None: - from tensorrt_llm._torch.pyexecutor.resource_manager import \ - ResourceManagerType - peft_cache_manager = self.engine.resource_manager.resource_managers.get( - ResourceManagerType.PEFT_CACHE_MANAGER) - self._lora_manager = LoraManager( - cpp_peft_cache_manager=peft_cache_manager.impl) + self._lora_manager = self.engine.get_lora_manager() lora_model_config = self.engine.model_engine.lora_model_config assert lora_model_config is not None self._lora_model_config = lora_model_config diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 3f87286024b..629807e11b6 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -5,7 +5,7 @@ from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import numpy as np import torch @@ -146,6 +146,7 @@ class LoraConfig(DictConversion): trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict) max_loras: int = 4 max_cpu_loras: int = 4 + lora_request: Optional[List[Any]] = None # TODO smor fix def __post_init__(self): assert self.lora_ckpt_source in ["hf", "nemo"], ( @@ -483,6 +484,11 @@ def __init__( self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu self.lora_target_modules: List[str] = [] + self._cpp_peft_cache_manager: Optional[tb_internal.batch_manager.PeftCacheManager] = None + + def set_cpp_peft_cache_manager( + self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager + ): self._cpp_peft_cache_manager = cpp_peft_cache_manager def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool: diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index dd6d2b4be31..bda384e1748 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -1,6 +1,7 @@ import pytest from tensorrt_llm import LLM +from tensorrt_llm.llmapi.llm_args import CudaGraphConfig from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer from tensorrt_llm.sampling_params import SamplingParams @@ -430,3 +431,33 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None: lora_request=lora_requests) assert len(outputs) == 2 + + +def test_lora_dir_with_graph(): + lora_req = LoRARequest( + "task-0", 0, f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1") + + lora_config = LoraConfig( + lora_dir=[f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"], + max_lora_rank=8, + lora_request=[lora_req]) + + llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf", + lora_config=lora_config, + cuda_graph_config=CudaGraphConfig(max_batch_size=1)) + # cuda_graph_config=None) + + prompts = [ + "美国的首都在哪里? \n答案:", + ] + references = [ + "美国的首都是华盛顿。\n\n美国的", + ] + sampling_params = SamplingParams(max_tokens=20) + lora_request = [lora_req] + + outputs = llm.generate(prompts, sampling_params, lora_request=lora_request) + + assert similar(outputs[0].outputs[0].text, references[0]) + print(f"lora output: {outputs[0].outputs[0].text}") + print(f"ref output: {references[0]}")