Skip to content

Commit 4b3706f

Browse files
committed
Merge branch 'main' into lluo/wheel_release-rebased
2 parents 50be2f8 + 5052cc6 commit 4b3706f

File tree

3 files changed

+194
-14
lines changed

3 files changed

+194
-14
lines changed

mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ endmacro()
1818
# If multiple CUDA major versions are available (e.g. 12.9 and 13.0), we
1919
# check what version is used on the host CTK. Otherwise, if host CTK is
2020
# not found, use the lower version.
21-
function(mtrt_get_tensorrt_cuda_version trt_version out_var)
22-
if(NOT CUDAToolkit_VERSION_MAJOR)
23-
find_package(CUDAToolkit)
24-
endif()
25-
21+
function(mtrt_get_tensorrt_cuda_version trt_version target_arch out_var)
2622
set(ctk_version "")
2723
if(CUDAToolkit_FOUND)
28-
set(ctk_version "${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}")
24+
set(ctk_version "${CUDAToolkit_VERSION_MAJOR}.9999")
25+
endif()
26+
27+
set(is_aarch64 FALSE)
28+
if(target_arch MATCHES "aarch64")
29+
set(is_aarch64 TRUE)
2930
endif()
3031

3132
set(trt_available_cuda_versions "")
@@ -39,25 +40,34 @@ function(mtrt_get_tensorrt_cuda_version trt_version out_var)
3940
trt_version VERSION_LESS 10.10)
4041
set(trt_available_cuda_versions "11.8;12.8")
4142
elseif(trt_version VERSION_GREATER 10.10 AND
42-
trt_version VERSION_LESS 10.13)
43+
trt_version VERSION_LESS 10.13.2)
4344
set(trt_available_cuda_versions "11.8;12.9")
44-
elseif(trt_version VERSION_GREATER_EQUAL "10.13")
45-
set(trt_available_cuda_versions "12.9;13.0")
45+
elseif(trt_version VERSION_GREATER_EQUAL "10.13.2")
46+
if(is_aarch64)
47+
set(trt_available_cuda_versions "13.0")
48+
else()
49+
set(trt_available_cuda_versions "12.9;13.0")
50+
endif()
4651
else()
4752
message(FATAL_ERROR "Could not determine available CUDA versions for TensorRT version ${trt_version}")
4853
endif()
4954

5055
set(selected_cuda_version "")
5156
if(ctk_version)
52-
foreach(available_version IN_LISTS trt_available_cuda_versions)
53-
if(ctk_version VERSION_LESS_EQUAL available_version)
57+
foreach(available_version IN LISTS trt_available_cuda_versions)
58+
59+
if(available_version VERSION_LESS_EQUAL ctk_version)
5460
set(selected_cuda_version "${available_version}")
5561
endif()
5662
endforeach()
5763
endif()
5864

5965
if(NOT selected_cuda_version)
60-
list(GET trt_available_cuda_versions 0 selected_cuda_version)
66+
# if(ctk_version)
67+
list(GET trt_available_cuda_versions -1 selected_cuda_version)
68+
# else()
69+
# list(GET trt_available_cuda_versions 0 selected_cuda_version)
70+
# endif()
6171
endif()
6272

6373
message(STATUS "Selected CUDA version tag for TensorRT ${trt_version} is ${selected_cuda_version}")
@@ -134,10 +144,22 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
134144
set(ARG_VERSION "10.12.0.36")
135145
endif()
136146

137-
if(ARG_VERSION VERSION_EQUAL "10.13.0.35")
147+
if(ARG_VERSION VERSION_EQUAL "10.13.0")
138148
set(ARG_VERSION "10.13.0.35")
139149
endif()
140150

151+
if(ARG_VERSION VERSION_EQUAL "10.13.2")
152+
set(ARG_VERSION "10.13.2.6")
153+
endif()
154+
155+
if(ARG_VERSION VERSION_EQUAL "10.13" OR ARG_VERSION VERSION_EQUAL "10.13.3")
156+
set(ARG_VERSION "10.13.3.9")
157+
endif()
158+
159+
if(ARG_VERSION VERSION_EQUAL "10.14")
160+
set(ARG_VERSION "10.14.1.48")
161+
endif()
162+
141163
set(downloadable_versions
142164
"8.6.1.6"
143165
"9.0.1.4" "9.1.0.4" "9.2.0.5"
@@ -156,6 +178,9 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
156178
"10.9.0.34"
157179
"10.12.0.36"
158180
"10.13.0.35"
181+
"10.13.2.6"
182+
"10.13.3.9"
183+
"10.14.1.48"
159184
)
160185

