Skip to content

Commit 28d1b5b

Browse files
committed
Debug
1 parent 39ff448 commit 28d1b5b

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

include/matx/transforms/chol/chol_cuda.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t {
103103
MATX_STATIC_ASSERT_STR(!is_half_v<T1>, matxInvalidType, "Cholesky solver does not support half precision");
104104
MATX_STATIC_ASSERT_STR((std::is_same_v<T1, typename OutTensor_t::value_type>), matxInavlidType, "Input and Output types must match");
105105

106-
params = GetCholParams(a, uplo);
106+
params = GetCholParams(a, uplo, exec);
107107

108108
this->GetWorkspaceSize();
109109
this->AllocateWorkspace(params.batch_size, false, exec);

include/matx/transforms/eig/eig_cuda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {
160160
params.jobz = jobz;
161161
params.uplo = uplo;
162162
params.exec = exec;
163+
printf("setting stream to %p\n", exec.getStream());
163164
params.dtype = TypeToInt<T1>();
164165

165166
return params;
@@ -273,6 +274,7 @@ struct DnEigCUDAParamsKeyHash {
273274
struct DnEigCUDAParamsKeyEq {
274275
bool operator()(const DnEigCUDAParams_t &l, const DnEigCUDAParams_t &t) const noexcept
275276
{
277+
printf("%ld %ld %zu %zu %d %d %p %p\n", l.n , t.n, l.batch_size , t.batch_size , (int)l.dtype , (int)t.dtype , l.exec.getStream() , t.exec.getStream());
276278
return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype && l.exec.getStream() == t.exec.getStream();
277279
}
278280
};

include/matx/transforms/svd/svd_cuda.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -724,8 +724,8 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t {
724724

725725
static DnSVDCUDAParams_t
726726
GetSVDParams(UTensor &u, STensor &s,
727-
VtTensor &vt, const ATensor &a, const cudaExecutor &exec,
728-
const char jobz = 'A')
727+
VtTensor &vt, const ATensor &a,
728+
const char jobz, const cudaExecutor &exec)
729729
{
730730
DnSVDCUDAParams_t params;
731731
params.batch_size = GetNumBatches(a);
@@ -997,7 +997,7 @@ void svd_impl(UTensor &&u, STensor &&s,
997997

998998
// Get parameters required by these tensors
999999
auto params = detail::matxDnSVDCUDAPlan_t<decltype(u_in), decltype(s_new), decltype(vt_in), decltype(at_col_maj)>::
1000-
GetSVDParams(u_in, s_new, vt_in, at_col_maj, exec, job_cusolver);
1000+
GetSVDParams(u_in, s_new, vt_in, at_col_maj, job_cusolver, exec);
10011001

10021002
// Get cache or new SVD plan if it doesn't exist
10031003
using cache_val_type = detail::matxDnSVDCUDAPlan_t<decltype(u_in), decltype(s_new), decltype(vt_in), decltype(at_col_maj)>;
@@ -1034,7 +1034,7 @@ void svd_impl(UTensor &&u, STensor &&s,
10341034

10351035
// Get parameters required by these tensors
10361036
auto params = detail::matxDnSVDCUDAPlan_t<decltype(u_col_maj), decltype(s_new), decltype(vt_col_maj), decltype(tvt)>::
1037-
GetSVDParams(u_col_maj, s_new, vt_col_maj, tvt, exec, job_cusolver);
1037+
GetSVDParams(u_col_maj, s_new, vt_col_maj, tvt, job_cusolver, exec);
10381038

10391039
// Get cache or new SVD plan if it doesn't exist
10401040
using cache_val_type = detail::matxDnSVDCUDAPlan_t<decltype(u_col_maj), decltype(s_new), decltype(vt_col_maj), decltype(tvt)>;

0 commit comments

Comments
 (0)