diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 321ec11bd75..4f2fbfb015a 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -122,15 +122,21 @@ def __init__(self, self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor) self._llm_id = None + log_level = logger.level + logger.set_level("info") # force display the backend + try: backend = kwargs.get('backend', None) - if backend == 'pytorch': + if backend == "pytorch": + logger.info("Using LLM with PyTorch backend") llm_args_cls = TorchLlmArgs elif backend == '_autodeploy': + logger.info("Using LLM with AutoDeploy backend") from .._torch.auto_deploy.llm_args import \ LlmArgs as AutoDeployLlmArgs llm_args_cls = AutoDeployLlmArgs else: + logger.info("Using LLM with TensorRT backend") llm_args_cls = TrtLlmArgs # check the kwargs and raise ValueError directly @@ -160,6 +166,9 @@ def __init__(self, f"Failed to parse the arguments for the LLM constructor: {e}") raise e + finally: + logger.set_level(log_level) # restore the log level + print_colored_debug(f"LLM.args.mpi_session: {self.args.mpi_session}\n", "yellow") self.mpi_session = self.args.mpi_session diff --git a/tests/integration/defs/llmapi/_run_llmapi_llm.py b/tests/integration/defs/llmapi/_run_llmapi_llm.py index 854af24efa7..14dde170777 100644 --- a/tests/integration/defs/llmapi/_run_llmapi_llm.py +++ b/tests/integration/defs/llmapi/_run_llmapi_llm.py @@ -1,25 +1,32 @@ #!/usr/bin/env python3 import os +from typing import Optional import click -from tensorrt_llm._tensorrt_engine import LLM -from tensorrt_llm.llmapi import BuildConfig, SamplingParams +from tensorrt_llm._tensorrt_engine import LLM as TrtLLM +from tensorrt_llm.llmapi import LLM, BuildConfig, SamplingParams @click.command() @click.option("--model_dir", type=str, required=True) @click.option("--tp_size", type=int, default=1) @click.option("--engine_dir", type=str, default=None) -def main(model_dir: str, tp_size: int, engine_dir: str): +@click.option("--backend", type=str, default=None) +def main(model_dir: str, tp_size: int, engine_dir: str, backend: Optional[str]): build_config = BuildConfig() build_config.max_batch_size = 8 build_config.max_input_len = 256 build_config.max_seq_len = 512 - llm = LLM(model_dir, - tensor_parallel_size=tp_size, - build_config=build_config) + backend = backend or "tensorrt" + assert backend in ["pytorch", "tensorrt"] + + llm_cls = TrtLLM if backend == "tensorrt" else LLM + + kwargs = {} if backend == "pytorch" else {"build_config": build_config} + + llm = llm_cls(model_dir, tensor_parallel_size=tp_size, **kwargs) if engine_dir is not None and os.path.abspath( engine_dir) != os.path.abspath(model_dir): diff --git a/tests/integration/defs/llmapi/test_llm_api_qa.py b/tests/integration/defs/llmapi/test_llm_api_qa.py new file mode 100644 index 00000000000..def4be0895c --- /dev/null +++ b/tests/integration/defs/llmapi/test_llm_api_qa.py @@ -0,0 +1,70 @@ +# Confirm that the default backend is changed +import os + +from defs.common import venv_check_output + +from ..conftest import llm_models_root + +model_path = llm_models_root() + "/llama-models-v3/llama-v3-8b-instruct-hf" + + +class TestLlmDefaultBackend: + """ + Check that the default backend is PyTorch for v1.0 breaking change + """ + + def test_llm_args_type_default(self, llm_root, llm_venv): + # Keep the complete example code here + from tensorrt_llm.llmapi import LLM, KvCacheConfig, TorchLlmArgs + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + llm = LLM(model=model_path, kv_cache_config=kv_cache_config) + + # The default backend should be PyTorch + assert llm.args.backend == "pytorch" + assert isinstance(llm.args, TorchLlmArgs) + + for output in llm.generate(["Hello, world!"]): + print(output) + + def test_llm_args_type_tensorrt(self, llm_root, llm_venv): + # Keep the complete example code here + from tensorrt_llm._tensorrt_engine import LLM + from tensorrt_llm.llmapi import KvCacheConfig, TrtLlmArgs + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + + llm = LLM(model=model_path, kv_cache_config=kv_cache_config) + + # If the backend is TensorRT, the args should be TrtLlmArgs + assert llm.args.backend in ("tensorrt", None) + assert isinstance(llm.args, TrtLlmArgs) + + for output in llm.generate(["Hello, world!"]): + print(output) + + def test_llm_args_logging(self, llm_root, llm_venv): + # It should print the backend in the log + script_path = os.path.join(os.path.dirname(__file__), + "_run_llmapi_llm.py") + print(f"script_path: {script_path}") + + # Test with pytorch backend + pytorch_cmd = [ + script_path, "--model_dir", model_path, "--backend", "pytorch" + ] + + pytorch_output = venv_check_output(llm_venv, pytorch_cmd) + + # Check that pytorch backend keyword appears in logs + assert "Using LLM with PyTorch backend" in pytorch_output, f"Expected 'pytorch' in logs, got: {pytorch_output}" + + # Test with tensorrt backend + tensorrt_cmd = [ + script_path, "--model_dir", model_path, "--backend", "tensorrt" + ] + + tensorrt_output = venv_check_output(llm_venv, tensorrt_cmd) + + # Check that tensorrt backend keyword appears in logs + assert "Using LLM with TensorRT backend" in tensorrt_output, f"Expected 'tensorrt' in logs, got: {tensorrt_output}" diff --git a/tests/integration/test_lists/qa/llm_function_full.txt b/tests/integration/test_lists/qa/llm_function_full.txt index 21d9f013755..05ee1a0f054 100644 --- a/tests/integration/test_lists/qa/llm_function_full.txt +++ b/tests/integration/test_lists/qa/llm_function_full.txt @@ -677,3 +677,9 @@ disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyL # These tests will impact triton. They should be at the end of all tests (https://nvbugs/4904271) # examples/test_openai.py::test_llm_openai_triton_1gpu # examples/test_openai.py::test_llm_openai_triton_plugingen_1gpu + +# llm-api promote pytorch to default +llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_logging +llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_type_tensorrt +llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_type_default +llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_logging