Skip to content

Add QR Implementations #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/versions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ set(TA_TRACKED_UMPIRE_PREVIOUS_TAG v6.0.0)
#set(TA_TRACKED_BLACSPP_TAG 20cfd414c5b719be1c958f4a2d57abef06df83b6 )
#set(TA_TRACKED_BLACSPP_PREVIOUS_TAG da4ada57e578cf944325a7152164306742551596 )

set(TA_TRACKED_SCALAPACKPP_TAG 043f85d7f31ec6009740ab466bcb5008af7b0814 )
set(TA_TRACKED_SCALAPACKPP_TAG bf17a7246af38d34523bd0099b01d9961d06d311 )
set(TA_TRACKED_SCALAPACKPP_PREVIOUS_TAG 043f85d7f31ec6009740ab466bcb5008af7b0814 )

set(TA_TRACKED_RANGEV3_TAG 2e0591c57fce2aca6073ad6e4fdc50d841827864)
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ TiledArray/math/linalg/util.h
TiledArray/math/linalg/cholesky.h
TiledArray/math/linalg/heig.h
TiledArray/math/linalg/lu.h
TiledArray/math/linalg/qr.h
TiledArray/math/linalg/svd.h
TiledArray/math/linalg/scalapack/util.h
TiledArray/math/linalg/scalapack/block_cyclic.h
TiledArray/math/linalg/scalapack/cholesky.h
TiledArray/math/linalg/scalapack/heig.h
TiledArray/math/linalg/scalapack/lu.h
TiledArray/math/linalg/scalapack/qr.h
TiledArray/math/linalg/scalapack/svd.h
TiledArray/conversions/btas.h
TiledArray/conversions/clone.h
Expand Down
1 change: 1 addition & 0 deletions src/TiledArray/math/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <TiledArray/math/linalg/cholesky.h>
#include <TiledArray/math/linalg/heig.h>
#include <TiledArray/math/linalg/lu.h>
#include <TiledArray/math/linalg/qr.h>
#include <TiledArray/math/linalg/svd.h>

#endif // TILEDARRAY_MATH_LINALG_LINALG_H__INCLUDED
42 changes: 42 additions & 0 deletions src/TiledArray/math/linalg/non-distributed/qr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#ifndef TILEDARRAY_MATH_LINALG_NON_DISTRIBUTED_QR_H__INCLUDED
#define TILEDARRAY_MATH_LINALG_NON_DISTRIBUTED_QR_H__INCLUDED

#include <TiledArray/config.h>

#include <TiledArray/math/linalg/util.h>
#include <TiledArray/math/linalg/rank-local.h>
#include <TiledArray/conversions/eigen.h>

namespace TiledArray::math::linalg::non_distributed {

template <bool QOnly, typename ArrayV>
auto householder_qr( const ArrayV& V, TiledRange q_trange = TiledRange(),
TiledRange r_trange = TiledRange() ) {

(void)detail::array_traits<ArrayV>{};
auto& world = V.world();
auto V_eig = detail::make_matrix(V);
decltype(V_eig) R_eig;
if( !world.rank() ) {
linalg::rank_local::householder_qr<QOnly>( V_eig, R_eig );
}
world.gop.broadcast_serializable( V_eig, 0 );
if(q_trange.rank() == 0) q_trange = V.trange();
auto Q = eigen_to_array<ArrayV>( world, q_trange, V_eig );
if constexpr (not QOnly) {
world.gop.broadcast_serializable( R_eig, 0 );
if (r_trange.rank() == 0) {
// Generate a TRange based on column tiling of V
auto col_tiling = V.trange().dim(1);
r_trange = TiledRange( {col_tiling, col_tiling} );
}
auto R = eigen_to_array<ArrayV>( world, r_trange, R_eig );
return std::make_tuple( Q, R );
} else {
return Q;
}
}

} // namespace TiledArray::math::linalg::non_distributed

