@@ -397,50 +397,74 @@ class MatXPybind {
397397 }
398398
399399 template <typename TensorType>
400- auto TensorViewToNumpy (const TensorType &ten)
401- {
400+ auto TensorViewToNumpy (const TensorType &ten) {
401+ using tensor_type = typename TensorType::scalar_type;
402+ using ntype = matx_convert_complex_type<tensor_type>;
402403 constexpr int RANK = TensorType::Rank ();
403- static_assert (RANK <=5 , " TensorViewToNumpy only supports max(RANK) = 5 at the moment." );
404-
405- using ntype = matx_convert_complex_type<typename TensorType::scalar_type>;
406- auto ften = pybind11::array_t <ntype>(ten.Shape ());
407-
408- for (index_t s1 = 0 ; s1 < ten.Size (0 ); s1++) {
409- if constexpr (RANK > 1 ) {
410- for (index_t s2 = 0 ; s2 < ten.Size (1 ); s2++) {
411- if constexpr (RANK > 2 ) {
412- for (index_t s3 = 0 ; s3 < ten.Size (2 ); s3++) {
413- if constexpr (RANK > 3 ) {
414- for (index_t s4 = 0 ; s4 < ten.Size (3 ); s4++) {
415- if constexpr (RANK > 4 ) {
416- for (index_t s5 = 0 ; s5 < ten.Size (4 ); s5++) {
417- ften.mutable_at (s1, s2, s3, s4, s5) =
418- ConvertComplex (ten (s1, s2, s3, s4, s5));
404+
405+ // If this is a half-precision type pybind/numpy doesn't support it, so we fall back to the
406+ // slow method where we convert everything
407+ if constexpr (is_matx_type<tensor_type>()) {
408+ auto ften = pybind11::array_t <ntype, pybind11::array::c_style | pybind11::array::forcecast>(ten.Shape ());
409+
410+ for (index_t s1 = 0 ; s1 < ten.Size (0 ); s1++) {
411+ if constexpr (RANK > 1 ) {
412+ for (index_t s2 = 0 ; s2 < ten.Size (1 ); s2++) {
413+ if constexpr (RANK > 2 ) {
414+ for (index_t s3 = 0 ; s3 < ten.Size (2 ); s3++) {
415+ if constexpr (RANK > 3 ) {
416+ for (index_t s4 = 0 ; s4 < ten.Size (3 ); s4++) {
417+ if constexpr (RANK > 4 ) {
418+ for (index_t s5 = 0 ; s5 < ten.Size (4 ); s5++) {
419+ ften.mutable_at (s1, s2, s3, s4, s5) =
420+ ConvertComplex (ten (s1, s2, s3, s4, s5));
421+ }
422+ } else {
423+ ften.mutable_at (s1, s2, s3, s4) =
424+ ConvertComplex (ten (s1, s2, s3, s4));
419425 }
420- } else {
421- ften.mutable_at (s1, s2, s3, s4) =
422- ConvertComplex (ten (s1, s2, s3, s4));
423426 }
424427 }
425- }
426- else {
427- ften. mutable_at (s1, s2, s3) = ConvertComplex ( ten (s1, s2, s3));
428+ else {
429+ ften. mutable_at (s1, s2, s3) = ConvertComplex ( ten (s1, s2, s3));
430+ }
428431 }
429432 }
430- }
431- else {
432- ften. mutable_at (s1, s2) = ConvertComplex ( ten (s1, s2));
433+ else {
434+ ften. mutable_at (s1, s2) = ConvertComplex ( ten (s1, s2));
435+ }
433436 }
434437 }
438+ else {
439+ ften.mutable_at (s1) = ConvertComplex (ten (s1));
440+ }
435441 }
436- else {
437- ften.mutable_at (s1) = ConvertComplex (ten (s1));
438- }
439- }
440442
441- return ften;
443+ return ften;
444+ }
445+ else {
446+ const auto tshape = ten.Shape ();
447+ const auto tstrides = ten.Strides ();
448+ std::vector<pybind11::ssize_t > shape{tshape.begin (), tshape.end ()};
449+ std::vector<pybind11::ssize_t > strides{tstrides.begin (), tstrides.end ()};
450+ std::for_each (strides.begin (), strides.end (), [](pybind11::ssize_t &x) {
451+ x *= sizeof (tensor_type);
452+ });
453+
454+ auto buf = pybind11::buffer_info (
455+ ten.Data (),
456+ sizeof (tensor_type),
457+ pybind11::format_descriptor<ntype>::format (),
458+ RANK,
459+ shape,
460+ strides
461+ );
462+
463+ return pybind11::array_t <ntype, pybind11::array::c_style | pybind11::array::forcecast>(buf);
464+ }
442465 }
443466
467+
444468 template <typename TensorType,
445469 typename CT = matx_convert_cuda_complex_type<typename TensorType::scalar_type>>
446470 std::optional<TestFailResult<CT>>
0 commit comments