33import functools
44import logging
55import os
6+ import shutil
67import subprocess
78import sys
89from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union , overload
@@ -1007,19 +1008,98 @@ def args_bounds_check(
10071008 return args [i ] if len (args ) > i and args [i ] is not None else replacement
10081009
10091010
1011+ def install_wget (platform : str ) -> None :
1012+ if shutil .which ("wget" ):
1013+ _LOGGER .debug ("wget is already installed" )
1014+ return
1015+ if platform .startswith ("linux" ):
1016+ try :
1017+ # if its root
1018+ if os .geteuid () == 0 :
1019+ subprocess .run (["apt-get" , "update" ], check = True )
1020+ subprocess .run (["apt-get" , "install" , "-y" , "wget" ], check = True )
1021+ else :
1022+ _LOGGER .debug ("Please run with sudo permissions" )
1023+ subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
1024+ subprocess .run (["sudo" , "apt-get" , "install" , "-y" , "wget" ], check = True )
1025+ except subprocess .CalledProcessError as e :
1026+ _LOGGER .debug ("Error installing wget:" , e )
1027+
1028+
1029+ def install_mpi (platform : str ) -> None :
1030+ if platform .startswith ("linux" ):
1031+ try :
1032+ # if its root
1033+ if os .geteuid () == 0 :
1034+ subprocess .run (["apt-get" , "update" ], check = True )
1035+ subprocess .run (["apt-get" , "install" , "-y" , "libmpich-dev" ], check = True )
1036+ subprocess .run (
1037+ ["apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
1038+ )
1039+ else :
1040+ _LOGGER .debug ("Please run with sudo permissions" )
1041+ subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
1042+ subprocess .run (
1043+ ["sudo" , "apt-get" , "install" , "-y" , "libmpich-dev" ], check = True
1044+ )
1045+ subprocess .run (
1046+ ["sudo" , "apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
1047+ )
1048+ except subprocess .CalledProcessError as e :
1049+ _LOGGER .debug ("Error installing mpi libs:" , e )
1050+
1051+
1052+ def download_plugin_lib_path (py_version : str , platform : str ) -> str :
1053+ plugin_lib_path = None
1054+ if py_version not in ("cp310" , "cp312" ):
1055+ _LOGGER .warning (
1056+ "No available wheel for python versions other than py3.10 and py3.12"
1057+ )
1058+ install_wget (platform )
1059+ base_url = "https://pypi.nvidia.com/tensorrt-llm/"
1060+ file_name = f"tensorrt_llm-0.17.0.post1-{ py_version } -{ py_version } -{ platform } .whl"
1061+ download_url = base_url + file_name
1062+ cmd = ["wget" , download_url ]
1063+ try :
1064+ if not (os .path .exists (file_name )):
1065+ _LOGGER .info (f"Running command: { ' ' .join (cmd )} " )
1066+ subprocess .run (cmd )
1067+ _LOGGER .info ("Download complete of wheel" )
1068+ if os .path .exists (file_name ):
1069+ _LOGGER .info ("filename now present" )
1070+ if os .path .exists ("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" ):
1071+ plugin_lib_path = (
1072+ "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1073+ )
1074+ else :
1075+ import zipfile
1076+
1077+ with zipfile .ZipFile (file_name , "r" ) as zip_ref :
1078+ zip_ref .extractall ("." ) # Extract to a folder named 'tensorrt_llm'
1079+ plugin_lib_path = (
1080+ "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1081+ )
1082+ except subprocess .CalledProcessError as e :
1083+ _LOGGER .debug (f"Error occurred while trying to download: { e } " )
1084+ except Exception as e :
1085+ _LOGGER .debug (f"An unexpected error occurred: { e } " )
1086+ return plugin_lib_path
1087+
1088+
10101089def load_tensorrt_llm () -> bool :
10111090 """
10121091 Attempts to load the TensorRT-LLM plugin and initialize it.
10131092
10141093 Returns:
10151094 bool: True if the plugin was successfully loaded and initialized, False otherwise.
10161095 """
1017-
10181096 plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
10191097 if not plugin_lib_path :
10201098 _LOGGER .warning (
10211099 "Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library" ,
10221100 )
1101+ for key , value in os .environ .items ():
1102+ print (f"{ key } : { value } " )
10231103 use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
10241104 "1" ,
10251105 "true" ,
@@ -1034,38 +1114,12 @@ def load_tensorrt_llm() -> bool:
10341114 else :
10351115 py_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
10361116 platform = Platform .current_platform ()
1037- if Platform == Platform .LINUX_X86_64 :
1038- platform = "linux_x86_64"
1039- elif Platform == Platform .LINUX_AARCH64 :
1040- platform = "linux_aarch64"
1041-
1042- if py_version not in ("cp310" , "cp312" ):
1043- _LOGGER .warning (
1044- "No available wheel for python versions other than py3.10 and py3.12"
1045- )
1046- if py_version == "cp310" and platform == "linux_aarch64" :
1047- _LOGGER .warning ("No available wheel for python3.10 with Linux aarch64" )
10481117
1049- base_url = "https://pypi.nvidia.com/tensorrt-llm/"
1050- file_name = (
1051- "tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl"
1052- )
1053- download_url = base_url + file_name
1054- cmd = ["wget" , download_url ]
1055- subprocess .run (cmd )
1056- if os .path .exists (file_name ):
1057- _LOGGER .info ("filename download is completed" )
1058- import zipfile
1059-
1060- with zipfile .ZipFile (file_name , "r" ) as zip_ref :
1061- zip_ref .extractall (
1062- "./tensorrt_llm"
1063- ) # Extract to a folder named 'tensorrt_llm'
1064- plugin_lib_path = (
1065- "./tensorrt_llm" + "libnvinfer_plugin_tensorrt_llm.so"
1066- )
1118+ platform = str (platform ).lower ()
1119+ plugin_lib_path = download_plugin_lib_path (py_version , platform )
10671120 try :
1068- # Load the shared library
1121+ # Load the shared
1122+ install_mpi (platform )
10691123 handle = ctypes .CDLL (plugin_lib_path )
10701124 _LOGGER .info (f"Successfully loaded plugin library: { plugin_lib_path } " )
10711125 except OSError as e_os_error :
0 commit comments