Skip to content

Commit eb6d588

Browse files
committed
init
Signed-off-by: Superjomn <[email protected]>
1 parent 2206e49 commit eb6d588

File tree

4 files changed

+99
-7
lines changed

4 files changed

+99
-7
lines changed

tensorrt_llm/llmapi/llm.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,21 @@ def __init__(self,
122122
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
123123
self._llm_id = None
124124

125+
log_level = logger.level
126+
logger.set_level("info") # force display the backend
127+
125128
try:
126129
backend = kwargs.get('backend', None)
127-
if backend == 'pytorch':
130+
if backend == "pytorch":
131+
logger.info("Using LLM with PyTorch backend")
128132
llm_args_cls = TorchLlmArgs
129133
elif backend == '_autodeploy':
134+
logger.info("Using LLM with AutoDeploy backend")
130135
from .._torch.auto_deploy.llm_args import \
131136
LlmArgs as AutoDeployLlmArgs
132137
llm_args_cls = AutoDeployLlmArgs
133138
else:
139+
logger.info("Using LLM with TensorRT backend")
134140
llm_args_cls = TrtLlmArgs
135141

136142
# check the kwargs and raise ValueError directly
@@ -160,6 +166,9 @@ def __init__(self,
160166
f"Failed to parse the arguments for the LLM constructor: {e}")
161167
raise e
162168

169+
finally:
170+
logger.set_level(log_level) # restore the log level
171+
163172
print_colored_debug(f"LLM.args.mpi_session: {self.args.mpi_session}\n",
164173
"yellow")
165174
self.mpi_session = self.args.mpi_session

tests/integration/defs/llmapi/_run_llmapi_llm.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
#!/usr/bin/env python3
22
import os
3+
from typing import Optional
34

45
import click
56

6-
from tensorrt_llm._tensorrt_engine import LLM
7-
from tensorrt_llm.llmapi import BuildConfig, SamplingParams
7+
from tensorrt_llm._tensorrt_engine import LLM as TrtLLM
8+
from tensorrt_llm.llmapi import LLM, BuildConfig, SamplingParams
89

910

1011
@click.command()
1112
@click.option("--model_dir", type=str, required=True)
1213
@click.option("--tp_size", type=int, default=1)
1314
@click.option("--engine_dir", type=str, default=None)
14-
def main(model_dir: str, tp_size: int, engine_dir: str):
15+
@click.option("--backend", type=str, default=None)
16+
def main(model_dir: str, tp_size: int, engine_dir: str, backend: Optional[str]):
1517
build_config = BuildConfig()
1618
build_config.max_batch_size = 8
1719
build_config.max_input_len = 256
1820
build_config.max_seq_len = 512
1921

20-
llm = LLM(model_dir,
21-
tensor_parallel_size=tp_size,
22-
build_config=build_config)
22+
backend = backend or "tensorrt"
23+
assert backend in ["pytorch", "tensorrt"]
24+
25+
llm_cls = TrtLLM if backend == "tensorrt" else LLM
26+
27+
kwargs = {} if backend == "pytorch" else {"build_config": build_config}
28+
29+
llm = llm_cls(model_dir, tensor_parallel_size=tp_size, **kwargs)
2330

2431
if engine_dir is not None and os.path.abspath(
2532
engine_dir) != os.path.abspath(model_dir):
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Confirm that the default backend is changed
2+
import os
3+
4+
from defs.common import venv_check_output
5+
6+
from ..conftest import llm_models_root
7+
8+
model_path = llm_models_root() + "/llama-models-v3/llama-v3-8b-instruct-hf"
9+
10+
11+
class TestLlmDefaultBackend:
12+
"""
13+
Check that the default backend is PyTorch for v1.0 breaking change
14+
"""
15+
16+
def test_llm_args_type_default(self, llm_root, llm_venv):
17+
# Keep the complete example code here
18+
from tensorrt_llm.llmapi import LLM, KvCacheConfig, TorchLlmArgs
19+
20+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
21+
llm = LLM(model=model_path, kv_cache_config=kv_cache_config)
22+
23+
# The default backend should be PyTorch
24+
assert llm.args.backend == "pytorch"
25+
assert isinstance(llm.args, TorchLlmArgs)
26+
27+
for output in llm.generate(["Hello, world!"]):
28+
print(output)
29+
30+
def test_llm_args_type_tensorrt(self, llm_root, llm_venv):
31+
# Keep the complete example code here
32+
from tensorrt_llm._tensorrt_engine import LLM
33+
from tensorrt_llm.llmapi import KvCacheConfig, TrtLlmArgs
34+
35+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
36+
37+
llm = LLM(model=model_path, kv_cache_config=kv_cache_config)
38+
39+
# If the backend is TensorRT, the args should be TrtLlmArgs
40+
assert llm.args.backend in ("tensorrt", None)
41+
assert isinstance(llm.args, TrtLlmArgs)
42+
43+
for output in llm.generate(["Hello, world!"]):
44+
print(output)
45+
46+
def test_llm_args_logging(self, llm_root, llm_venv):
47+
# It should print the backend in the log
48+
script_path = os.path.join(os.path.dirname(__file__),
49+
"_run_llmapi_llm.py")
50+
print(f"script_path: {script_path}")
51+
52+
# Test with pytorch backend
53+
pytorch_cmd = [
54+
script_path, "--model_dir", model_path, "--backend", "pytorch"
55+
]
56+
57+
pytorch_output = venv_check_output(llm_venv, pytorch_cmd)
58+
59+
# Check that pytorch backend keyword appears in logs
60+
assert "Using LLM with PyTorch backend" in pytorch_output, f"Expected 'pytorch' in logs, got: {pytorch_output}"
61+
62+
# Test with tensorrt backend
63+
tensorrt_cmd = [
64+
script_path, "--model_dir", model_path, "--backend", "tensorrt"
65+
]
66+
67+
tensorrt_output = venv_check_output(llm_venv, tensorrt_cmd)
68+
69+
# Check that tensorrt backend keyword appears in logs
70+
assert "Using LLM with TensorRT backend" in tensorrt_output, f"Expected 'tensorrt' in logs, got: {tensorrt_output}"

tests/integration/test_lists/qa/llm_function_full.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,3 +677,9 @@ disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyL
677677
# These tests will impact triton. They should be at the end of all tests (https://nvbugs/4904271)
678678
# examples/test_openai.py::test_llm_openai_triton_1gpu
679679
# examples/test_openai.py::test_llm_openai_triton_plugingen_1gpu
680+
681+
# llm-api promote pytorch to default
682+
llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_logging
683+
llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_type_tensorrt
684+
llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_type_default
685+
llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_logging

0 commit comments

Comments
 (0)