@@ -644,8 +644,7 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t {
644644 MATX_STATIC_ASSERT_STR (!is_complex_v<T3>, matxInvalidType, " S type must be real" );
645645 MATX_STATIC_ASSERT_STR ((std::is_same_v<typename inner_op_type_t <T1>::type, T3>), matxInvalidType, " A and S inner types must match" );
646646
647- params = GetSVDParams (u, s, vt, a, jobz);
648- params.exec = exec;
647+ params = GetSVDParams (u, s, vt, a, jobz, exec);
649648 params.method = method;
650649
651650 if (params.method == SVDMethod::GESVDJ_BATCHED) {
@@ -725,8 +724,8 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t {
725724
726725 static DnSVDCUDAParams_t
727726 GetSVDParams (UTensor &u, STensor &s,
728- VtTensor &vt, const ATensor &a,
729- const char jobz = ' A ' )
727+ VtTensor &vt, const ATensor &a,
728+ const char jobz, const cudaExecutor &exec )
730729 {
731730 DnSVDCUDAParams_t params;
732731 params.batch_size = GetNumBatches (a);
@@ -738,6 +737,7 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t {
738737 params.S = s.Data ();
739738 params.jobz = jobz;
740739 params.dtype = TypeToInt<T1>();
740+ params.exec = exec;
741741 return params;
742742 }
743743
@@ -997,7 +997,7 @@ void svd_impl(UTensor &&u, STensor &&s,
997997
998998 // Get parameters required by these tensors
999999 auto params = detail::matxDnSVDCUDAPlan_t<decltype (u_in), decltype (s_new), decltype (vt_in), decltype (at_col_maj)>::
1000- GetSVDParams (u_in, s_new, vt_in, at_col_maj, job_cusolver);
1000+ GetSVDParams (u_in, s_new, vt_in, at_col_maj, job_cusolver, exec );
10011001
10021002 // Get cache or new SVD plan if it doesn't exist
10031003 using cache_val_type = detail::matxDnSVDCUDAPlan_t<decltype (u_in), decltype (s_new), decltype (vt_in), decltype (at_col_maj)>;
@@ -1034,15 +1034,15 @@ void svd_impl(UTensor &&u, STensor &&s,
10341034
10351035 // Get parameters required by these tensors
10361036 auto params = detail::matxDnSVDCUDAPlan_t<decltype (u_col_maj), decltype (s_new), decltype (vt_col_maj), decltype (tvt)>::
1037- GetSVDParams (u_col_maj, s_new, vt_col_maj, tvt, job_cusolver);
1037+ GetSVDParams (u_col_maj, s_new, vt_col_maj, tvt, job_cusolver, exec );
10381038
10391039 // Get cache or new SVD plan if it doesn't exist
10401040 using cache_val_type = detail::matxDnSVDCUDAPlan_t<decltype (u_col_maj), decltype (s_new), decltype (vt_col_maj), decltype (tvt)>;
10411041 detail::GetCache ().LookupAndExec <detail::svd_cuda_cache_t >(
10421042 detail::GetCacheIdFromType<detail::svd_cuda_cache_t >(),
10431043 params,
10441044 [&]() {
1045- return std::make_shared<cache_val_type>(u_col_maj, s_new, vt_col_maj, tvt, method, stream , job_cusolver);
1045+ return std::make_shared<cache_val_type>(u_col_maj, s_new, vt_col_maj, tvt, method, exec , job_cusolver);
10461046 },
10471047 [&](std::shared_ptr<cache_val_type> ctype) {
10481048 ctype->Exec (u_col_maj, s_new, vt_col_maj, tvt, exec, job_cusolver);
0 commit comments