Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/matx/core/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -892,4 +892,13 @@ void print(const Op &oper, Args... dims) {
std::apply([&](auto &&...args) { print(oper, args...); }, tp);
}

template <typename Op>
auto OpToTensor(Op &&op, cudaStream_t stream) {
if constexpr (!is_tensor_view_v<Op>) {
return make_tensor<typename remove_cvref<Op>::scalar_type>(op.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream);
} else {
return op;
}
}

}
2 changes: 1 addition & 1 deletion include/matx/transforms/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ __MATX_INLINE__ void fft(OutputTensor o, const InputTensor i,
auto in_t = getCufft1DSupportedTensor(i, stream);

if(!in_t.isSameView(i)) {
(in_t = i).run(stream);
(in_t = i).run(stream);
}

// TODO should combine this function with above...
Expand Down
137 changes: 98 additions & 39 deletions include/matx/transforms/solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -1110,31 +1110,37 @@ void chol(OutputTensor &out, const ATensor &a,

using T1 = typename OutputTensor::scalar_type;

auto a_new = OpToTensor(a, stream);

if(!a_new.isSameView(a)) {
(a_new = a).run(stream);
}

/* Temporary WAR
cuSolver doesn't support row-major layouts. Since we want to make the
library appear as though everything is row-major, we take a performance hit
to transpose in and out of the function. Eventually this may be fixed in
cuSolver.
*/
T1 *tp;
matxAlloc(reinterpret_cast<void **>(&tp), a.Bytes(), MATX_ASYNC_DEVICE_MEMORY,
matxAlloc(reinterpret_cast<void **>(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY,
stream);
auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a, stream);
auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a_new, stream);

// Get parameters required by these tensors
auto params = detail::matxDnCholSolverPlan_t<OutputTensor, ATensor>::GetCholParams(tv, uplo);
auto params = detail::matxDnCholSolverPlan_t<OutputTensor, decltype(a_new)>::GetCholParams(tv, uplo);
params.uplo = uplo;

// Get cache or new inverse plan if it doesn't exist
auto ret = detail::dnchol_cache.Lookup(params);
if (ret == std::nullopt) {
auto tmp = new detail::matxDnCholSolverPlan_t<OutputTensor, ATensor>{tv, uplo};
auto tmp = new detail::matxDnCholSolverPlan_t<OutputTensor, decltype(a_new)>{tv, uplo};
detail::dnchol_cache.Insert(params, static_cast<void *>(tmp));
tmp->Exec(tv, tv, stream, uplo);
}
else {
auto chol_type =
static_cast<detail::matxDnCholSolverPlan_t<OutputTensor, ATensor> *>(ret.value());
static_cast<detail::matxDnCholSolverPlan_t<OutputTensor, decltype(a_new)> *>(ret.value());
chol_type->Exec(tv, tv, stream, uplo);
}

Expand Down Expand Up @@ -1175,32 +1181,42 @@ void lu(OutputTensor &out, PivotTensor &piv,

using T1 = typename OutputTensor::scalar_type;

auto piv_new = OpToTensor(piv, stream);
auto a_new = OpToTensor(a, stream);

if(!piv_new.isSameView(piv)) {
(piv_new = piv).run(stream);
}
if(!a_new.isSameView(a)) {
(a_new = a).run(stream);
}

/* Temporary WAR
cuSolver doesn't support row-major layouts. Since we want to make the
library appear as though everything is row-major, we take a performance hit
to transpose in and out of the function. Eventually this may be fixed in
cuSolver.
*/
T1 *tp;
matxAlloc(reinterpret_cast<void **>(&tp), a.Bytes(), MATX_ASYNC_DEVICE_MEMORY,
matxAlloc(reinterpret_cast<void **>(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY,
stream);
auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a, stream);
auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a_new, stream);
auto tvt = tv.PermuteMatrix();

// Get parameters required by these tensors
auto params = detail::matxDnLUSolverPlan_t<OutputTensor, PivotTensor, ATensor>::GetLUParams(piv, tvt);
auto params = detail::matxDnLUSolverPlan_t<OutputTensor, decltype(piv_new), decltype(a_new)>::GetLUParams(piv_new, tvt);

// Get cache or new LU plan if it doesn't exist
auto ret = detail::dnlu_cache.Lookup(params);
if (ret == std::nullopt) {
auto tmp = new detail::matxDnLUSolverPlan_t<OutputTensor, PivotTensor, ATensor>{piv, tvt};
auto tmp = new detail::matxDnLUSolverPlan_t<OutputTensor, decltype(piv_new), decltype(a_new)>{piv_new, tvt};

detail::dnlu_cache.Insert(params, static_cast<void *>(tmp));
tmp->Exec(tvt, piv, tvt, stream);
tmp->Exec(tvt, piv_new, tvt, stream);
}
else {
auto lu_type = static_cast<detail::matxDnLUSolverPlan_t<OutputTensor, PivotTensor, ATensor> *>(ret.value());
lu_type->Exec(tvt, piv, tvt, stream);
auto lu_type = static_cast<detail::matxDnLUSolverPlan_t<OutputTensor, decltype(piv_new), decltype(a_new)> *>(ret.value());
lu_type->Exec(tvt, piv_new, tvt, stream);
}

/* Temporary WAR
Expand Down Expand Up @@ -1238,20 +1254,26 @@ void det(OutputTensor &out, const InputTensor &a,
static_assert(OutputTensor::Rank() == InputTensor::Rank() - 2, "Output tensor rank must be 2 less than input for det()");
constexpr int RANK = InputTensor::Rank();

auto a_new = OpToTensor(a, stream);

if(!a_new.isSameView(a)) {
(a_new = a).run(stream);
}

// Get parameters required by these tensors
std::array<index_t, RANK - 1> s;

// Set batching dimensions of piv
for (int i = 0; i < RANK - 2; i++) {
s[i] = a.Size(i);
s[i] = a_new.Size(i);
}

s[RANK - 2] = std::min(a.Size(RANK - 1), a.Size(RANK - 2));
s[RANK - 2] = std::min(a_new.Size(RANK - 1), a_new.Size(RANK - 2));

auto piv = make_tensor<int64_t>(s);
auto ac = make_tensor<typename OutputTensor::scalar_type>(a.Shape());
auto ac = make_tensor<typename OutputTensor::scalar_type>(a_new.Shape());

lu(ac, piv, a, stream);
lu(ac, piv, a_new, stream);
prod(out, diag(ac), stream);
}

Expand Down Expand Up @@ -1287,32 +1309,42 @@ void cusolver_qr(OutTensor &out, TauTensor &tau,

using T1 = typename OutTensor::scalar_type;

auto tau_new = OpToTensor(tau, stream);
auto a_new = OpToTensor(a, stream);

if(!tau_new.isSameView(tau)) {
(tau_new = tau).run(stream);
}
if(!a_new.isSameView(a)) {
(a_new = a).run(stream);
}

/* Temporary WAR
cuSolver doesn't support row-major layouts. Since we want to make the
library appear as though everything is row-major, we take a performance hit
to transpose in and out of the function. Eventually this may be fixed in
cuSolver.
*/
T1 *tp;
matxAlloc(reinterpret_cast<void **>(&tp), a.Bytes(), MATX_ASYNC_DEVICE_MEMORY,
matxAlloc(reinterpret_cast<void **>(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY,
stream);
auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a, stream);
auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a_new, stream);
auto tvt = tv.PermuteMatrix();

// Get parameters required by these tensors
auto params = detail::matxDnQRSolverPlan_t<OutTensor, TauTensor, ATensor>::GetQRParams(tau, tvt);
auto params = detail::matxDnQRSolverPlan_t<OutTensor, decltype(tau_new), decltype(a_new)>::GetQRParams(tau_new, tvt);

// Get cache or new QR plan if it doesn't exist
auto ret = detail::dnqr_cache.Lookup(params);
if (ret == std::nullopt) {
auto tmp = new detail::matxDnQRSolverPlan_t<OutTensor, TauTensor, ATensor>{tau, tvt};
auto tmp = new detail::matxDnQRSolverPlan_t<OutTensor, decltype(tau_new), decltype(a_new)>{tau_new, tvt};

detail::dnqr_cache.Insert(params, static_cast<void *>(tmp));
tmp->Exec(tvt, tau, tvt, stream);
tmp->Exec(tvt, tau_new, tvt, stream);
}
else {
auto qr_type = static_cast<detail::matxDnQRSolverPlan_t<OutTensor, TauTensor, ATensor> *>(ret.value());
qr_type->Exec(tvt, tau, tvt, stream);
auto qr_type = static_cast<detail::matxDnQRSolverPlan_t<OutTensor, decltype(tau_new), decltype(a_new)> *>(ret.value());
qr_type->Exec(tvt, tau_new, tvt, stream);
}

/* Temporary WAR
Expand Down Expand Up @@ -1361,34 +1393,51 @@ void svd(UTensor &u, STensor &s,

using T1 = typename ATensor::scalar_type;

auto u_new = OpToTensor(u, stream);
auto s_new = OpToTensor(s, stream);
auto v_new = OpToTensor(v, stream);
auto a_new = OpToTensor(a, stream);

if(!u_new.isSameView(u)) {
(u_new = u).run(stream);
}
if(!s_new.isSameView(s)) {
(s_new = s).run(stream);
}
if(!v_new.isSameView(v)) {
(v_new = v).run(stream);
}
if(!a_new.isSameView(a)) {
(a_new = a).run(stream);
}

/* Temporary WAR
cuSolver doesn't support row-major layouts. Since we want to make the
library appear as though everything is row-major, we take a performance hit
to transpose in and out of the function. Eventually this may be fixed in
cuSolver.
*/
T1 *tp;
matxAlloc(reinterpret_cast<void **>(&tp), a.Bytes(), MATX_ASYNC_DEVICE_MEMORY,
stream);
auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a, stream);
matxAlloc(reinterpret_cast<void **>(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream);
auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a_new, stream);
auto tvt = tv.PermuteMatrix();

// Get parameters required by these tensors
auto params = detail::matxDnSVDSolverPlan_t<UTensor, STensor, VTensor, decltype(tvt)>::GetSVDParams(
u, s, v, tvt, jobu, jobvt);
auto params = detail::matxDnSVDSolverPlan_t<decltype(u_new), decltype(s_new), decltype(v_new), decltype(tvt)>::GetSVDParams(
u_new, s_new, v_new, tvt, jobu, jobvt);

// Get cache or new QR plan if it doesn't exist
auto ret = detail::dnsvd_cache.Lookup(params);
if (ret == std::nullopt) {
auto tmp = new detail::matxDnSVDSolverPlan_t{u, s, v, tvt, jobu, jobvt};
auto tmp = new detail::matxDnSVDSolverPlan_t{u_new, s_new, v_new, tvt, jobu, jobvt};

detail::dnsvd_cache.Insert(params, static_cast<void *>(tmp));
tmp->Exec(u, s, v, tvt, jobu, jobvt, stream);
tmp->Exec(u_new, s_new, v_new, tvt, jobu, jobvt, stream);
}
else {
auto svd_type =
static_cast<detail::matxDnSVDSolverPlan_t<UTensor, STensor, VTensor, decltype(tvt)> *>(ret.value());
svd_type->Exec(u, s, v, tvt, jobu, jobvt, stream);
static_cast<detail::matxDnSVDSolverPlan_t<decltype(u_new), decltype(s_new), decltype(v_new), decltype(tvt)> *>(ret.value());
svd_type->Exec(u_new, s_new, v_new, tvt, jobu, jobvt, stream);
}
}

Expand Down Expand Up @@ -1437,27 +1486,37 @@ void eig([[maybe_unused]] OutputTensor &out, WTensor &w,
*/
using T1 = typename OutputTensor::scalar_type;

auto w_new = OpToTensor(w, stream);
auto a_new = OpToTensor(a, stream);

if(!w_new.isSameView(w)) {
(w_new = w).run(stream);
}
if(!a_new.isSameView(a)) {
(a_new = a).run(stream);
}

T1 *tp;
matxAlloc(reinterpret_cast<void **>(&tp), a.Bytes(), MATX_ASYNC_DEVICE_MEMORY,
matxAlloc(reinterpret_cast<void **>(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY,
stream);
auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a, stream);
auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a_new, stream);

// Get parameters required by these tensors
auto params =
detail::matxDnEigSolverPlan_t<OutputTensor, WTensor, ATensor>::GetEigParams(w, tv, jobz, uplo);
detail::matxDnEigSolverPlan_t<OutputTensor, decltype(w_new), decltype(a_new)>::GetEigParams(w_new, tv, jobz, uplo);

// Get cache or new eigen plan if it doesn't exist
auto ret = detail::dneig_cache.Lookup(params);
if (ret == std::nullopt) {
auto tmp = new detail::matxDnEigSolverPlan_t<OutputTensor, WTensor, ATensor>{w, tv, jobz, uplo};
auto tmp = new detail::matxDnEigSolverPlan_t<OutputTensor, decltype(w_new), decltype(a_new)>{w_new, tv, jobz, uplo};

detail::dneig_cache.Insert(params, static_cast<void *>(tmp));
tmp->Exec(tv, w, tv, jobz, uplo, stream);
tmp->Exec(tv, w_new, tv, jobz, uplo, stream);
}
else {
auto eig_type =
static_cast<detail::matxDnEigSolverPlan_t<OutputTensor, WTensor, ATensor> *>(ret.value());
eig_type->Exec(tv, w, tv, jobz, uplo, stream);
static_cast<detail::matxDnEigSolverPlan_t<OutputTensor, decltype(w_new), decltype(a_new)> *>(ret.value());
eig_type->Exec(tv, w_new, tv, jobz, uplo, stream);
}

/* Copy and free async buffer for transpose */
Expand Down