|
23 | 23 | KvCacheConfig, MoeConfig, MTPDecodingConfig,
|
24 | 24 | NGramDecodingConfig, SamplingParams,
|
25 | 25 | TorchCompileConfig)
|
| 26 | +from defs.common import generate_dummy_loras |
| 27 | +from tensorrt_llm.lora_manager import LoraConfig |
26 | 28 | from tensorrt_llm.quantization import QuantAlgo
|
27 | 29 |
|
28 | 30 | from ..conftest import (llm_models_root, parametrize_with_ids, skip_no_hopper,
|
@@ -590,6 +592,37 @@ def test_auto_dtype_chunked_prefill(self):
|
590 | 592 | task = GSM8K(self.MODEL_NAME)
|
591 | 593 | task.evaluate(llm)
|
592 | 594 |
|
| 595 | + # This is a smoke test to make sure LoRA works. |
| 596 | + def test_lora(self): |
| 597 | + model_path = f"{llm_models_root()}/gemma/gemma-3-1b-it/" |
| 598 | + lora_rank = 32 |
| 599 | + num_loras = 1 |
| 600 | + print(f"Generating {num_loras} dummy LoRAs with rank {lora_rank}...") |
| 601 | + lora_output_dirs = generate_dummy_loras( |
| 602 | + hf_model_dir=model_path, |
| 603 | + lora_output_dir="/tmp/lora_output", |
| 604 | + num_loras=num_loras, |
| 605 | + lora_rank=lora_rank, |
| 606 | + target_modules=["q_proj", "k_proj", "v_proj"], # "gate_proj", "down_proj", "up_proj"], |
| 607 | + zero_weights=True, |
| 608 | + ) |
| 609 | + print("lora_output_dirs: ", lora_output_dirs) |
| 610 | + lora_config = LoraConfig( |
| 611 | + lora_dir=lora_output_dirs, |
| 612 | + lora_ckpt_source="hf", |
| 613 | + max_lora_rank=lora_rank, |
| 614 | + lora_target_modules=['attn_q', 'attn_k', 'attn_v'], # "mlp_h_to_4h", "mlp_4h_to_h", "mlp_gate"], |
| 615 | + max_loras=num_loras, |
| 616 | + max_cpu_loras=num_loras, |
| 617 | + ) |
| 618 | + # Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size. |
| 619 | + kv_cache_config = KvCacheConfig( |
| 620 | + enable_block_reuse=False, |
| 621 | + enable_partial_reuse=False, |
| 622 | + ) |
| 623 | + with LLM(model_path, lora_config=lora_config, enable_lora=True, kv_cache_config=kv_cache_config) as llm: |
| 624 | + task = GSM8K(self.MODEL_NAME) |
| 625 | + task.evaluate(llm) |
593 | 626 |
|
594 | 627 | class TestMixtral8x7B(LlmapiAccuracyTestHarness):
|
595 | 628 | MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1"
|
|
0 commit comments