-
Notifications
You must be signed in to change notification settings - Fork 2k
Feat/support lora cuda graph #7335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -23,7 +23,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch_dtype_to_str, trace_func) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.inputs.multimodal import MultimodalParams | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.logger import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.mapping import Mapping | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.models.modeling_utils import QuantAlgo | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -282,6 +282,15 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attn_backend = pytorch_backend_config.attn_backend | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.lora_manager: Optional[LoraManager] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if lora_config is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.lora_manager = LoraManager() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.lora_prefetch_requests_list = None # TODO smor - fix "LoRARequest" import | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if lora_config is not None and lora_config.lora_request is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.lora_prefetch_requests_list = lora_config.lora_request | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.has_lora_prefetched = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+285
to
+293
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initialize has_lora_prefetched unconditionally; avoid AttributeError. Currently only set when lora_request is not None but later read unguarded. - self.lora_manager: Optional[LoraManager] = None
+ self.lora_manager: Optional[lora_mgr.LoraManager] = None
if lora_config is not None:
- self.lora_manager = LoraManager()
+ self.lora_manager = lora_mgr.LoraManager()
- self.lora_prefetch_requests_list = None # TODO smor - fix "LoRARequest" import
- if lora_config is not None and lora_config.lora_request is not None:
- self.lora_prefetch_requests_list = lora_config.lora_request
- self.has_lora_prefetched = False
+ self.lora_prefetch_requests_list = None # LoRA prefetch requests (bindings executor side)
+ self.has_lora_prefetched = False
+ if lora_config is not None and getattr(lora_config, "lora_request", None):
+ self.lora_prefetch_requests_list = lora_config.lora_request📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model = self._load_model( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mapping=self.mapping, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -427,6 +436,27 @@ def set_lora_model_config(self, lora_target_modules: list[str], | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_size=self.model.config.hidden_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype=torch_dtype_to_str(self.model.config.torch_dtype)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def set_lora_manager_cpp_peft_cache_manager( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, resource_manager: ResourceManager): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cpp_peft_cache_manager = resource_manager.get_resource_manager( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ResourceManagerType.PEFT_CACHE_MANAGER) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if cpp_peft_cache_manager is not None and self.lora_manager is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.lora_manager.set_cpp_peft_cache_manager( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cpp_peft_cache_manager.impl) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+439
to
+446
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Don’t reach into .impl; use the manager’s public API consistently. This helper should set the cpp manager via a public setter, but consumers must then call the Python PeftCacheManager methods, not impl. No change here; see refactor below in _maybe_get_cuda_graph to stop using impl. 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def prefetch_lora_dirs(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.lora_prefetch_requests_list is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for request in self.lora_prefetch_requests_list: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.lora_manager.load_from_ckpt( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [request.path], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_config=self.lora_model_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| runtime_mapping=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uids=[request.adapter_id]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.has_lora_prefetched = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def use_mrope(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_mrope = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -481,6 +511,16 @@ def warmup(self, resource_manager: ResourceManager) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.cuda_graph_dummy_request = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_cuda_graph_warmup_request(batch_size): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_config = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.has_lora_prefetched: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO smor currently I assume a single adapter with uid 0, change this | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uid = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm.bindings import executor as tllm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_config = tllm.LoraConfig( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_id=uid, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights=self.lora_manager.cpp_lora_weights[uid], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config=self.lora_manager.cpp_lora_config[uid]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+514
to
+523
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Avoid shadowing LoraConfig; remove hard-coded uid=0. Use a distinct name for the bindings object and derive uid from prefetch requests. - lora_config = tllm.LoraConfig(
- task_id=uid,
- weights=self.lora_manager.cpp_lora_weights[uid],
- config=self.lora_manager.cpp_lora_config[uid])
+ lora_binding = tllm.LoraConfig(
+ task_id=uid,
+ weights=self.lora_manager.cpp_lora_weights[uid],
+ config=self.lora_manager.cpp_lora_config[uid])And when passing to add_dummy_requests: - lora_request=
- lora_config, # TODO smor- tests assume BS1 then this will be ignored for now, need to resolve
+ lora_request=lora_binding,Also, compute uid: - uid = 0
+ # Prefer the first prefetched adapter id
+ uid = getattr(self.lora_prefetch_requests_list[0], "adapter_id", 0)📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| available_blocks = kv_cache_manager.get_num_free_blocks() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if available_blocks >= batch_size: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result = ScheduledRequests() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -492,6 +532,8 @@ def get_cuda_graph_warmup_request(batch_size): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_gen=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_num_draft_tokens=self.max_draft_len, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_mrope=use_mrope, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_request= | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_config, # TODO smor- tests assume BS1 then this will be ignored for now, need to resolve | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| available_tokens = kv_cache_manager.get_num_available_tokens( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.max_draft_len) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -505,6 +547,7 @@ def get_cuda_graph_warmup_request(batch_size): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_gen=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_num_draft_tokens=self.max_draft_len, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_mrope=use_mrope, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_request=lora_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| )[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # This batch contains both the longest request and the shortest requests, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -863,7 +906,8 @@ def _round_up_batch_size(self, batch_size: int) -> int: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _maybe_get_cuda_graph( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch: ScheduledRequests, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| spec_config: Optional["DecodingBaseConfig"] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| spec_config: Optional["DecodingBaseConfig"] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| resource_manager: Optional[ResourceManager] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Optional[DecodingCUDAGraphRunner]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
906
to
911
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix Ruff F821: DecodingBaseConfig undefined (even in quotes). Gate a type-only import to satisfy static analysis without runtime dep. from typing import Any, Dict, Optional, Tuple
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from ..speculative.decoding_config import DecodingBaseConfig📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.12.2)909-909: Undefined name (F821) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Get a CUDA graph runner or return None (e.g. if CUDA graphs are disabled | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -908,8 +952,57 @@ def _maybe_get_cuda_graph( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| spec_metadata = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_params = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.has_lora_prefetched: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| peft_cache_manager = resource_manager.get_resource_manager( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ResourceManagerType.PEFT_CACHE_MANAGER) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context_requests = batch.context_requests | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_requests = batch.generation_requests | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(context_requests) > 0 and len(generation_requests) > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "SMOR, non empty context and generation requests isn't tested yet" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(context_requests) > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("SMOR, context requests isn't tested yet") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(generation_requests) > 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("SMOR, generation requests isn't tested yet") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_request = generation_requests[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO smor I have no idea why this is happening | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_request.lora_weights = generation_request.lora_weights.reshape( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [1] + list(generation_request.lora_weights.shape)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_request.lora_config = generation_request.lora_config.reshape( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [1] + list(generation_request.lora_config.shape)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| peft_cache_manager.impl.add_request_peft(generation_request, True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| py_lora_task_layer_module_configs = peft_cache_manager.impl.ensure_batch( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context_requests, generation_requests, False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for req in context_requests: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| req. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for req in generation_requests: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| req. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+955
to
+993
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion LoRA PEFT setup inside CUDA-graph path is brittle: direct .impl access, ad-hoc reshapes, and hard errors.
- lora_params = None
-
- if self.has_lora_prefetched:
- peft_cache_manager = resource_manager.get_resource_manager(
- ResourceManagerType.PEFT_CACHE_MANAGER)
-
- context_requests = batch.context_requests
- generation_requests = batch.generation_requests
-
- if len(context_requests) > 0 and len(generation_requests) > 0:
- raise ValueError(
- "SMOR, non empty context and generation requests isn't tested yet"
- )
-
- if len(context_requests) > 0:
- raise ValueError("SMOR, context requests isn't tested yet")
-
- if len(generation_requests) > 1:
- raise ValueError("SMOR, generation requests isn't tested yet")
-
- generation_request = generation_requests[0]
- # TODO smor I have no idea why this is happening
- generation_request.lora_weights = generation_request.lora_weights.reshape(
- [1] + list(generation_request.lora_weights.shape))
- generation_request.lora_config = generation_request.lora_config.reshape(
- [1] + list(generation_request.lora_config.shape))
- peft_cache_manager.impl.add_request_peft(generation_request, True)
-
- py_lora_task_layer_module_configs = peft_cache_manager.impl.ensure_batch(
- context_requests, generation_requests, False)
- for req in context_requests:
- req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[
- req.
- py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None
- for req in generation_requests:
- req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[
- req.
- py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None
-
- # TODO smor - look at get lora params from requests
- # You need something that isn't scheduled requests
- # It also appears that you should make sure resource manager is called, because prefetch
- # has to be added to peftCacheManager as well. So it still shouldn't work
-
- lora_params = self._get_lora_params_from_requests(
- batch, attn_metadata)
- print(f"SMOR, not failed on lora_params in maybe_get_cuda_graph")
+ lora_params = None
+ if self.has_lora_prefetched:
+ peft_cache_manager = resource_manager.get_resource_manager(
+ ResourceManagerType.PEFT_CACHE_MANAGER)
+ if peft_cache_manager is None:
+ logger.debug("LoRA prefetched, but no PEFT cache manager present; skipping LoRA for graphs.")
+ else:
+ # Only generation-only batches are CUDA-graphable today.
+ if len(batch.context_requests) == 0 and len(batch.generation_requests) >= 1:
+ for req in batch.generation_requests:
+ peft_cache_manager.add_request_peft(req)
+ py_cfgs = peft_cache_manager.ensure_batch(
+ batch.context_requests, batch.generation_requests, reset_gpu_cache=False)
+ for req in batch.generation_requests:
+ req.py_lora_task_layer_module_configs = py_cfgs.get(req.py_request_id)
+ lora_params = self._get_lora_params_from_requests(batch, attn_metadata)
+ else:
+ logger.debug("LoRA + CUDA graph currently supports generation-only batches; skipping LoRA params.")📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO smor - look at get lora params from requests | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # You need something that isn't scheduled requests | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # It also appears that you should make sure resource manager is called, because prefetch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # has to be added to peftCacheManager as well. So it still shouldn't work | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_params = self._get_lora_params_from_requests( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch, attn_metadata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"SMOR, not failed on lora_params in maybe_get_cuda_graph") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._cuda_graphs[batch_size] = DecodingCUDAGraphRunner( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_params) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self._cuda_graphs[batch_size] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __del__(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2040,7 +2133,9 @@ def forward( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with self._maybe_pad_batch(scheduled_requests, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kv_cache_manager) as scheduled_requests: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| maybe_graph = self._maybe_get_cuda_graph( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scheduled_requests, spec_config=self.spec_config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scheduled_requests, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| spec_config=self.spec_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| resource_manager=resource_manager) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if maybe_graph is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attn_metadata = maybe_graph.attn_metadata | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.is_spec_decode: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -373,6 +373,7 @@ def add_dummy_requests( | |||||||||||||||||||||||||||||||||||||||
| prepare_resource: bool = True, | ||||||||||||||||||||||||||||||||||||||||
| max_num_draft_tokens: int = 0, | ||||||||||||||||||||||||||||||||||||||||
| use_mrope: bool = False, | ||||||||||||||||||||||||||||||||||||||||
| lora_request=None, | ||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+376
to
377
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Type and API contract for lora_request are unclear; initialize and type it. Document/annotate the expected type (bindings executor LoraConfig-like) and make has_lora_prefetched paths robust to None. Also initialize has_lora_prefetched at engine level to avoid AttributeError elsewhere. Apply: - lora_request=None,
+ lora_request: Optional[object] = None,
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||
| beam_width = 1 # TODO: more than 1 beam? | ||||||||||||||||||||||||||||||||||||||||
| requests = [] | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -389,14 +390,28 @@ def add_dummy_requests( | |||||||||||||||||||||||||||||||||||||||
| # Using 1 instead of 0 prevents NaN during warmup in e.g. Deepseek | ||||||||||||||||||||||||||||||||||||||||
| mrope_position_deltas = torch.zeros( | ||||||||||||||||||||||||||||||||||||||||
| 1, device="cuda", dtype=torch.int32) if use_mrope else None | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| lora_task_id = None | ||||||||||||||||||||||||||||||||||||||||
| lora_weights = None | ||||||||||||||||||||||||||||||||||||||||
| lora_config = None | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if lora_request is not None: | ||||||||||||||||||||||||||||||||||||||||
| # TODO smor currently work with single adapter only, not sure how this should work with request ids | ||||||||||||||||||||||||||||||||||||||||
| lora_task_id = lora_request.task_id | ||||||||||||||||||||||||||||||||||||||||
| lora_weights = lora_request.weights | ||||||||||||||||||||||||||||||||||||||||
| lora_config = lora_request.config | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+394
to
+403
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard against missing fields and wrong shapes on lora_request. Accessing task_id/weights/config blindly risks AttributeError; torch reshape later expects tensors. Add duck-typing checks and early validation. - lora_task_id = None
- lora_weights = None
- lora_config = None
-
- if lora_request is not None:
- # TODO smor currently work with single adapter only, not sure how this should work with request ids
- lora_task_id = lora_request.task_id
- lora_weights = lora_request.weights
- lora_config = lora_request.config
+ lora_task_id = None
+ lora_weights = None
+ lora_config = None
+ if lora_request is not None:
+ # Single-adapter warmup; multi-adapter not yet supported.
+ if not all(hasattr(lora_request, a) for a in ("task_id", "weights", "config")):
+ raise TypeError("lora_request must expose task_id, weights, and config")
+ lora_task_id = int(lora_request.task_id)
+ lora_weights = lora_request.weights
+ lora_config = lora_request.config📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||
| req = LlmRequest(request_id=req_id, | ||||||||||||||||||||||||||||||||||||||||
| max_new_tokens=1, | ||||||||||||||||||||||||||||||||||||||||
| input_tokens=[1] * token_num, | ||||||||||||||||||||||||||||||||||||||||
| sampling_config=SamplingConfig( | ||||||||||||||||||||||||||||||||||||||||
| sampling_params._get_sampling_config()), | ||||||||||||||||||||||||||||||||||||||||
| is_streaming=False, | ||||||||||||||||||||||||||||||||||||||||
| mrope_position_deltas=mrope_position_deltas, | ||||||||||||||||||||||||||||||||||||||||
| encoder_input_tokens=encoder_input_tokens) | ||||||||||||||||||||||||||||||||||||||||
| encoder_input_tokens=encoder_input_tokens, | ||||||||||||||||||||||||||||||||||||||||
| lora_task_id=lora_task_id, | ||||||||||||||||||||||||||||||||||||||||
| lora_weights=lora_weights, | ||||||||||||||||||||||||||||||||||||||||
| lora_config=lora_config) | ||||||||||||||||||||||||||||||||||||||||
| req.is_dummy_request = True | ||||||||||||||||||||||||||||||||||||||||
| req.paged_kv_block_ids = [] | ||||||||||||||||||||||||||||||||||||||||
| if prepare_resource: | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ | |
| from collections import defaultdict | ||
| from dataclasses import dataclass, field | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Dict, List, Optional, Union | ||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
@@ -146,6 +146,7 @@ class LoraConfig(DictConversion): | |
| trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict) | ||
| max_loras: int = 4 | ||
| max_cpu_loras: int = 4 | ||
| lora_request: Optional[List[Any]] = None # TODO smor fix | ||
|
|
||
|
Comment on lines
+149
to
150
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Avoid Any for LoraConfig.lora_request; use typed forward reference to LoRARequest. Prevents loss of type-safety and documents intent, while avoiding import cycles with TYPE_CHECKING. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
@@
if TYPE_CHECKING:
from .runtime import ModelConfig
+ from .executor.request import LoRARequest
@@
- lora_request: Optional[List[Any]] = None # TODO smor fix
+ lora_request: Optional[List["LoRARequest"]] = None
🤖 Prompt for AI Agents |
||
| def __post_init__(self): | ||
| assert self.lora_ckpt_source in ["hf", "nemo"], ( | ||
|
|
@@ -483,6 +484,11 @@ def __init__( | |
| self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu | ||
| self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu | ||
| self.lora_target_modules: List[str] = [] | ||
| self._cpp_peft_cache_manager: Optional[tb_internal.batch_manager.PeftCacheManager] = None | ||
|
|
||
| def set_cpp_peft_cache_manager( | ||
| self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | ||
| ): | ||
| self._cpp_peft_cache_manager = cpp_peft_cache_manager | ||
|
|
||
| def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import pytest | ||
|
|
||
| from tensorrt_llm import LLM | ||
| from tensorrt_llm.llmapi.llm_args import CudaGraphConfig | ||
| from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer | ||
| from tensorrt_llm.sampling_params import SamplingParams | ||
|
|
||
|
|
@@ -430,3 +431,33 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None: | |
| lora_request=lora_requests) | ||
|
|
||
| assert len(outputs) == 2 | ||
|
|
||
|
|
||
| def test_lora_dir_with_graph(): | ||
| lora_req = LoRARequest( | ||
| "task-0", 0, f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1") | ||
|
|
||
| lora_config = LoraConfig( | ||
| lora_dir=[f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"], | ||
| max_lora_rank=8, | ||
| lora_request=[lora_req]) | ||
|
|
||
| llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf", | ||
| lora_config=lora_config, | ||
| cuda_graph_config=CudaGraphConfig(max_batch_size=1)) | ||
| # cuda_graph_config=None) | ||
|
|
||
| prompts = [ | ||
| "美国的首都在哪里? \n答案:", | ||
| ] | ||
| references = [ | ||
| "美国的首都是华盛顿。\n\n美国的", | ||
| ] | ||
| sampling_params = SamplingParams(max_tokens=20) | ||
| lora_request = [lora_req] | ||
|
|
||
| outputs = llm.generate(prompts, sampling_params, lora_request=lora_request) | ||
|
|
||
| assert similar(outputs[0].outputs[0].text, references[0]) | ||
| print(f"lora output: {outputs[0].outputs[0].text}") | ||
| print(f"ref output: {references[0]}") | ||
|
Comment on lines
+436
to
+463
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Ensure resource cleanup, avoid redundant config, and gate by memory.
-@pytest.mark.parametrize(
+# keep above tests unchanged
@@
-def test_lora_dir_with_graph():
+@skip_gpu_memory_less_than_40gb
+def test_lora_dir_with_graph():
@@
- lora_config = LoraConfig(
- lora_dir=[f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"],
- max_lora_rank=8,
- lora_request=[lora_req])
+ lora_config = LoraConfig(
+ lora_dir=[f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"],
+ max_lora_rank=8)
@@
- llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf",
- lora_config=lora_config,
- cuda_graph_config=CudaGraphConfig(max_batch_size=1))
+ llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf",
+ lora_config=lora_config,
+ cuda_graph_config=CudaGraphConfig(max_batch_size=1))
@@
- outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
-
- assert similar(outputs[0].outputs[0].text, references[0])
- print(f"lora output: {outputs[0].outputs[0].text}")
- print(f"ref output: {references[0]}")
+ try:
+ outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
+ assert similar(outputs[0].outputs[0].text, references[0])
+ print(f"lora output: {outputs[0].outputs[0].text}")
+ print(f"ref output: {references[0]}")
+ finally:
+ llm.shutdown()
🤖 Prompt for AI Agents |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Follow import style guideline; avoid symbol imports and name collisions.
Import the module and use qualified names to avoid colliding with bindings’ LoraConfig.
And update usages, e.g.:
📝 Committable suggestion
🤖 Prompt for AI Agents