Skip to content

Commit 104bf39

Browse files
[mlir-tensorrt] NFC: Integrate internal changes
Resolve some minor issues after recent migrations of PJRT and kernel projects. Co-authored-by: Sagar Shelke <[email protected]> GitOrigin: 704e99c3c8c187ed84c401d1664dae3e9ef03206
1 parent 574dbe5 commit 104bf39

File tree

8 files changed

+174
-12
lines changed

8 files changed

+174
-12
lines changed

mlir-tensorrt/DependencyProvider.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ set(mlir_patch_dir "${CMAKE_CURRENT_LIST_DIR}/build_tools/patches/mlir")
8383
if(NOT MTRT_BUILD_LLVM_FROM_SOURCE)
8484
message(WARNING "Using 'find_package' to locate pre-built LLVM. Please set MLIR_DIR to the directory containing MLIRConfig.cmake")
8585
else()
86-
8786
nv_register_package(
8887
NAME LLVM
8988
URL "https://github.com/llvm/llvm-project/archive/${MLIR_TRT_LLVM_COMMIT}.zip"
@@ -94,6 +93,7 @@ else()
9493
"${mlir_patch_dir}/0006-mlir-emitc-Fix-two-EmitC-bugs.patch"
9594
"${mlir_patch_dir}/0009-mlir-Support-FileLineColRange-in-LLVM-debug-translat.patch"
9695
"${mlir_patch_dir}/0011-MLIR-Fix-bufferization-interface-for-tensor-reshape.patch"
96+
"${mlir_patch_dir}/0001-NVPTX-Add-support-for-PTX-ISA-v8.8-136639.patch"
9797
# Set the CPM cache key to the Git hash for easy navigation.
9898
PRE_ADD_HOOK [[
9999
list(APPEND _vap_UNPARSED_ARGUMENTS
@@ -315,7 +315,7 @@ nv_register_package(
315315
NAME absl
316316
GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git
317317
GIT_TAG fb3621f4f897824c0dbe0615fa94543df6192f30
318-
EXCLUDE_FROM_ALL TRUE
318+
EXCLUDE_FROM_ALL TRUE
319319
OPTIONS
320320
"ABSL_USE_SYSTEM_INCLUDES ON"
321321
"ABSL_PROPAGATE_CXX_STD ON"
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
From 659f5acd038afbb281b4d1d410762f40954e08c8 Mon Sep 17 00:00:00 2001
2+
From: Princeton Ferro <[email protected]>
3+
Date: Fri, 2 May 2025 17:42:32 -0700
4+
Subject: [PATCH 1/1] [NVPTX] Add support for PTX ISA v8.8 (#136639)
5+
6+
Support PTX version 8.8 (`-mattr=+ptx88`) from CUDA 12.9. The following
7+
new targets are also added:
8+
9+
- SM103 and SM121: sm_103, sm_103a, sm_121, sm_121a.
10+
11+
Also, some things were reformatted.
12+
13+
https://docs.nvidia.com/cuda/parallel-thread-execution/#changes-in-ptx-isa-version-8-8
14+
---
15+
llvm/lib/Target/NVPTX/NVPTX.td | 62 +++++++++++++++------------
16+
llvm/test/CodeGen/NVPTX/sm-version.ll | 16 +++++++
17+
2 files changed, 51 insertions(+), 27 deletions(-)
18+
19+
diff --git a/llvm/lib/Target/NVPTX/NVPTX.td b/llvm/lib/Target/NVPTX/NVPTX.td
20+
index 5467ae011a20..ff9a187ecf72 100644
21+
--- a/llvm/lib/Target/NVPTX/NVPTX.td
22+
+++ b/llvm/lib/Target/NVPTX/NVPTX.td
23+
@@ -36,17 +36,21 @@ class FeaturePTX<int version>:
24+
25+
foreach sm = [20, 21, 30, 32, 35, 37, 50, 52, 53,
26+
60, 61, 62, 70, 72, 75, 80, 86, 87,
27+
- 89, 90, 100, 101, 120] in
28+
+ 89, 90, 100, 101, 103, 120, 121] in
29+
def SM#sm: FeatureSM<""#sm, !mul(sm, 10)>;
30+
31+
-def SM90a: FeatureSM<"90a", 901>;
32+
+// Arch-specific targets. PTX for these is not compatible with any other
33+
+// architectures.
34+
+def SM90a : FeatureSM<"90a", 901>;
35+
def SM100a: FeatureSM<"100a", 1001>;
36+
def SM101a: FeatureSM<"101a", 1011>;
37+
+def SM103a: FeatureSM<"103a", 1031>;
38+
def SM120a: FeatureSM<"120a", 1201>;
39+
+def SM121a: FeatureSM<"121a", 1211>;
40+
41+
foreach version = [32, 40, 41, 42, 43, 50, 60, 61, 62, 63, 64, 65,
42+
70, 71, 72, 73, 74, 75, 76, 77, 78,
43+
- 80, 81, 82, 83, 84, 85, 86, 87] in
44+
+ 80, 81, 82, 83, 84, 85, 86, 87, 88] in
45+
def PTX#version: FeaturePTX<version>;
46+
47+
//===----------------------------------------------------------------------===//
48+
@@ -56,33 +60,37 @@ foreach version = [32, 40, 41, 42, 43, 50, 60, 61, 62, 63, 64, 65,
49+
class Proc<string Name, list<SubtargetFeature> Features>
50+
: Processor<Name, NoItineraries, Features>;
51+
52+
-def : Proc<"sm_20", [SM20, PTX32]>;
53+
-def : Proc<"sm_21", [SM21, PTX32]>;
54+
-def : Proc<"sm_30", [SM30]>;
55+
-def : Proc<"sm_32", [SM32, PTX40]>;
56+
-def : Proc<"sm_35", [SM35, PTX32]>;
57+
-def : Proc<"sm_37", [SM37, PTX41]>;
58+
-def : Proc<"sm_50", [SM50, PTX40]>;
59+
-def : Proc<"sm_52", [SM52, PTX41]>;
60+
-def : Proc<"sm_53", [SM53, PTX42]>;
61+
-def : Proc<"sm_60", [SM60, PTX50]>;
62+
-def : Proc<"sm_61", [SM61, PTX50]>;
63+
-def : Proc<"sm_62", [SM62, PTX50]>;
64+
-def : Proc<"sm_70", [SM70, PTX60]>;
65+
-def : Proc<"sm_72", [SM72, PTX61]>;
66+
-def : Proc<"sm_75", [SM75, PTX63]>;
67+
-def : Proc<"sm_80", [SM80, PTX70]>;
68+
-def : Proc<"sm_86", [SM86, PTX71]>;
69+
-def : Proc<"sm_87", [SM87, PTX74]>;
70+
-def : Proc<"sm_89", [SM89, PTX78]>;
71+
-def : Proc<"sm_90", [SM90, PTX78]>;
72+
-def : Proc<"sm_90a", [SM90a, PTX80]>;
73+
-def : Proc<"sm_100", [SM100, PTX86]>;
74+
+def : Proc<"sm_20", [SM20, PTX32]>;
75+
+def : Proc<"sm_21", [SM21, PTX32]>;
76+
+def : Proc<"sm_30", [SM30]>;
77+
+def : Proc<"sm_32", [SM32, PTX40]>;
78+
+def : Proc<"sm_35", [SM35, PTX32]>;
79+
+def : Proc<"sm_37", [SM37, PTX41]>;
80+
+def : Proc<"sm_50", [SM50, PTX40]>;
81+
+def : Proc<"sm_52", [SM52, PTX41]>;
82+
+def : Proc<"sm_53", [SM53, PTX42]>;
83+
+def : Proc<"sm_60", [SM60, PTX50]>;
84+
+def : Proc<"sm_61", [SM61, PTX50]>;
85+
+def : Proc<"sm_62", [SM62, PTX50]>;
86+
+def : Proc<"sm_70", [SM70, PTX60]>;
87+
+def : Proc<"sm_72", [SM72, PTX61]>;
88+
+def : Proc<"sm_75", [SM75, PTX63]>;
89+
+def : Proc<"sm_80", [SM80, PTX70]>;
90+
+def : Proc<"sm_86", [SM86, PTX71]>;
91+
+def : Proc<"sm_87", [SM87, PTX74]>;
92+
+def : Proc<"sm_89", [SM89, PTX78]>;
93+
+def : Proc<"sm_90", [SM90, PTX78]>;
94+
+def : Proc<"sm_90a", [SM90a, PTX80]>;
95+
+def : Proc<"sm_100", [SM100, PTX86]>;
96+
def : Proc<"sm_100a", [SM100a, PTX86]>;
97+
-def : Proc<"sm_101", [SM101, PTX86]>;
98+
+def : Proc<"sm_101", [SM101, PTX86]>;
99+
def : Proc<"sm_101a", [SM101a, PTX86]>;
100+
-def : Proc<"sm_120", [SM120, PTX87]>;
101+
+def : Proc<"sm_103", [SM103, PTX88]>;
102+
+def : Proc<"sm_103a", [SM103a, PTX88]>;
103+
+def : Proc<"sm_120", [SM120, PTX87]>;
104+
def : Proc<"sm_120a", [SM120a, PTX87]>;
105+
+def : Proc<"sm_121", [SM121, PTX88]>;
106+
+def : Proc<"sm_121a", [SM121a, PTX88]>;
107+
108+
def NVPTXInstrInfo : InstrInfo {
109+
}
110+
diff --git a/llvm/test/CodeGen/NVPTX/sm-version.ll b/llvm/test/CodeGen/NVPTX/sm-version.ll
111+
index ce9a1b1b161d..9705a2f3ba73 100644
112+
--- a/llvm/test/CodeGen/NVPTX/sm-version.ll
113+
+++ b/llvm/test/CodeGen/NVPTX/sm-version.ll
114+
@@ -20,8 +20,12 @@
115+
; RUN: llc < %s -mtriple=nvptx -mcpu=sm_100a | FileCheck %s --check-prefix=SM100a
116+
; RUN: llc < %s -mtriple=nvptx -mcpu=sm_101 | FileCheck %s --check-prefix=SM101
117+
; RUN: llc < %s -mtriple=nvptx -mcpu=sm_101a | FileCheck %s --check-prefix=SM101a
118+
+; RUN: llc < %s -mtriple=nvptx -mcpu=sm_103 | FileCheck %s --check-prefix=SM103
119+
+; RUN: llc < %s -mtriple=nvptx -mcpu=sm_103a | FileCheck %s --check-prefix=SM103a
120+
; RUN: llc < %s -mtriple=nvptx -mcpu=sm_120 | FileCheck %s --check-prefix=SM120
121+
; RUN: llc < %s -mtriple=nvptx -mcpu=sm_120a | FileCheck %s --check-prefix=SM120a
122+
+; RUN: llc < %s -mtriple=nvptx -mcpu=sm_121 | FileCheck %s --check-prefix=SM121
123+
+; RUN: llc < %s -mtriple=nvptx -mcpu=sm_121a | FileCheck %s --check-prefix=SM121a
124+
125+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 | FileCheck %s --check-prefix=SM20
126+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_21 | FileCheck %s --check-prefix=SM21
127+
@@ -45,8 +49,12 @@
128+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100a | FileCheck %s --check-prefix=SM100a
129+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_101 | FileCheck %s --check-prefix=SM101
130+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_101a | FileCheck %s --check-prefix=SM101a
131+
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_103 | FileCheck %s --check-prefix=SM103
132+
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_103a | FileCheck %s --check-prefix=SM103a
133+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_120 | FileCheck %s --check-prefix=SM120
134+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_120a | FileCheck %s --check-prefix=SM120a
135+
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_121 | FileCheck %s --check-prefix=SM121
136+
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_121a | FileCheck %s --check-prefix=SM121a
137+
138+
; SM20: .version 3.2
139+
; SM21: .version 3.2
140+
@@ -70,8 +78,12 @@
141+
; SM100a: .version 8.6
142+
; SM101: .version 8.6
143+
; SM101a: .version 8.6
144+
+; SM103: .version 8.8
145+
+; SM103a: .version 8.8
146+
; SM120: .version 8.7
147+
; SM120a: .version 8.7
148+
+; SM121: .version 8.8
149+
+; SM121a: .version 8.8
150+
151+
; SM20: .target sm_20
152+
; SM21: .target sm_21
153+
@@ -95,5 +107,9 @@
154+
; SM100a: .target sm_100a
155+
; SM101: .target sm_101
156+
; SM101a: .target sm_101a
157+
+; SM103: .target sm_103
158+
+; SM103a: .target sm_103a
159+
; SM120: .target sm_120
160+
; SM120a: .target sm_120a
161+
+; SM121: .target sm_121
162+
+; SM121a: .target sm_121a
163+
--
164+
2.52.0
165+

mlir-tensorrt/integrations/PJRT/CMakeLists.txt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
#-------------------------------------------------------------------------------
2-
# Options and settings
3-
#-------------------------------------------------------------------------------
4-
find_package(XLA REQUIRED)
5-
set(MLIR_TRT_XLA_SOURCE_DIR ${XLA_SOURCE_DIR} CACHE INTERNAL "")
1+
include(${CMAKE_CURRENT_LIST_DIR}/PJRTConfig.cmake)
62

73
#-------------------------------------------------------------------------------
84
# Project Setup
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
find_package(XLA REQUIRED)
2+
set(MLIR_TRT_XLA_SOURCE_DIR ${XLA_SOURCE_DIR} CACHE INTERNAL "")

mlir-tensorrt/integrations/PJRT/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
llvm_canonicalize_cmake_booleans(
44
LLVM_ENABLE_ASSERTIONS
55
ENABLE_ASAN
6-
${MLIR_TRT_FEATURE_FLAGS}w
6+
${MLIR_TRT_FEATURE_FLAGS}
77
)
88

99
# ==== Add unit test subfolder ==============================================

mlir-tensorrt/integrations/PJRT/test/JAX/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def pytest_configure(config):
197197
config.addinivalue_line(
198198
"markers", "requires_trt_version: test requires specific TensorRT version"
199199
)
200-
201200
config.addinivalue_line("markers", "long_test: test takes a long time to run")
202201
config.addinivalue_line(
203202
"markers",

mlir-tensorrt/kernel/include/mlir-kernel/InitAllDialects.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ inline void registerAllRequiredDialects(mlir::DialectRegistry &registry) {
133133
mlir::ub::UBDialect,
134134
mlir::vector::VectorDialect
135135
>();
136+
// clang-format on
136137

137138
// Register all external models.
138139
// Register pointer-like type interfaces for external pointer types.
@@ -202,8 +203,6 @@ inline void registerAllRequiredDialects(mlir::DialectRegistry &registry) {
202203
registerConvertMemRefToLLVMInterface(registry);
203204
registerConvertComplexToLLVMInterface(registry);
204205
NVVM::registerConvertGpuToNVVMInterface(registry);
205-
206-
// clang-format on
207206
}
208207
} // namespace mlir::kernel
209208

mlir-tensorrt/kernel/lib/Kernel/TransformSchedules/InitialTransformSchedule.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ static FailureOr<TilingInterface> findRootOp(func::FuncOp funcOp) {
115115
return failure();
116116
}
117117

118-
/// Decide the transform schedule parameters for different ops.
118+
/// Decide the transform schedule parameters for different ops (matmul /
119+
/// elementwise).
119120
static FailureOr<DetermineTransformScheduleResult>
120121
decideTransformScheduleParameters(const TransformScheduleSelector &selector,
121122
func::FuncOp funcOp, TilingInterface rootOp) {

0 commit comments

Comments
 (0)