|
40 | 40 | #include "matx/core/make_tensor.h" |
41 | 41 | #include "matx/kernels/utility.cuh" |
42 | 42 |
|
| 43 | +static constexpr bool PRINT_ON_DEVICE = false; ///< print() uses printf on device |
| 44 | + |
43 | 45 | namespace matx |
44 | 46 | { |
45 | 47 | /** |
@@ -667,10 +669,23 @@ namespace detail { |
667 | 669 | } |
668 | 670 | } |
669 | 671 | } |
670 | | -} // end namespace detail |
671 | | - |
672 | | -static constexpr bool PRINT_ON_DEVICE = false; ///< print() uses printf on device |
673 | 672 |
|
| 673 | + template <typename Op, |
| 674 | + typename... Args, |
| 675 | + std::enable_if_t<((std::is_integral_v<Args>)&&...) && |
| 676 | + (Op::Rank() == 0 || sizeof...(Args) > 0), |
| 677 | + bool> = true> |
| 678 | + void DevicePrint(const Op &op, Args... dims) { |
| 679 | + if constexpr (PRINT_ON_DEVICE) { |
| 680 | + PrintKernel<<<1, 1>>>(op, dims...); |
| 681 | + } |
| 682 | + else { |
| 683 | + auto tmpv = make_tensor<typename Op::scalar_type>(op.Shape()); |
| 684 | + (tmpv = op).run(); |
| 685 | + PrintData(tmpv, dims...); |
| 686 | + } |
| 687 | + } |
| 688 | +} // end namespace detail |
674 | 689 |
|
675 | 690 | /** |
676 | 691 | * @brief Print a tensor's values to stdout |
@@ -714,22 +729,21 @@ void PrintData(const Op &op, Args... dims) { |
714 | 729 | data, |
715 | 730 | reinterpret_cast<CUdeviceptr>(op.Data())); |
716 | 731 | MATX_ASSERT_STR_EXP(ret, CUDA_SUCCESS, matxCudaError, "Failed to get memory type"); |
717 | | - MATX_ASSERT_STR(mtype == CU_MEMORYTYPE_HOST || mtype == 0, matxNotSupported, "Invalid memory type for printing"); |
| 732 | + MATX_ASSERT_STR(mtype == CU_MEMORYTYPE_HOST || mtype == 0 || mtype == CU_MEMORYTYPE_DEVICE, |
| 733 | + matxNotSupported, "Invalid memory type for printing"); |
718 | 734 |
|
719 | | - detail::InternalPrint(op, dims...); |
| 735 | + if (mtype == CU_MEMORYTYPE_DEVICE) { |
| 736 | + detail::DevicePrint(op, dims...); |
| 737 | + } |
| 738 | + else { |
| 739 | + detail::InternalPrint(op, dims...); |
| 740 | + } |
720 | 741 | } |
721 | 742 | else if (kind == MATX_INVALID_MEMORY || HostPrintable(kind)) { |
722 | 743 | detail::InternalPrint(op, dims...); |
723 | 744 | } |
724 | 745 | else if (DevicePrintable(kind)) { |
725 | | - if constexpr (PRINT_ON_DEVICE) { |
726 | | - PrintKernel<<<1, 1>>>(op, dims...); |
727 | | - } |
728 | | - else { |
729 | | - auto tmpv = make_tensor<typename Op::scalar_type>(op.Shape()); |
730 | | - (tmpv = op).run(); |
731 | | - PrintData(tmpv, dims...); |
732 | | - } |
| 746 | + detail::DevicePrint(op, dims...); |
733 | 747 | } |
734 | 748 | } |
735 | 749 | else { |
|
0 commit comments