1+ // //////////////////////////////////////////////////////////////////////////////
2+ // BSD 3-Clause License
3+ //
4+ // Copyright (c) 2024, NVIDIA Corporation
5+ // All rights reserved.
6+ //
7+ // Redistribution and use in source and binary forms, with or without
8+ // modification, are permitted provided that the following conditions are met:
9+ //
10+ // 1. Redistributions of source code must retain the above copyright notice, this
11+ // list of conditions and the following disclaimer.
12+ //
13+ // 2. Redistributions in binary form must reproduce the above copyright notice,
14+ // this list of conditions and the following disclaimer in the documentation
15+ // and/or other materials provided with the distribution.
16+ //
17+ // 3. Neither the name of the copyright holder nor the names of its
18+ // contributors may be used to endorse or promote products derived from
19+ // this software without specific prior written permission.
20+ //
21+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+ // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+ // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+ // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+ // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+ // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+ // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+ // ///////////////////////////////////////////////////////////////////////////////
32+
33+ #include < pybind11/pybind11.h>
34+ #include < pybind11/numpy.h>
35+ #include < iostream>
36+ #include < stdio.h>
37+ #include < matx.h>
38+ #include < matx/core/dlpack.h>
39+
40+ namespace py = pybind11;
41+
42+ const char * get_capsule_name (py::capsule capsule)
43+ {
44+ return capsule.name ();
45+ }
46+
47+ typedef DLManagedTensor* PTR_DLManagedTensor;
48+ int attempt_unpack_dlpack (py::capsule dlpack_capsule, PTR_DLManagedTensor& p_dlpack)
49+ {
50+ if (p_dlpack == nullptr )
51+ {
52+ return -1 ;
53+ }
54+
55+ const char * capsule_name = dlpack_capsule.name ();
56+
57+ if (strncmp (capsule_name," dltensor" ,8 ) != 0 )
58+ {
59+ return -2 ;
60+ }
61+
62+ p_dlpack = static_cast <PTR_DLManagedTensor>(dlpack_capsule.get_pointer ());
63+
64+ if (p_dlpack == nullptr ) {
65+ return -3 ;
66+ }
67+
68+ return 0 ;
69+ }
70+
71+ int check_dlpack_status (py::capsule dlpack_capsule)
72+ {
73+ PTR_DLManagedTensor unused;
74+ return attempt_unpack_dlpack (dlpack_capsule, unused);
75+ }
76+
77+ const char * dlpack_device_type_to_string (DLDeviceType device_type)
78+ {
79+ switch (device_type)
80+ {
81+ case kDLCPU : return " kDLCPU" ;
82+ case kDLCUDA : return " kDLCUDA" ;
83+ case kDLCUDAHost : return " kDLCUDAHost" ;
84+ case kDLOpenCL : return " kDLOpenCL" ;
85+ case kDLVulkan : return " kDLVulkan" ;
86+ case kDLMetal : return " kDLMetal" ;
87+ case kDLVPI : return " kDLVPI" ;
88+ case kDLROCM : return " kDLROCM" ;
89+ case kDLROCMHost : return " kDLROCMHost" ;
90+ case kDLExtDev : return " kDLExtDev" ;
91+ case kDLCUDAManaged : return " kDLCUDAManaged" ;
92+ case kDLOneAPI : return " kDLOneAPI" ;
93+ case kDLWebGPU : return " kDLWebGPU" ;
94+ case kDLHexagon : return " kDLHexagon" ;
95+ default : return " Unknown DLDeviceType" ;
96+ }
97+ }
98+
99+ const char * dlpack_code_to_string (uint8_t code)
100+ {
101+ switch (code)
102+ {
103+ case kDLInt : return " kDLInt" ;
104+ case kDLUInt : return " kDLUInt" ;
105+ case kDLFloat : return " kDLFloat" ;
106+ case kDLOpaqueHandle : return " kDLOpaqueHandle" ;
107+ case kDLBfloat : return " kDLBfloat" ;
108+ case kDLComplex : return " kDLComplex" ;
109+ case kDLBool : return " kDLBool" ;
110+ default : return " Unknown DLDataTypeCode" ;
111+ }
112+ }
113+
114+ void print_dlpack_info (py::capsule dlpack_capsule) {
115+ PTR_DLManagedTensor p_tensor;
116+ if (attempt_unpack_dlpack (dlpack_capsule, p_tensor))
117+ {
118+ fprintf (stderr," Error: capsule not valid dlpack" );
119+ return ;
120+ }
121+
122+ printf (" data: %p\n " ,p_tensor->dl_tensor .data );
123+ printf (" device: device_type %s, device_id %d\n " ,
124+ dlpack_device_type_to_string (p_tensor->dl_tensor .device .device_type ),
125+ p_tensor->dl_tensor .device .device_id
126+ );
127+ printf (" ndim: %d\n " ,p_tensor->dl_tensor .ndim );
128+ printf (" dtype: code %s, bits %u, lanes %u\n " ,
129+ dlpack_code_to_string (p_tensor->dl_tensor .dtype .code ),
130+ p_tensor->dl_tensor .dtype .bits ,
131+ p_tensor->dl_tensor .dtype .lanes
132+ );
133+ printf (" shape: " );
134+ for (int k=0 ; k<p_tensor->dl_tensor .ndim ; k++)
135+ {
136+ printf (" %ld, " ,p_tensor->dl_tensor .shape [k]);
137+ }
138+ printf (" \n " );
139+ printf (" strides: " );
140+ for (int k=0 ; k<p_tensor->dl_tensor .ndim ; k++)
141+ {
142+ printf (" %ld, " ,p_tensor->dl_tensor .strides [k]);
143+ }
144+ printf (" \n " );
145+ printf (" byte_offset: %lu\n " ,p_tensor->dl_tensor .byte_offset );
146+ }
147+
148+ template <typename T, int RANK>
149+ void print (py::capsule dlpack_capsule)
150+ {
151+ PTR_DLManagedTensor p_tensor;
152+ if (attempt_unpack_dlpack (dlpack_capsule, p_tensor))
153+ {
154+ fprintf (stderr," Error: capsule not valid dlpack" );
155+ return ;
156+ }
157+
158+ matx::tensor_t <T, RANK> a;
159+ matx::make_tensor (a, *p_tensor);
160+ matx::print (a);
161+ }
162+
163+ void call_python_example (py::capsule dlpack_capsule)
164+ {
165+ PTR_DLManagedTensor p_tensor;
166+ if (attempt_unpack_dlpack (dlpack_capsule, p_tensor))
167+ {
168+ fprintf (stderr," Error: capsule not valid dlpack" );
169+ return ;
170+ }
171+
172+ matx::tensor_t <float , 2 > a;
173+ matx::make_tensor (a, *p_tensor);
174+
175+ auto pb = matx::detail::MatXPybind{};
176+
177+ // Example use of python's print
178+ pybind11::print (" Example use of python's print function from C++: " , 1 , 2.0 , " three" );
179+ pybind11::print (" The dlpack_capsule is a " , dlpack_capsule);
180+
181+ auto mypythonlib = pybind11::module_::import (" mypythonlib" );
182+ mypythonlib.attr (" my_func" )(dlpack_capsule);
183+ }
184+
185+ PYBIND11_MODULE (matxutil, m) {
186+ m.def (" get_capsule_name" , &get_capsule_name, " Returns PyCapsule name" );
187+ m.def (" print_dlpack_info" , &print_dlpack_info, " Print the DLPack tensor metadata" );
188+ m.def (" check_dlpack_status" , &check_dlpack_status, " Returns 0 if DLPack is valid, negative error code otherwise" );
189+ m.def (" print_float_2D" , &print<float ,2 >, " Prints a float32 2D tensor" );
190+ m.def (" call_python_example" , &call_python_example, " Example C++ function that calls python code" );
191+ }
0 commit comments