Skip to content

Commit 12724da

Browse files
committed
Add a self contained python calling MatX (calling python calling MatX) integration example
1 parent a63cd2f commit 12724da

File tree

4 files changed

+330
-0
lines changed

4 files changed

+330
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# This is a cmake project showing how to build a python importable library
2+
# using pybind11, how to pass tensors between MatX and python, and
3+
# how to call MatX operators from python
4+
5+
cmake_minimum_required(VERSION 3.26)
6+
7+
if(NOT DEFINED CMAKE_BUILD_TYPE)
8+
message(WARNING "CMAKE_BUILD_TYPE not defined. Defaulting to release.")
9+
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type: Debug;Release;MinSizeRel;RelWithDebInfo")
10+
endif()
11+
12+
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
13+
message(WARNING "CMAKE_CUDA_ARCHITECTURES not defined. Defaulting to 70")
14+
set(CMAKE_CUDA_ARCHITECTURES 70 CACHE STRING "Select compile target CUDA Compute Capabilities")
15+
endif()
16+
17+
if(NOT DEFINED MATX_FETCH_REMOTE)
18+
message(WARNING "MATX_FETCH_REMOTE not defined. Defaulting to OFF, will use local MatX repo")
19+
set(MATX_FETCH_REMOTE OFF CACHE BOOL "Set MatX repo fetch location")
20+
endif()
21+
22+
project(SAMPLE_MATX_PYTHON LANGUAGES CUDA CXX)
23+
find_package(CUDAToolkit 12.6 REQUIRED)
24+
25+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
26+
27+
# Must enable pybind11 support
28+
set(MATX_EN_PYBIND11 ON)
29+
30+
# Use this section if you want to configure other MatX options
31+
#set(MATX_EN_VISUALIZATION ON) # Uncomment to enable visualizations
32+
#set(MATX_EN_FILEIO ON) # Uncomment to file IO
33+
34+
if(MATX_FETCH_REMOTE)
35+
include(FetchContent)
36+
FetchContent_Declare(
37+
MatX
38+
GIT_REPOSITORY https://github.com/NVIDIA/MatX.git
39+
GIT_TAG main
40+
)
41+
else()
42+
include(FetchContent)
43+
FetchContent_Declare(
44+
MatX
45+
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../
46+
)
47+
endif()
48+
FetchContent_MakeAvailable(MatX)
49+
50+
add_library(matxutil MODULE matxutil.cu)
51+
target_link_libraries(matxutil PRIVATE matx::matx)
52+
set_target_properties(matxutil PROPERTIES SUFFIX ".so" PREFIX "")
53+
54+
configure_file(
55+
${CMAKE_CURRENT_SOURCE_DIR}/mypythonlib.py
56+
${CMAKE_BINARY_DIR}
57+
COPYONLY
58+
)
59+
60+
configure_file(
61+
${CMAKE_CURRENT_SOURCE_DIR}/example_matxutil.py
62+
${CMAKE_BINARY_DIR}
63+
COPYONLY
64+
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import cupy as cp
2+
import sys
3+
sys.path.append('.')
4+
5+
import matxutil
6+
7+
# Demonstrate dlpack consumption invalidates it for future use
8+
def dlp_usage_error():
9+
a = cp.empty((3,3), dtype=cp.float32)
10+
dlp = a.toDlpack()
11+
assert(matxutil.check_dlpack_status(dlp) == 0)
12+
a2 = cp.from_dlpack(dlp) # causes dlp to become unused
13+
assert(matxutil.check_dlpack_status(dlp) != 0)
14+
return dlp
15+
16+
# Demonstrate cupy array stays in scope when returning valid dlp
17+
def scope_okay():
18+
a = cp.empty((3,3), dtype=cp.float32)
19+
a[1,1] = 2
20+
dlp = a.toDlpack()
21+
assert(matxutil.check_dlpack_status(dlp) == 0)
22+
return dlp
23+
24+
print("Demonstrate dlpack consumption invalidates it for future use:")
25+
dlp = dlp_usage_error()
26+
assert(matxutil.check_dlpack_status(dlp) != 0)
27+
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}")
28+
print()
29+
30+
print("Demonstrate cupy array stays in scope when returning valid dlpack:")
31+
dlp = scope_okay()
32+
assert(matxutil.check_dlpack_status(dlp) == 0)
33+
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}")
34+
print()
35+
36+
print("Print info about the dlpack:")
37+
matxutil.print_dlpack_info(dlp)
38+
print()
39+
40+
print("Use MatX to print the tensor:")
41+
matxutil.print_float_2D(dlp)
42+
print()
43+
44+
print("Print current memory usage info:")
45+
gpu_mempool = cp.get_default_memory_pool()
46+
pinned_mempool = cp.get_default_pinned_memory_pool()
47+
print(f" GPU mempool used bytes {gpu_mempool.used_bytes()}")
48+
print(f" Pinned mempool n_free_blocks {pinned_mempool.n_free_blocks()}")
49+
print()
50+
51+
print("Demonstrate python to C++ to python to C++ calling chain (uses mypythonlib.py):")
52+
# This function calls back into python and executes a from_dlpack, consuming the dlp
53+
matxutil.call_python_example(dlp)
54+
assert(matxutil.check_dlpack_status(dlp) != 0)
55+
56+
# Other things to try
57+
# (Done) Check dltensor still valid in C++ before use
58+
# (Done) Assign dltensor to MatX tensor, how to determine number of dimensions
59+
# Pass stream from cp to C++
60+
# (Done) Negative case where pointer goes out of scope before C++ finishes
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// BSD 3-Clause License
3+
//
4+
// Copyright (c) 2024, NVIDIA Corporation
5+
// All rights reserved.
6+
//
7+
// Redistribution and use in source and binary forms, with or without
8+
// modification, are permitted provided that the following conditions are met:
9+
//
10+
// 1. Redistributions of source code must retain the above copyright notice, this
11+
// list of conditions and the following disclaimer.
12+
//
13+
// 2. Redistributions in binary form must reproduce the above copyright notice,
14+
// this list of conditions and the following disclaimer in the documentation
15+
// and/or other materials provided with the distribution.
16+
//
17+
// 3. Neither the name of the copyright holder nor the names of its
18+
// contributors may be used to endorse or promote products derived from
19+
// this software without specific prior written permission.
20+
//
21+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
/////////////////////////////////////////////////////////////////////////////////
32+
33+
#include <pybind11/pybind11.h>
34+
#include <pybind11/numpy.h>
35+
#include <iostream>
36+
#include <stdio.h>
37+
#include <matx.h>
38+
#include <matx/core/dlpack.h>
39+
40+
namespace py = pybind11;
41+
42+
const char* get_capsule_name(py::capsule capsule)
43+
{
44+
return capsule.name();
45+
}
46+
47+
typedef DLManagedTensor* PTR_DLManagedTensor;
48+
int attempt_unpack_dlpack(py::capsule dlpack_capsule, PTR_DLManagedTensor& p_dlpack)
49+
{
50+
if (p_dlpack == nullptr)
51+
{
52+
return -1;
53+
}
54+
55+
const char* capsule_name = dlpack_capsule.name();
56+
57+
if (strncmp(capsule_name,"dltensor",8) != 0)
58+
{
59+
return -2;
60+
}
61+
62+
p_dlpack = static_cast<PTR_DLManagedTensor>(dlpack_capsule.get_pointer());
63+
64+
if (p_dlpack == nullptr) {
65+
return -3;
66+
}
67+
68+
return 0;
69+
}
70+
71+
int check_dlpack_status(py::capsule dlpack_capsule)
72+
{
73+
PTR_DLManagedTensor unused;
74+
return attempt_unpack_dlpack(dlpack_capsule, unused);
75+
}
76+
77+
const char* dlpack_device_type_to_string(DLDeviceType device_type)
78+
{
79+
switch(device_type)
80+
{
81+
case kDLCPU: return "kDLCPU";
82+
case kDLCUDA: return "kDLCUDA";
83+
case kDLCUDAHost: return "kDLCUDAHost";
84+
case kDLOpenCL: return "kDLOpenCL";
85+
case kDLVulkan: return "kDLVulkan";
86+
case kDLMetal: return "kDLMetal";
87+
case kDLVPI: return "kDLVPI";
88+
case kDLROCM: return "kDLROCM";
89+
case kDLROCMHost: return "kDLROCMHost";
90+
case kDLExtDev: return "kDLExtDev";
91+
case kDLCUDAManaged: return "kDLCUDAManaged";
92+
case kDLOneAPI: return "kDLOneAPI";
93+
case kDLWebGPU: return "kDLWebGPU";
94+
case kDLHexagon: return "kDLHexagon";
95+
default: return "Unknown DLDeviceType";
96+
}
97+
}
98+
99+
const char* dlpack_code_to_string(uint8_t code)
100+
{
101+
switch(code)
102+
{
103+
case kDLInt: return "kDLInt";
104+
case kDLUInt: return "kDLUInt";
105+
case kDLFloat: return "kDLFloat";
106+
case kDLOpaqueHandle: return "kDLOpaqueHandle";
107+
case kDLBfloat: return "kDLBfloat";
108+
case kDLComplex: return "kDLComplex";
109+
case kDLBool: return "kDLBool";
110+
default: return "Unknown DLDataTypeCode";
111+
}
112+
}
113+
114+
void print_dlpack_info(py::capsule dlpack_capsule) {
115+
PTR_DLManagedTensor p_tensor;
116+
if (attempt_unpack_dlpack(dlpack_capsule, p_tensor))
117+
{
118+
fprintf(stderr,"Error: capsule not valid dlpack");
119+
return;
120+
}
121+
122+
printf(" data: %p\n",p_tensor->dl_tensor.data);
123+
printf(" device: device_type %s, device_id %d\n",
124+
dlpack_device_type_to_string(p_tensor->dl_tensor.device.device_type),
125+
p_tensor->dl_tensor.device.device_id
126+
);
127+
printf(" ndim: %d\n",p_tensor->dl_tensor.ndim);
128+
printf(" dtype: code %s, bits %u, lanes %u\n",
129+
dlpack_code_to_string(p_tensor->dl_tensor.dtype.code),
130+
p_tensor->dl_tensor.dtype.bits,
131+
p_tensor->dl_tensor.dtype.lanes
132+
);
133+
printf(" shape: ");
134+
for (int k=0; k<p_tensor->dl_tensor.ndim; k++)
135+
{
136+
printf("%ld, ",p_tensor->dl_tensor.shape[k]);
137+
}
138+
printf("\n");
139+
printf(" strides: ");
140+
for (int k=0; k<p_tensor->dl_tensor.ndim; k++)
141+
{
142+
printf("%ld, ",p_tensor->dl_tensor.strides[k]);
143+
}
144+
printf("\n");
145+
printf(" byte_offset: %lu\n",p_tensor->dl_tensor.byte_offset);
146+
}
147+
148+
template<typename T, int RANK>
149+
void print(py::capsule dlpack_capsule)
150+
{
151+
PTR_DLManagedTensor p_tensor;
152+
if (attempt_unpack_dlpack(dlpack_capsule, p_tensor))
153+
{
154+
fprintf(stderr,"Error: capsule not valid dlpack");
155+
return;
156+
}
157+
158+
matx::tensor_t<T, RANK> a;
159+
matx::make_tensor(a, *p_tensor);
160+
matx::print(a);
161+
}
162+
163+
void call_python_example(py::capsule dlpack_capsule)
164+
{
165+
PTR_DLManagedTensor p_tensor;
166+
if (attempt_unpack_dlpack(dlpack_capsule, p_tensor))
167+
{
168+
fprintf(stderr,"Error: capsule not valid dlpack");
169+
return;
170+
}
171+
172+
matx::tensor_t<float, 2> a;
173+
matx::make_tensor(a, *p_tensor);
174+
175+
auto pb = matx::detail::MatXPybind{};
176+
177+
// Example use of python's print
178+
pybind11::print(" Example use of python's print function from C++: ", 1, 2.0, "three");
179+
pybind11::print(" The dlpack_capsule is a ", dlpack_capsule);
180+
181+
auto mypythonlib = pybind11::module_::import("mypythonlib");
182+
mypythonlib.attr("my_func")(dlpack_capsule);
183+
}
184+
185+
PYBIND11_MODULE(matxutil, m) {
186+
m.def("get_capsule_name", &get_capsule_name, "Returns PyCapsule name");
187+
m.def("print_dlpack_info", &print_dlpack_info, "Print the DLPack tensor metadata");
188+
m.def("check_dlpack_status", &check_dlpack_status, "Returns 0 if DLPack is valid, negative error code otherwise");
189+
m.def("print_float_2D", &print<float,2>, "Prints a float32 2D tensor");
190+
m.def("call_python_example", &call_python_example, "Example C++ function that calls python code");
191+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import cupy as cp
2+
import sys
3+
sys.path.append('.')
4+
import matxutil
5+
6+
def my_func(dlp):
7+
print(f" type(dlp) before cp.from_dlpack(): {type(dlp)}")
8+
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}")
9+
a = cp.from_dlpack(dlp)
10+
print(f" type(dlp) after cp.from_dlpack(): {type(dlp)}")
11+
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}")
12+
print(f" type(cp.from_dlPack(dlp)): {type(a)}")
13+
print()
14+
print("Finally, print the tensor we received from MatX using python:")
15+
print(a)

0 commit comments

Comments
 (0)