161186
if(NOT ARG_VERSION IN_LIST downloadable_versions)
@@ -166,7 +191,7 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
166191
string(REGEX MATCH "[0-9]+\\.[0-9]+\\.[0-9]+" trt_short_version ${ARG_VERSION})
167192

168193
# Get the CUDA version tag.
169-
mtrt_get_tensorrt_cuda_version("${ARG_VERSION}" TRT_CUDA_VERSION)
194+
mtrt_get_tensorrt_cuda_version("${ARG_VERSION}" "${TARGET_ARCH}" TRT_CUDA_VERSION)
170195

171196
# For aarch64, the published packages are only for
172197
# "Ubuntu-20.04". I believe this corresponds to NVIDIA supported ARM server
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
cmake_minimum_required(VERSION 3.20)
2+
3+
# Add the parent directory to module path so we can include the module under test
4+
get_filename_component(_test_dir "${CMAKE_CURRENT_LIST_FILE}" DIRECTORY)
5+
list(APPEND CMAKE_MODULE_PATH "${_test_dir}/..")
6+
7+
include(TensorRTDownloadURL)
8+
9+
# Helper function to run a test case
10+
function(run_test trt_version ctk_major ctk_minor target_arch expected_result)
11+
# Setup mock environment
12+
if("${ctk_major}" STREQUAL "")
13+
# Simulate CUDAToolkit not found
14+
set(CUDAToolkit_FOUND FALSE)
15+
# # We set MAJOR to a dummy value to prevent the function from calling find_package(CUDAToolkit)
16+
# # which would try to find the actual system CUDA.
17+
# # Note: CMake treats "0" as False, so we must use a non-zero value to skip the `if(NOT ...)` check.
18+
# set(CUDAToolkit_VERSION_MAJOR "999")
19+
# set(CUDAToolkit_VERSION_MINOR "0")
20+
else()
21+
# Simulate CUDAToolkit found with specific version
22+
set(CUDAToolkit_FOUND TRUE)
23+
set(CUDAToolkit_VERSION_MAJOR "${ctk_major}")
24+
set(CUDAToolkit_VERSION_MINOR "${ctk_minor}")
25+
endif()
26+
27+
# Call the function under test
28+
mtrt_get_tensorrt_cuda_version("${trt_version}" "${target_arch}" result)
29+
30+
# Check result
31+
if(NOT "${result}" STREQUAL "${expected_result}")
32+
if("${ctk_major}" STREQUAL "")
33+
set(ctk_str "Not Found")
34+
else()
35+
set(ctk_str "${ctk_major}.${ctk_minor}")
36+
endif()
37+
message(FATAL_ERROR "Test Failed for TRT ${trt_version} with Host CUDA ${ctk_str}.\nExpected: ${expected_result}\nActual: ${result}")
38+
else()
39+
message(STATUS "Test Passed: TRT ${trt_version} + CUDA ${ctk_major}.${ctk_minor} -> ${result}")
40+
endif()
41+
endfunction()
42+
43+
message(STATUS "Running TestTensorRTCUDAVersion...")
44+
45+
# Test Group 1: TRT 10.3 (Available: 11.8, 12.5)
46+
# Logic:
47+
# If host <= available, set selected = available.
48+
# List is iterated in order.
49+
# If multiple match, the last one matching is kept (largest available that is >= host).
50+
51+
# Case 1.1: No CUDA found. Should pick last (12.5).
52+
run_test("10.3" "" "" "x86_64" "12.5")
53+
54+
# Case 1.2: Host CUDA 11.8.
55+
# 11.8 <= 11.8 (True -> 11.8)
56+
# Break.
57+
run_test("10.3" "11" "8" "x86_64" "11.8")
58+
59+
# Case 1.3: Host CUDA 12.0.
60+
# 12.0 <= 11.8 (False)
61+
# 12.0 <= 12.5 (True -> 12.5)
62+
run_test("10.3" "12" "0" "x86_64" "12.5")
63+
64+
# Case 1.4: Host CUDA 12.6.
65+
# 12.6 <= 11.8 (False)
66+
# 12.6 <= 12.5 (False)
67+
# Fallback to last (12.5) because host > max available
68+
run_test("10.3" "12" "6" "x86_64" "12.5")
69+
70+
# Test Group 2: TRT 10.5 (Available: 11.8, 12.6)
71+
# Host 12.6 -> 12.6 <= 12.6 (True -> 12.6)
72+
run_test("10.5" "12" "6" "x86_64" "12.6")
73+
74+
# Test Group 3: TRT 10.8 (Available: 11.8, 12.8)
75+
# Host 12.7 -> 12.7 <= 12.8 (True -> 12.8)
76+
run_test("10.8" "12" "7" "x86_64" "12.8")
77+
78+
# Test Group 4: TRT 10.11 (Available: 11.8, 12.9)
79+
# Host 12.9 -> 12.9 <= 12.9 (True -> 12.9)
80+
run_test("10.11" "12" "9" "x86_64" "12.9")
81+
82+
# Test Group 5: TRT 10.13 (Available: 11.8, 12.9) because < 10.13.2
83+
# Host 12.0 -> 12.0 <= 12.9 (True -> 12.9)
84+
run_test("10.13" "12" "0" "x86_64" "12.9")
85+
86+
# Host 13.1 -> Fallback 12.9 (last available)
87+
run_test("10.13" "13" "1" "x86_64" "12.9")
88+
89+
message(STATUS "All tests passed successfully.")
90+
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
cmake_minimum_required(VERSION 3.20)
2+
3+
# Tests that all possible TensorRT download URLs point to valid resources.
4+
# Usage: 'cmake -P build_tools/cmake/TestTensorRTDownloadURL.cmake'
5+
cmake_policy(SET CMP0057 NEW)
6+
7+
# Mock CUDAToolkit to avoid find_package in script mode
8+
set(CUDAToolkit_FOUND TRUE)
9+
# CUDAToolkit_VERSION_MAJOR is set inside the loop
10+
set(CUDAToolkit_VERSION_MINOR "0")
11+
12+
# Add the parent directory to module path so we can include the module under test
13+
get_filename_component(_test_dir "${CMAKE_CURRENT_LIST_FILE}" DIRECTORY)
14+
list(APPEND CMAKE_MODULE_PATH "${_test_dir}/..")
15+
16+
include(TensorRTDownloadURL)
17+
18+
set(VERSIONS "10.13.2" "10.13.3" "10.14" "10.2" "10.3" "10.4"
19+
"10.5" "10.8" "10.9" "10.12")
20+
set(OSS "Linux")
21+
set(ARCHS "x86_64" "aarch64")
22+
23+
foreach(VERSION IN LISTS VERSIONS)
24+
foreach(OS IN LISTS OSS)
25+
foreach(ARCH IN LISTS ARCHS)
26+
# Adjust mock CUDA version based on TRT version and ARCH to select a valid download URL
27+
if(VERSION VERSION_GREATER_EQUAL "10.4" AND "${ARCH}" MATCHES "aarch64")
28+
# For 10.4+ on aarch64 (Ubuntu 24.04), CUDA 11.8 URLs seem invalid/missing.
29+
# Force newer CUDA (e.g. 12.6) by setting host to 12.0.
30+
set(CUDAToolkit_VERSION_MAJOR "12")
31+
else()
32+
# Default to picking the lowest available version (usually 11.8 or 12.9)
33+
# by setting host version low (e.g. 11.0).
34+
set(CUDAToolkit_VERSION_MAJOR "11")
35+
endif()
36+
37+
if(VERSION VERSION_LESS 10.0)
38+
if("${ARCH}" MATCHES "aarch64")
39+
continue()
40+
endif()
41+
endif()
42+
43+
mtrt_get_tensorrt_download_url("${VERSION}" "${OS}" "${ARCH}" url modified_version)
44+
45+
# Use curl to perform a HEAD request
46+
execute_process(
47+
COMMAND wget --spider --server-response --max-redirect=2 "${url}"
48+
ERROR_VARIABLE headers
49+
RESULT_VARIABLE result
50+
)
51+
52+
# Check if curl request was successful
53+
if(result EQUAL 0)
54+
# Check for 'application/x-gzip', 'application/gzip', or '.tar.gz' in headers
55+
if(headers MATCHES "(application/x-gzip|application/gzip|\\.tar\\.gz)")
56+
message(STATUS "Valid .tar.gz URL: ${url}")
57+
else()
58+
message(FATAL_ERROR "Not a valid .tar.gz URL: ${url}")
59+
endif()
60+
else()
61+
message(FATAL_ERROR "Failed to reach URL: ${url}")
62+
endif()
63+
endforeach()
64+
endforeach()
65+
endforeach()

0 commit comments

Comments
 (0)