Skip to content

Commit b4b1185

Browse files
chzblychyiqingy0
andauthored
[https://nvbugs/5450855][fix] Cherry pick #6700 and #6702 from main (#6808)
Signed-off-by: Yiqing Yan <[email protected]> Signed-off-by: Yanchao Lu <[email protected]> Co-authored-by: Yiqing Yan <[email protected]>
1 parent 751d5f1 commit b4b1185

File tree

16 files changed

+76
-20
lines changed

16 files changed

+76
-20
lines changed

cpp/kernels/fmha_v2/fmha_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import subprocess
22

33
import pytest
4-
from cuda import cuda, nvrtc
4+
5+
try:
6+
from cuda.bindings import driver as cuda
7+
from cuda.bindings import nvrtc
8+
except ImportError:
9+
from cuda import cuda, nvrtc
510

611

712
def ASSERT_DRV(err):

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
accelerate>=0.25.0
44
build
55
colored
6-
cuda-python # Do not override the custom version of cuda-python installed in the NGC PyTorch image.
6+
cuda-python>=12,<13
77
diffusers>=0.27.0
88
lark
99
mpi4py

tensorrt_llm/_ipc_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,20 @@
1717
import sys
1818
from typing import List, Tuple
1919

20-
from cuda import cuda, cudart
21-
from cuda.cudart import cudaError_t
20+
try:
21+
from cuda.bindings import driver as cuda
22+
from cuda.bindings import runtime as cudart
23+
except ImportError:
24+
from cuda import cuda, cudart
2225

2326
from ._utils import mpi_comm
2427
from .logger import logger
2528
from .mapping import Mapping
2629

2730

28-
def _raise_if_error(error: cudaError_t | cuda.CUresult):
29-
if isinstance(error, cudaError_t):
30-
if error != cudaError_t.cudaSuccess:
31+
def _raise_if_error(error: cudart.cudaError_t | cuda.CUresult):
32+
if isinstance(error, cudart.cudaError_t):
33+
if error != cudart.cudaError_t.cudaSuccess:
3134
raise RuntimeError(f"CUDA Runtime API error: {repr(error)}")
3235
if isinstance(error, cuda.CUresult):
3336
if error != cuda.CUresult.CUDA_SUCCESS:

tensorrt_llm/_mnnvl_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818

1919
import pynvml
2020
import torch
21-
from cuda import cuda
21+
22+
try:
23+
from cuda.bindings import driver as cuda
24+
except ImportError:
25+
from cuda import cuda
2226

2327
from ._dlpack_utils import pack_strided_memory
2428
from ._utils import mpi_comm

tensorrt_llm/auto_parallel/cluster_info.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
import pynvml
77
import torch
8-
from cuda import cudart
8+
9+
try:
10+
from cuda.bindings import runtime as cudart
11+
except ImportError:
12+
from cuda import cudart
913

1014
from tensorrt_llm._utils import DictConversion
1115
from tensorrt_llm.logger import logger

tensorrt_llm/runtime/generation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
import torch
3030
import tensorrt as trt
3131
# isort: on
32-
from cuda import cudart
32+
try:
33+
from cuda.bindings import runtime as cudart
34+
except ImportError:
35+
from cuda import cudart
3336

3437
from tensorrt_llm.runtime.memory_pools.memory_pools_allocator import \
3538
MemoryPoolsAllocator

tensorrt_llm/runtime/multimodal_model_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from typing import Optional, Tuple
1414

1515
import torch.nn.functional as F
16-
from cuda import cudart
16+
17+
try:
18+
from cuda.bindings import runtime as cudart
19+
except ImportError:
20+
from cuda import cudart
21+
1722
from huggingface_hub import hf_hub_download
1823
from PIL import Image, UnidentifiedImageError
1924
from safetensors import safe_open

tests/integration/defs/sysinfo/get_sysinfo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424

2525
import psutil
2626
import pynvml
27-
from cuda import cuda
27+
28+
try:
29+
from cuda.bindings import driver as cuda
30+
except ImportError:
31+
from cuda import cuda
2832

2933
# Logger
3034
logger = logging.getLogger(__name__)

tests/microbenchmarks/all_reduce.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
# isort: off
1919
import torch
2020
# isort: on
21-
from cuda import cuda, cudart
21+
try:
22+
from cuda.bindings import runtime as cudart
23+
except ImportError:
24+
from cuda import cudart
2225

2326
import tensorrt_llm as tllm
2427
from tensorrt_llm import Mapping, Tensor

tests/microbenchmarks/build_time_benchmark.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import traceback
88

99
import tensorrt as trt
10-
from cuda import cudart
10+
11+
try:
12+
from cuda.bindings import runtime as cudart
13+
except ImportError:
14+
from cuda import cudart
1115

1216
import tensorrt_llm
1317
from tensorrt_llm import (AutoConfig, AutoModelForCausalLM, BuildConfig,

0 commit comments

Comments
 (0)