Skip to content

Commit 76f7068

Browse files
committed
adding SBSA to the supported platform of TRT-LLM libs and installing MPI libs for the distributed tests
1 parent 43b5ade commit 76f7068

File tree

5 files changed

+12
-17
lines changed

5 files changed

+12
-17
lines changed

.github/scripts/install-mpi-linux-x86.sh

Lines changed: 0 additions & 4 deletions
This file was deleted.

.github/scripts/install-torch-tensorrt.sh

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,6 @@ if [[ $(uname -m) == "aarch64" ]]; then
1212
install_cuda_aarch64
1313
fi
1414

15-
if [[ "$(uname -s)" == "Linux" && "$(uname -m)" == "x86_64" ]]; then
16-
# install MPI for Linux x86_64
17-
source .github/scripts/install-mpi-linux-x86.sh
18-
install_mpi_linux_x86
19-
fi
20-
2115
# Install all the dependencies required for Torch-TensorRT
2216
pip install --pre -r ${PWD}/tests/py/requirements.txt
2317
# dependencies in the tests/py/requirements.txt might install a different version of torch or torchvision

.github/workflows/build-test-linux-x86_64.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ jobs:
363363
export USE_HOST_DEPS=1
364364
export CI_BUILD=1
365365
export USE_TRTLLM_PLUGINS=1
366+
dnf install -y mpich mpich-devel openmpi openmpi-devel
366367
pushd .
367368
cd tests/py
368369
cd dynamo

py/torch_tensorrt/dynamo/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -847,18 +847,21 @@ def is_platform_supported_for_trtllm(platform: str) -> bool:
847847
Note:
848848
TensorRT-LLM plugins for NCCL backend are not supported on:
849849
- Windows platforms
850-
- Jetson devices (aarch64 architecture)
850+
- Orin, Xavier, or Tegra devices (aarch64 architecture)
851+
851852
"""
852853
if "windows" in platform:
853854
logger.info(
854855
"TensorRT-LLM plugins for NCCL backend are not supported on Windows"
855856
)
856857
return False
857-
if "aarch64" in platform:
858+
if torch.cuda.is_available():
859+
device_name = torch.cuda.get_device_name().lower()
860+
if any(keyword in device_name for keyword in ["orin", "xavier", "tegra"]):
861+
return False
858862
logger.info(
859-
"TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices (aarch64)"
863+
"TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices"
860864
)
861-
return False
862865
return True
863866

864867

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from parameterized import parameterized
1010
from torch.testing._internal.common_utils import run_tests
1111
from torch_tensorrt._enums import Platform
12+
from torch_tensorrt.dynamo.utils import is_platform_supported_for_trtllm
1213

1314

1415
class DistributedGatherModel(nn.Module):
@@ -44,10 +45,10 @@ def forward(self, x):
4445
platform_str = str(Platform.current_platform()).lower()
4546

4647

47-
class TestGatherNcclOpsConverter(DispatchTestCase):
48+
class TestNcclOpsConverter(DispatchTestCase):
4849
@unittest.skipIf(
49-
"win" in platform_str or "aarch64" in platform_str,
50-
"Skipped on Windows and Jetson: NCCL backend is not supported.",
50+
not is_platform_supported_for_trtllm(platform_str),
51+
"Skipped on Windows, Jetson: NCCL backend is not supported.",
5152
)
5253
@classmethod
5354
def setUpClass(cls):

0 commit comments

Comments
 (0)