Skip to content

Commit cb175f1

Browse files
committed
save initial changes for test
1 parent 2bbe9b6 commit cb175f1

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
KvCacheConfig, MoeConfig, MTPDecodingConfig,
2424
NGramDecodingConfig, SamplingParams,
2525
TorchCompileConfig)
26+
from defs.common import generate_dummy_loras
27+
from tensorrt_llm.lora_manager import LoraConfig
2628
from tensorrt_llm.quantization import QuantAlgo
2729

2830
from ..conftest import (llm_models_root, parametrize_with_ids, skip_no_hopper,
@@ -590,6 +592,37 @@ def test_auto_dtype_chunked_prefill(self):
590592
task = GSM8K(self.MODEL_NAME)
591593
task.evaluate(llm)
592594

595+
# This is a smoke test to make sure LoRA works.
596+
def test_lora(self):
597+
model_path = f"{llm_models_root()}/gemma/gemma-3-1b-it/"
598+
lora_rank = 32
599+
num_loras = 1
600+
print(f"Generating {num_loras} dummy LoRAs with rank {lora_rank}...")
601+
lora_output_dirs = generate_dummy_loras(
602+
hf_model_dir=model_path,
603+
lora_output_dir="/tmp/lora_output",
604+
num_loras=num_loras,
605+
lora_rank=lora_rank,
606+
target_modules=["q_proj", "k_proj", "v_proj"], # "gate_proj", "down_proj", "up_proj"],
607+
zero_weights=True,
608+
)
609+
print("lora_output_dirs: ", lora_output_dirs)
610+
lora_config = LoraConfig(
611+
lora_dir=lora_output_dirs,
612+
lora_ckpt_source="hf",
613+
max_lora_rank=lora_rank,
614+
lora_target_modules=['attn_q', 'attn_k', 'attn_v'], # "mlp_h_to_4h", "mlp_4h_to_h", "mlp_gate"],
615+
max_loras=num_loras,
616+
max_cpu_loras=num_loras,
617+
)
618+
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
619+
kv_cache_config = KvCacheConfig(
620+
enable_block_reuse=False,
621+
enable_partial_reuse=False,
622+
)
623+
with LLM(model_path, lora_config=lora_config, enable_lora=True, kv_cache_config=kv_cache_config) as llm:
624+
task = GSM8K(self.MODEL_NAME)
625+
task.evaluate(llm)
593626

594627
class TestMixtral8x7B(LlmapiAccuracyTestHarness):
595628
MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1"

0 commit comments

Comments
 (0)