#endif
53 changes: 53 additions & 0 deletions src/TiledArray/math/linalg/qr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#ifndef TILEDARRAY_MATH_LINALG_QR_H__INCLUDED
#define TILEDARRAY_MATH_LINALG_QR_H__INCLUDED

#include <TiledArray/config.h>
#if TILEDARRAY_HAS_SCALAPACK
#include <TiledArray/math/linalg/scalapack/qr.h>
#endif
#include <TiledArray/math/linalg/non-distributed/qr.h>
#include <TiledArray/util/threads.h>

#include <TiledArray/math/linalg/cholesky.h>

namespace TiledArray::math::linalg {

template <bool QOnly, typename ArrayV>
auto householder_qr( const ArrayV& V, TiledRange q_trange = TiledRange(),
TiledRange r_trange = TiledRange() ) {
TA_MAX_THREADS;
#if TILEDARRAY_HAS_SCALAPACK
if (V.world().size() > 1 && V.elements_range().volume() > 10000000) {
return scalapack::householder_qr<QOnly>( V, q_trange, r_trange );
}
#endif
return non_distributed::householder_qr<QOnly>( V, q_trange, r_trange );
}

template <bool QOnly, typename ArrayV>
auto cholesky_qr( const ArrayV& V, TiledRange r_trange = TiledRange() ) {
TA_MAX_THREADS;
// Form Grammian
ArrayV G; G("i,j") = V("k,i").conj() * V("k,j");

// Obtain Cholesky L and its inverse
auto [L, Linv] = cholesky_linv<true>( G, r_trange );

// Q = V * L**-H
ArrayV Q; Q("i,j") = V("i,k") * Linv("j,k").conj();

if constexpr (not QOnly) {
// R = L**H
ArrayV R; R("i,j") = L("j,i");
return std::make_tuple( Q, R );
} else return Q;

}

} // namespace TiledArray::math::linalg

namespace TiledArray {
using TiledArray::math::linalg::householder_qr;
using TiledArray::math::linalg::cholesky_qr;
} // namespace TiledArray
#endif
32 changes: 31 additions & 1 deletion src/TiledArray/math/linalg/rank-local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,34 @@ void lu_inv(Matrix<T>& A) {
TA_LAPACK(getri, n, a, lda, ipiv.data());
}

template <bool QOnly, typename T>
void householder_qr( Matrix<T> &V, Matrix<T> &R ) {
integer m = V.rows();
integer n = V.cols();
integer k = std::min(m,n);
integer ldv = V.rows(); // Col Major
T* v = V.data();
std::vector<T> tau(k);
lapack::geqrf( m, n, v, ldv, tau.data() );

// Extract R
if constexpr ( not QOnly ) {
// Resize R just in case
R.resize(k,n);
R.fill(0.);
// Extract Upper triangle into R
integer ldr = R.rows();
T* r = R.data();
lapack::lacpy( lapack::MatrixType::Upper, k, n, v, ldv, r, ldr );
}

// Explicitly form Q
// TODO: This is wrong for complex, but it doesn't look like R/C is caught
// anywhere else either...
lapack::orgqr( m, n, k, v, ldv, tau.data() );

}

#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR) \
template void cholesky(MATRIX&); \
template void cholesky_linv(MATRIX&); \
Expand All @@ -216,7 +244,9 @@ void lu_inv(Matrix<T>& A) {
template void heig(MATRIX&, MATRIX&, VECTOR&); \
template void svd(Job,Job,MATRIX&, VECTOR&, MATRIX*, MATRIX*); \
template void lu_solve(MATRIX&, MATRIX&); \
template void lu_inv(MATRIX&);
template void lu_inv(MATRIX&); \
template void householder_qr<true>(MATRIX&,MATRIX&); \
template void householder_qr<false>(MATRIX&,MATRIX&);

TA_LAPACK_EXPLICIT(Matrix<double>, std::vector<double>);
TA_LAPACK_EXPLICIT(Matrix<float>, std::vector<float>);
Expand Down
3 changes: 3 additions & 0 deletions src/TiledArray/math/linalg/rank-local.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ void lu_solve(Matrix<T> &A, Matrix<T> &B);
template <typename T>
void lu_inv(Matrix<T> &A);

template <bool QOnly,typename T>
void householder_qr( Matrix<T> &V, Matrix<T> &R );

} // namespace TiledArray::math::linalg::rank_local

