Skip to content

Commit d1f0722

Browse files
committed
Fixed print function to work on device in certain cases
1 parent 28b9790 commit d1f0722

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

include/matx/core/tensor_utils.h

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
#include "matx/core/make_tensor.h"
4141
#include "matx/kernels/utility.cuh"
4242

43+
static constexpr bool PRINT_ON_DEVICE = false; ///< print() uses printf on device
44+
4345
namespace matx
4446
{
4547
/**
@@ -667,10 +669,23 @@ namespace detail {
667669
}
668670
}
669671
}
670-
} // end namespace detail
671-
672-
static constexpr bool PRINT_ON_DEVICE = false; ///< print() uses printf on device
673672

673+
template <typename Op,
674+
typename... Args,
675+
std::enable_if_t<((std::is_integral_v<Args>)&&...) &&
676+
(Op::Rank() == 0 || sizeof...(Args) > 0),
677+
bool> = true>
678+
void DevicePrint(const Op &op, Args... dims) {
679+
if constexpr (PRINT_ON_DEVICE) {
680+
PrintKernel<<<1, 1>>>(op, dims...);
681+
}
682+
else {
683+
auto tmpv = make_tensor<typename Op::scalar_type>(op.Shape());
684+
(tmpv = op).run();
685+
PrintData(tmpv, dims...);
686+
}
687+
}
688+
} // end namespace detail
674689

675690
/**
676691
* @brief Print a tensor's values to stdout
@@ -714,22 +729,21 @@ void PrintData(const Op &op, Args... dims) {
714729
data,
715730
reinterpret_cast<CUdeviceptr>(op.Data()));
716731
MATX_ASSERT_STR_EXP(ret, CUDA_SUCCESS, matxCudaError, "Failed to get memory type");
717-
MATX_ASSERT_STR(mtype == CU_MEMORYTYPE_HOST || mtype == 0, matxNotSupported, "Invalid memory type for printing");
732+
MATX_ASSERT_STR(mtype == CU_MEMORYTYPE_HOST || mtype == 0 || mtype == CU_MEMORYTYPE_DEVICE,
733+
matxNotSupported, "Invalid memory type for printing");
718734

719-
detail::InternalPrint(op, dims...);
735+
if (mtype == CU_MEMORYTYPE_DEVICE) {
736+
detail::DevicePrint(op, dims...);
737+
}
738+
else {
739+
detail::InternalPrint(op, dims...);
740+
}
720741
}
721742
else if (kind == MATX_INVALID_MEMORY || HostPrintable(kind)) {
722743
detail::InternalPrint(op, dims...);
723744
}
724745
else if (DevicePrintable(kind)) {
725-
if constexpr (PRINT_ON_DEVICE) {
726-
PrintKernel<<<1, 1>>>(op, dims...);
727-
}
728-
else {
729-
auto tmpv = make_tensor<typename Op::scalar_type>(op.Shape());
730-
(tmpv = op).run();
731-
PrintData(tmpv, dims...);
732-
}
746+
detail::DevicePrint(op, dims...);
733747
}
734748
}
735749
else {

0 commit comments

Comments
 (0)