Skip to content

Commit ed09e1c

Browse files
authored
Add zero-copy interface from MatX to NumPy (#653)
1 parent 7d1debb commit ed09e1c

File tree

2 files changed

+64
-32
lines changed

2 files changed

+64
-32
lines changed

include/matx/core/pybind.h

Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -397,50 +397,74 @@ class MatXPybind {
397397
}
398398

399399
template <typename TensorType>
400-
auto TensorViewToNumpy(const TensorType &ten)
401-
{
400+
auto TensorViewToNumpy(const TensorType &ten) {
401+
using tensor_type = typename TensorType::scalar_type;
402+
using ntype = matx_convert_complex_type<tensor_type>;
402403
constexpr int RANK = TensorType::Rank();
403-
static_assert(RANK <=5, "TensorViewToNumpy only supports max(RANK) = 5 at the moment.");
404-
405-
using ntype = matx_convert_complex_type<typename TensorType::scalar_type>;
406-
auto ften = pybind11::array_t<ntype>(ten.Shape());
407-
408-
for (index_t s1 = 0; s1 < ten.Size(0); s1++) {
409-
if constexpr (RANK > 1) {
410-
for (index_t s2 = 0; s2 < ten.Size(1); s2++) {
411-
if constexpr (RANK > 2) {
412-
for (index_t s3 = 0; s3 < ten.Size(2); s3++) {
413-
if constexpr (RANK > 3) {
414-
for (index_t s4 = 0; s4 < ten.Size(3); s4++) {
415-
if constexpr (RANK > 4) {
416-
for (index_t s5 = 0; s5 < ten.Size(4); s5++) {
417-
ften.mutable_at(s1, s2, s3, s4, s5) =
418-
ConvertComplex(ten(s1, s2, s3, s4, s5));
404+
405+
// If this is a half-precision type pybind/numpy doesn't support it, so we fall back to the
406+
// slow method where we convert everything
407+
if constexpr (is_matx_type<tensor_type>()) {
408+
auto ften = pybind11::array_t<ntype, pybind11::array::c_style | pybind11::array::forcecast>(ten.Shape());
409+
410+
for (index_t s1 = 0; s1 < ten.Size(0); s1++) {
411+
if constexpr (RANK > 1) {
412+
for (index_t s2 = 0; s2 < ten.Size(1); s2++) {
413+
if constexpr (RANK > 2) {
414+
for (index_t s3 = 0; s3 < ten.Size(2); s3++) {
415+
if constexpr (RANK > 3) {
416+
for (index_t s4 = 0; s4 < ten.Size(3); s4++) {
417+
if constexpr (RANK > 4) {
418+
for (index_t s5 = 0; s5 < ten.Size(4); s5++) {
419+
ften.mutable_at(s1, s2, s3, s4, s5) =
420+
ConvertComplex(ten(s1, s2, s3, s4, s5));
421+
}
422+
} else {
423+
ften.mutable_at(s1, s2, s3, s4) =
424+
ConvertComplex(ten(s1, s2, s3, s4));
419425
}
420-
} else {
421-
ften.mutable_at(s1, s2, s3, s4) =
422-
ConvertComplex(ten(s1, s2, s3, s4));
423426
}
424427
}
425-
}
426-
else {
427-
ften.mutable_at(s1, s2, s3) = ConvertComplex(ten(s1, s2, s3));
428+
else {
429+
ften.mutable_at(s1, s2, s3) = ConvertComplex(ten(s1, s2, s3));
430+
}
428431
}
429432
}
430-
}
431-
else {
432-
ften.mutable_at(s1, s2) = ConvertComplex(ten(s1, s2));
433+
else {
434+
ften.mutable_at(s1, s2) = ConvertComplex(ten(s1, s2));
435+
}
433436
}
434437
}
438+
else {
439+
ften.mutable_at(s1) = ConvertComplex(ten(s1));
440+
}
435441
}
436-
else {
437-
ften.mutable_at(s1) = ConvertComplex(ten(s1));
438-
}
439-
}
440442

441-
return ften;
443+
return ften;
444+
}
445+
else {
446+
const auto tshape = ten.Shape();
447+
const auto tstrides = ten.Strides();
448+
std::vector<pybind11::ssize_t> shape{tshape.begin(), tshape.end()};
449+
std::vector<pybind11::ssize_t> strides{tstrides.begin(), tstrides.end()};
450+
std::for_each(strides.begin(), strides.end(), [](pybind11::ssize_t &x) {
451+
x *= sizeof(tensor_type);
452+
});
453+
454+
auto buf = pybind11::buffer_info(
455+
ten.Data(),
456+
sizeof(tensor_type),
457+
pybind11::format_descriptor<ntype>::format(),
458+
RANK,
459+
shape,
460+
strides
461+
);
462+
463+
return pybind11::array_t<ntype, pybind11::array::c_style | pybind11::array::forcecast>(buf);
464+
}
442465
}
443466

467+
444468
template <typename TensorType,
445469
typename CT = matx_convert_cuda_complex_type<typename TensorType::scalar_type>>
446470
std::optional<TestFailResult<CT>>

include/matx/core/tensor_impl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,14 @@ class tensor_impl_t {
616616
*/
617617
__MATX_INLINE__ auto Shape() const noexcept { return this->desc_.Shape(); }
618618

619+
/**
620+
* Get the strides the tensor from the underlying data
621+
*
622+
* @return
623+
* A shape of the data with the appropriate strides set
624+
*/
625+
__MATX_INLINE__ auto Strides() const noexcept { return this->desc_.Strides(); }
626+
619627
/**
620628
* Set the size of a dimension
621629
*

0 commit comments

Comments
 (0)