#endif // TILEDARRAY_MATH_LINALG_RANK_LOCAL_H__INCLUDED
85 changes: 85 additions & 0 deletions src/TiledArray/math/linalg/scalapack/qr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#ifndef TILEDARRAY_MATH_LINALG_SCALAPACK_QR_H__INCLUDED
#define TILEDARRAY_MATH_LINALG_SCALAPACK_QR_H__INCLUDED

#include <TiledArray/config.h>
#if TILEDARRAY_HAS_SCALAPACK

#include <TiledArray/math/linalg/scalapack/util.h>

#include <scalapackpp/factorizations/geqrf.hpp>
#include <scalapackpp/householder/generate_q_householder.hpp>
#include <scalapackpp/lacpy.hpp>

namespace TiledArray::math::linalg::scalapack {

template <bool QOnly, typename ArrayV>
auto householder_qr( const ArrayV& V, TiledRange q_trange = TiledRange(),
TiledRange r_trange = TiledRange(),
size_t NB = default_block_size(),
size_t MB = default_block_size()) {

using value_type = typename ArrayV::element_type;

auto& world = V.world();
auto world_comm = world.mpi.comm().Get_mpi_comm();
blacspp::Grid grid = blacspp::Grid::square_grid(world_comm);

world.gop.fence(); // stage ScaLAPACK execution
auto V_sca = scalapack::array_to_block_cyclic(V, grid, MB, NB);
world.gop.fence(); // stage ScaLAPACK execution

auto [M, N] = V_sca.dims();
auto K = std::min(M,N);
auto [V_Mloc, V_Nloc] = V_sca.dist().get_local_dims(M, N);
auto desc_v = V_sca.dist().descinit_noerror(M, N, V_Mloc);


std::vector<value_type>
TAU_local( scalapackpp::local_col_from_desc( K, desc_v ) );

// Perform QR factorization -> Obtain reflectors + R in UT
auto info = scalapackpp::pgeqrf( M, N, V_sca.local_mat().data(), 1, 1, desc_v, TAU_local.data() );
if(info) TA_EXCEPTION("GEQRF FAILED");

ArrayV R; // Uninitialized R matrix

if constexpr (not QOnly) {
BlockCyclicMatrix<value_type> R_sca( world, grid, K, N, MB, NB );
auto [R_Mloc, R_Nloc] = R_sca.dist().get_local_dims(K, N);
auto desc_r = R_sca.dist().descinit_noerror(K, N, R_Mloc);

// Extract R from the upper triangle of V
R_sca.local_mat().fill(0.);
scalapackpp::placpy( scalapackpp::Uplo::Upper, K, N,
V_sca.local_mat().data(), 1, 1, desc_v,
R_sca.local_mat().data(), 1, 1, desc_r );

if (r_trange.rank() == 0) {
// Generate a TRange based on column tiling of V
auto col_tiling = V.trange().dim(1);
r_trange = TiledRange( {col_tiling, col_tiling} );
}

world.gop.fence();
R = scalapack::block_cyclic_to_array<ArrayV>( R_sca, r_trange );
world.gop.fence();
}

// Generate Q
info = scalapackpp::generate_q_householder( M, N, K, V_sca.local_mat().data(), 1, 1, desc_v,
TAU_local.data() );
if(info) TA_EXCEPTION("GENQ FAILED");

if(q_trange.rank() == 0) q_trange = V.trange();
world.gop.fence();
auto Q = scalapack::block_cyclic_to_array<ArrayV>( V_sca, q_trange );
world.gop.fence();

if constexpr (QOnly) return Q;
else return std::make_tuple( Q, R );
}

} // namespace TiledArray::math::linalg::scalapack

#endif // TILEDARRAY_HAS_SCALAPACK
#endif // HEADER GUARD
Loading