Skip to content

Commit 9b942d2

Browse files
ax3lroelof-groenewald
authored andcommitted
Bind DLDeviceType
1 parent 728fa4a commit 9b942d2

File tree

6 files changed

+44
-1
lines changed

6 files changed

+44
-1
lines changed

src/Base/Array4.H

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#pragma once
77

88
#include "pyAMReX.H"
9-
#include "dlpack.h"
9+
#include "dlpack/dlpack.h"
1010

1111
#include <AMReX_Array4.H>
1212
#include <AMReX_BLassert.H>

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
add_subdirectory(AmrCore)
33
add_subdirectory(Base)
44
#add_subdirectory(Boundary)
5+
add_subdirectory(dlpack)
56
#add_subdirectory(EB)
67
#add_subdirectory(Extern)
78
#add_subdirectory(LinearSolvers)

src/dlpack/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
foreach(D IN LISTS AMReX_SPACEDIM)
2+
target_sources(pyAMReX_${D}d
3+
PRIVATE
4+
DLPack.cpp
5+
)
6+
endforeach()

src/dlpack/DLPack.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "pyAMReX.H"
2+
3+
#include "dlpack.h"
4+
5+
6+
void init_DLPack(py::module& m)
7+
{
8+
using namespace amrex;
9+
10+
// register types only if not already present, e.g., from another library
11+
// that also implements DLPack bindings and exposes the types
12+
13+
py::type pyDLDeviceType = py::type::of<DLDeviceType>();
14+
if (!pyDLDeviceType) {
15+
py::native_enum<DLDeviceType>(m, "DLDeviceType", "enum.IntEnum")
16+
.value("kDLCPU", DLDeviceType::kDLCPU)
17+
.value("kDLCUDA", DLDeviceType::kDLCUDA)
18+
.value("kDLCUDAHost", DLDeviceType::kDLCUDAHost)
19+
.value("kDLOpenCL", DLDeviceType::kDLOpenCL)
20+
.value("kDLVulkan", DLDeviceType::kDLVulkan)
21+
.value("kDLMetal", DLDeviceType::kDLMetal)
22+
.value("kDLVPI", DLDeviceType::kDLVPI)
23+
.value("kDLROCM", DLDeviceType::kDLROCM)
24+
.value("kDLROCMHost", DLDeviceType::kDLROCMHost)
25+
.value("kDLExtDev", DLDeviceType::kDLExtDev)
26+
.value("kDLCUDAManaged", DLDeviceType::kDLCUDAManaged)
27+
.value("kDLOneAPI", DLDeviceType::kDLOneAPI)
28+
.value("kDLWebGPU", DLDeviceType::kDLWebGPU)
29+
.value("kDLHexagon", DLDeviceType::kDLHexagon)
30+
.value("kDLMAIA", DLDeviceType::kDLMAIA)
31+
;
32+
}
33+
34+
}
File renamed without changes.

src/pyAMReX.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void init_Arena(py::module&);
2020
void init_Array4(py::module&);
2121
void init_BaseFab(py::module&);
2222
void init_Box(py::module &);
23+
void init_DLPack(py::module &);
2324
void init_RealBox(py::module &);
2425
void init_BoxArray(py::module &);
2526
void init_CoordSys(py::module&);
@@ -98,6 +99,7 @@ PYBIND11_MODULE(amrex_3d_pybind, m) {
9899

99100
// note: order from parent to child classes and argument usage
100101

102+
init_DLPack(m);
101103
init_AMReX(m);
102104
init_Arena(m);
103105
init_Dim3(m);

0 commit comments

Comments
 (0)