diff --git a/INSTALL.md b/INSTALL.md index c2901a7e11..6060c4bd29 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -40,7 +40,7 @@ Both methods are supported. However, for most users we _strongly_ recommend to b - Boost.Container: header-only - Boost.Test: header-only or (optionally) as a compiled library, *only used for unit testing* - Boost.Range: header-only, *only used for unit testing* -- [BTAS](http://github.com/ValeevGroup/BTAS), tag 6fcb6451bc7ca46a00534a30c51dc5c230c39ac3 . If usable BTAS installation is not found, TiledArray will download and compile +- [BTAS](http://github.com/ValeevGroup/BTAS), tag 561fe1bff7f3374814111a15e28c7a141ab9b67a . If usable BTAS installation is not found, TiledArray will download and compile BTAS from source. *This is the recommended way to compile BTAS for all users*. - [MADNESS](https://github.com/m-a-d-n-e-s-s/madness), tag 91fff76deba20c751d0646c54f2f1c1e07bd6156 . Only the MADworld runtime and BLAS/LAPACK C API component of MADNESS is used by TiledArray. diff --git a/external/versions.cmake b/external/versions.cmake index 40321f91ac..e9cfb45375 100644 --- a/external/versions.cmake +++ b/external/versions.cmake @@ -24,8 +24,8 @@ set(TA_TRACKED_MADNESS_PREVIOUS_TAG 0b44ef319643cb9721fbe17d294987c146e6460e) set(TA_TRACKED_MADNESS_VERSION 0.10.1) set(TA_TRACKED_MADNESS_PREVIOUS_VERSION 0.10.1) -set(TA_TRACKED_BTAS_TAG 6fcb6451bc7ca46a00534a30c51dc5c230c39ac3) -set(TA_TRACKED_BTAS_PREVIOUS_TAG 474ddc095cbea12a1d28aca5435703dd9f69b166) +set(TA_TRACKED_BTAS_TAG 561fe1bff7f3374814111a15e28c7a141ab9b67a) +set(TA_TRACKED_BTAS_PREVIOUS_TAG 6fcb6451bc7ca46a00534a30c51dc5c230c39ac3) set(TA_TRACKED_LIBRETT_TAG 68abe31a9ec6fd2fd9ffbcd874daa80457f947da) set(TA_TRACKED_LIBRETT_PREVIOUS_TAG 7e27ac766a9038df6aa05613784a54a036c4b796) diff --git a/src/TiledArray/dist_array.h b/src/TiledArray/dist_array.h index ea6a066441..09e99eda86 100644 --- a/src/TiledArray/dist_array.h +++ b/src/TiledArray/dist_array.h @@ -136,6 +136,27 @@ class DistArray : public madness::archive::ParallelSerializableObject { std::is_same_v, Future> || std::is_same_v, value_type>; + /// compute type of DistArray with different Policy and/or Tile + template + using rebind_t = DistArray; + + private: + template + struct rebind_numeric; + template + struct rebind_numeric< + Numeric, std::enable_if_t>> { + using type = + DistArray, Policy>; + }; + + public: + /// compute type of DistArray with Tile's rebound numeric type + /// @note this is SFINAE-disabled if `Tile::rebind_numeric_t` is not + /// defined + template + using rebind_numeric_t = typename rebind_numeric::type; + private: pimpl_type pimpl_; ///< managed ptr to Array implementation bool defer_deleter_to_next_fence_ = @@ -1838,6 +1859,22 @@ DistArray replicated(const DistArray& a) { return result; } +namespace detail { + +template +struct real_t_impl> { + using type = typename DistArray::template rebind_numeric_t< + typename Tile::scalar_type>; +}; + +template +struct complex_t_impl> { + using type = typename DistArray::template rebind_numeric_t< + std::complex>; +}; + +} // namespace detail + } // namespace TiledArray // serialization diff --git a/src/TiledArray/external/btas.h b/src/TiledArray/external/btas.h index 483be905df..7dbd115d4d 100644 --- a/src/TiledArray/external/btas.h +++ b/src/TiledArray/external/btas.h @@ -841,6 +841,20 @@ struct ordinal_traits> { : OrdinalType::ColMajor; }; +template +struct real_t_impl> { + using type = + typename btas::Tensor::template rebind_numeric_t< + typename btas::Tensor::scalar_type>; +}; + +template +struct complex_t_impl> { + using type = + typename btas::Tensor::template rebind_numeric_t< + std::complex::scalar_type>>; +}; + } // namespace detail } // namespace TiledArray diff --git a/src/TiledArray/math/linalg/non-distributed/heig.h b/src/TiledArray/math/linalg/non-distributed/heig.h index 8a7c244bbc..5490b6b757 100644 --- a/src/TiledArray/math/linalg/non-distributed/heig.h +++ b/src/TiledArray/math/linalg/non-distributed/heig.h @@ -52,10 +52,10 @@ namespace TiledArray::math::linalg::non_distributed { */ template auto heig(const Array& A, TiledRange evec_trange = TiledRange()) { - using numeric_type = typename detail::array_traits::numeric_type; + using scalar_type = typename detail::array_traits::scalar_type; World& world = A.world(); auto A_eig = detail::make_matrix(A); - std::vector evals; + std::vector evals; if (world.rank() == 0) { linalg::rank_local::heig(A_eig, evals); } @@ -93,12 +93,12 @@ auto heig(const Array& A, TiledRange evec_trange = TiledRange()) { template auto heig(const ArrayA& A, const ArrayB& B, TiledRange evec_trange = TiledRange()) { - using numeric_type = typename detail::array_traits::numeric_type; + using scalar_type = typename detail::array_traits::scalar_type; (void)detail::array_traits{}; World& world = A.world(); auto A_eig = detail::make_matrix(A); auto B_eig = detail::make_matrix(B); - std::vector evals; + std::vector evals; if (world.rank() == 0) { linalg::rank_local::heig(A_eig, B_eig, evals); } diff --git a/src/TiledArray/math/linalg/non-distributed/svd.h b/src/TiledArray/math/linalg/non-distributed/svd.h index 9c146784ef..e6ea5ef1da 100644 --- a/src/TiledArray/math/linalg/non-distributed/svd.h +++ b/src/TiledArray/math/linalg/non-distributed/svd.h @@ -27,9 +27,9 @@ #include -#include -#include #include +#include +#include namespace TiledArray::math::linalg::non_distributed { @@ -52,13 +52,14 @@ namespace TiledArray::math::linalg::non_distributed { * @param[in] vt_trange TiledRange for resulting right singular vectors * (transposed). * - * @returns A tuple containing the eigenvalues and eigenvectors of input array - * as std::vector and in TA format, respectively. + * @returns A tuple containing the singular values and singular vectors of + * input array as std::vector and in TA format, respectively. */ -template -auto svd(const Array& A, TiledRange u_trange = TiledRange(), TiledRange vt_trange = TiledRange()) { - +template +auto svd(const Array& A, TiledRange u_trange = TiledRange(), + TiledRange vt_trange = TiledRange()) { using T = typename Array::numeric_type; + using TS = typename Array::scalar_type; using Matrix = linalg::rank_local::Matrix; World& world = A.world(); @@ -68,7 +69,7 @@ auto svd(const Array& A, TiledRange u_trange = TiledRange(), TiledRange vt_trang constexpr bool need_u = (Vectors == SVD::LeftVectors) or svd_all_vectors; constexpr bool need_vt = (Vectors == SVD::RightVectors) or svd_all_vectors; - std::vector S; + std::vector S; std::unique_ptr U, VT; if constexpr (need_u) U = std::make_unique(); @@ -82,7 +83,7 @@ auto svd(const Array& A, TiledRange u_trange = TiledRange(), TiledRange vt_trang if (U) world.gop.broadcast_serializable(*U, 0); if (VT) world.gop.broadcast_serializable(*VT, 0); - auto make_array = [&world](auto && ... args) { + auto make_array = [&world](auto&&... args) { return eigen_to_array(world, args...); }; @@ -97,7 +98,6 @@ auto svd(const Array& A, TiledRange u_trange = TiledRange(), TiledRange vt_trang } if constexpr (!need_u && !need_vt) return S; - } } // namespace TiledArray::math::linalg::non_distributed diff --git a/src/TiledArray/math/linalg/rank-local.cpp b/src/TiledArray/math/linalg/rank-local.cpp index a1e2e5538b..74e1aac526 100644 --- a/src/TiledArray/math/linalg/rank-local.cpp +++ b/src/TiledArray/math/linalg/rank-local.cpp @@ -113,19 +113,23 @@ void cholesky_lsolve(Op transpose, Matrix& A, Matrix& X) { } template -void heig(Matrix& A, std::vector& W) { +void heig(Matrix& A, std::vector>& W) { auto jobz = lapack::Job::Vec; auto uplo = lapack::Uplo::Lower; integer n = A.rows(); T* a = A.data(); integer lda = A.rows(); W.resize(n); - T* w = W.data(); - TA_LAPACK(syev, jobz, uplo, n, a, lda, w); + auto* w = W.data(); + if constexpr (TiledArray::detail::is_complex_v) + TA_LAPACK(heev, jobz, uplo, n, a, lda, w); + else + TA_LAPACK(syev, jobz, uplo, n, a, lda, w); } template -void heig(Matrix& A, Matrix& B, std::vector& W) { +void heig(Matrix& A, Matrix& B, + std::vector>& W) { integer itype = 1; auto jobz = lapack::Job::Vec; auto uplo = lapack::Uplo::Lower; @@ -135,12 +139,17 @@ void heig(Matrix& A, Matrix& B, std::vector& W) { T* b = B.data(); integer ldb = B.rows(); W.resize(n); - T* w = W.data(); - TA_LAPACK(sygv, itype, jobz, uplo, n, a, lda, b, ldb, w); + auto* w = W.data(); + if constexpr (TiledArray::detail::is_complex_v) + TA_LAPACK(hegv, itype, jobz, uplo, n, a, lda, b, ldb, w); + else + TA_LAPACK(sygv, itype, jobz, uplo, n, a, lda, b, ldb, w); } template -void svd(Job jobu, Job jobvt, Matrix& A, std::vector& S, Matrix* U, Matrix* VT) { +void svd(Job jobu, Job jobvt, Matrix& A, + std::vector>& S, Matrix* U, + Matrix* VT) { integer m = A.rows(); integer n = A.cols(); integer k = std::min(m, n); @@ -148,40 +157,42 @@ void svd(Job jobu, Job jobvt, Matrix& A, std::vector& S, Matrix* U, Mat integer lda = A.rows(); S.resize(k); - T* s = S.data(); + auto* s = S.data(); - T* u = nullptr; + T* u = nullptr; T* vt = nullptr; integer ldu = 1, ldvt = 1; - if( (jobu == Job::SomeVec or jobu == Job::AllVec) and (not U) ) - TA_LAPACK_ERROR("Requested out-of-place right singular vectors with null U input"); - if( (jobvt == Job::SomeVec or jobvt == Job::AllVec) and (not VT) ) - TA_LAPACK_ERROR("Requested out-of-place left singular vectors with null VT input"); + if ((jobu == Job::SomeVec or jobu == Job::AllVec) and (not U)) + TA_LAPACK_ERROR( + "Requested out-of-place right singular vectors with null U input"); + if ((jobvt == Job::SomeVec or jobvt == Job::AllVec) and (not VT)) + TA_LAPACK_ERROR( + "Requested out-of-place left singular vectors with null VT input"); - if( jobu == Job::SomeVec ) { + if (jobu == Job::SomeVec) { U->resize(m, k); u = U->data(); ldu = m; } - if( jobu == Job::AllVec ) { + if (jobu == Job::AllVec) { U->resize(m, m); u = U->data(); ldu = m; } - if( jobvt == Job::SomeVec ) { + if (jobvt == Job::SomeVec) { VT->resize(k, n); vt = VT->data(); ldvt = k; } - if( jobvt == Job::AllVec ) { + if (jobvt == Job::AllVec) { VT->resize(n, n); vt = VT->data(); ldvt = n; } - + TA_LAPACK(gesvd, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt); } @@ -208,47 +219,51 @@ void lu_inv(Matrix& A) { } template -void householder_qr( Matrix &V, Matrix &R ) { +void householder_qr(Matrix& V, Matrix& R) { integer m = V.rows(); integer n = V.cols(); - integer k = std::min(m,n); - integer ldv = V.rows(); // Col Major + integer k = std::min(m, n); + integer ldv = V.rows(); // Col Major T* v = V.data(); std::vector tau(k); - lapack::geqrf( m, n, v, ldv, tau.data() ); + lapack::geqrf(m, n, v, ldv, tau.data()); // Extract R - if constexpr ( not QOnly ) { + if constexpr (not QOnly) { // Resize R just in case - R.resize(k,n); + R.resize(k, n); R.fill(0.); // Extract Upper triangle into R integer ldr = R.rows(); T* r = R.data(); - lapack::lacpy( lapack::MatrixType::Upper, k, n, v, ldv, r, ldr ); + lapack::lacpy(lapack::MatrixType::Upper, k, n, v, ldv, r, ldr); } // Explicitly form Q // TODO: This is wrong for complex, but it doesn't look like R/C is caught // anywhere else either... - lapack::orgqr( m, n, k, v, ldv, tau.data() ); - + if constexpr (TiledArray::detail::is_complex_v) + lapack::ungqr(m, n, k, v, ldv, tau.data()); + else + lapack::orgqr(m, n, k, v, ldv, tau.data()); } -#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR) \ - template void cholesky(MATRIX&); \ - template void cholesky_linv(MATRIX&); \ - template void cholesky_solve(MATRIX&, MATRIX&); \ - template void cholesky_lsolve(Op, MATRIX&, MATRIX&); \ - template void heig(MATRIX&, VECTOR&); \ - template void heig(MATRIX&, MATRIX&, VECTOR&); \ - template void svd(Job,Job,MATRIX&, VECTOR&, MATRIX*, MATRIX*); \ - template void lu_solve(MATRIX&, MATRIX&); \ - template void lu_inv(MATRIX&); \ - template void householder_qr(MATRIX&,MATRIX&); \ - template void householder_qr(MATRIX&,MATRIX&); +#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR) \ + template void cholesky(MATRIX&); \ + template void cholesky_linv(MATRIX&); \ + template void cholesky_solve(MATRIX&, MATRIX&); \ + template void cholesky_lsolve(Op, MATRIX&, MATRIX&); \ + template void heig(MATRIX&, VECTOR&); \ + template void heig(MATRIX&, MATRIX&, VECTOR&); \ + template void svd(Job, Job, MATRIX&, VECTOR&, MATRIX*, MATRIX*); \ + template void lu_solve(MATRIX&, MATRIX&); \ + template void lu_inv(MATRIX&); \ + template void householder_qr(MATRIX&, MATRIX&); \ + template void householder_qr(MATRIX&, MATRIX&); TA_LAPACK_EXPLICIT(Matrix, std::vector); TA_LAPACK_EXPLICIT(Matrix, std::vector); +TA_LAPACK_EXPLICIT(Matrix>, std::vector); +TA_LAPACK_EXPLICIT(Matrix>, std::vector); } // namespace TiledArray::math::linalg::rank_local diff --git a/src/TiledArray/math/linalg/rank-local.h b/src/TiledArray/math/linalg/rank-local.h index 77774c195a..5c46550bd3 100644 --- a/src/TiledArray/math/linalg/rank-local.h +++ b/src/TiledArray/math/linalg/rank-local.h @@ -42,17 +42,20 @@ template void cholesky_lsolve(Op transpose, Matrix &A, Matrix &X); template -void heig(Matrix &A, std::vector &W); +void heig(Matrix &A, std::vector> &W); template -void heig(Matrix &A, Matrix &B, std::vector &W); +void heig(Matrix &A, Matrix &B, + std::vector> &W); template -void svd(Job jobu, Job jobvt, Matrix &A, std::vector &S, Matrix *U, +void svd(Job jobu, Job jobvt, Matrix &A, + std::vector> &S, Matrix *U, Matrix *VT); template -void svd(Matrix &A, std::vector &S, Matrix *U, Matrix *VT) { +void svd(Matrix &A, std::vector> &S, + Matrix *U, Matrix *VT) { svd(U ? Job::SomeVec : Job::NoVec, VT ? Job::SomeVec : Job::NoVec, A, S, U, VT); } diff --git a/src/TiledArray/math/linalg/scalapack/heig.h b/src/TiledArray/math/linalg/scalapack/heig.h index bc9edeaa91..d7e84ae706 100644 --- a/src/TiledArray/math/linalg/scalapack/heig.h +++ b/src/TiledArray/math/linalg/scalapack/heig.h @@ -58,7 +58,7 @@ namespace TiledArray::math::linalg::scalapack { template auto heig(const Array& A, TiledRange evec_trange = TiledRange(), size_t NB = default_block_size()) { - using value_type = typename Array::element_type; + using value_type = typename Array::numeric_type; using real_type = scalapackpp::detail::real_t; auto& world = A.world(); @@ -80,9 +80,8 @@ auto heig(const Array& A, TiledRange evec_trange = TiledRange(), scalapack::BlockCyclicMatrix evecs(world, grid, N, N, NB, NB); auto info = scalapackpp::hereig( - scalapackpp::Job::Vec, blacspp::Uplo::Lower, N, - matrix.local_mat().data(), 1, 1, desc, evals.data(), - evecs.local_mat().data(), 1, 1, desc); + scalapackpp::Job::Vec, blacspp::Uplo::Lower, N, matrix.local_mat().data(), + 1, 1, desc, evals.data(), evecs.local_mat().data(), 1, 1, desc); if (info) TA_EXCEPTION("EVP Failed"); if (evec_trange.rank() == 0) evec_trange = A.trange(); @@ -122,8 +121,8 @@ template auto heig(const ArrayA& A, const ArrayB& B, TiledRange evec_trange = TiledRange(), size_t NB = default_block_size()) { - using value_type = typename ArrayA::element_type; - static_assert(std::is_same_v); + using value_type = typename ArrayA::numeric_type; + static_assert(std::is_same_v); using real_type = scalapackpp::detail::real_t; auto& world = A.world(); @@ -150,9 +149,9 @@ auto heig(const ArrayA& A, const ArrayB& B, scalapack::BlockCyclicMatrix evecs(world, grid, N, N, NB, NB); auto info = scalapackpp::hereig_gen( - scalapackpp::Job::Vec, blacspp::Uplo::Lower, N, - A_sca.local_mat().data(), 1, 1, desc, B_sca.local_mat().data(), 1, 1, - desc, evals.data(), evecs.local_mat().data(), 1, 1, desc); + scalapackpp::Job::Vec, blacspp::Uplo::Lower, N, A_sca.local_mat().data(), + 1, 1, desc, B_sca.local_mat().data(), 1, 1, desc, evals.data(), + evecs.local_mat().data(), 1, 1, desc); if (info) TA_EXCEPTION("EVP Failed"); if (evec_trange.rank() == 0) evec_trange = A.trange(); diff --git a/src/TiledArray/math/solvers/conjgrad.h b/src/TiledArray/math/solvers/conjgrad.h index 91992cf7de..cacfd55d63 100644 --- a/src/TiledArray/math/solvers/conjgrad.h +++ b/src/TiledArray/math/solvers/conjgrad.h @@ -60,7 +60,7 @@ namespace TiledArray::math { // clang-format on template struct ConjugateGradientSolver { - typedef typename D::element_type value_type; + typedef typename D::numeric_type value_type; /// \param a object of type F /// \param b RHS diff --git a/src/TiledArray/math/solvers/diis.h b/src/TiledArray/math/solvers/diis.h index 252d40480b..1407ff327e 100644 --- a/src/TiledArray/math/solvers/diis.h +++ b/src/TiledArray/math/solvers/diis.h @@ -82,7 +82,7 @@ namespace TiledArray::math { template class DIIS { public: - typedef typename D::element_type value_type; + typedef typename D::numeric_type value_type; typedef typename TiledArray::detail::scalar_t scalar_type; typedef Eigen::Matrix diff --git a/src/TiledArray/tensor/complex.h b/src/TiledArray/tensor/complex.h index 33698521a2..cfa330101d 100644 --- a/src/TiledArray/tensor/complex.h +++ b/src/TiledArray/tensor/complex.h @@ -274,6 +274,32 @@ inline auto abs(const ComplexConjugate& a) { inline int abs(const ComplexConjugate& a) { return 1; } +template >> +TILEDARRAY_FORCE_INLINE auto operator*(const L l, const std::complex r) { + return static_cast(l) * r; +} + +template >> +TILEDARRAY_FORCE_INLINE auto operator*(const std::complex l, const R r) { + return l * static_cast(r); +} + +template +TILEDARRAY_FORCE_INLINE + std::enable_if_t, std::complex> + operator*(const L l, const std::complex r) { + return std::complex(l, 0.) * r; +} + +template +TILEDARRAY_FORCE_INLINE + std::enable_if_t, std::complex> + operator*(const std::complex l, const R r) { + return l * std::complex(r, 0.); +} + } // namespace detail } // namespace TiledArray diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index f25bba68f7..fcb5ffbe7a 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -108,6 +108,32 @@ class Tensor { detail::is_tensor_of_tensor::value; }; + public: + /// compute type of Tensor with different element type + template ::template rebind_alloc> + using rebind_t = Tensor; + + template + struct rebind_numeric; + template + struct rebind_numeric::value>> { + using VU = typename V::template rebind_numeric::type; + using type = Tensor::template rebind_alloc>; + }; + template + struct rebind_numeric::value>> { + using type = Tensor< + U, typename std::allocator_traits::template rebind_alloc>; + }; + + /// compute type of Tensor with different numeric type + template + using rebind_numeric_t = typename rebind_numeric::type; + + private: using default_construct = bool; Tensor(const range_type& range, size_t batch_size, bool default_construct) @@ -1412,8 +1438,10 @@ class Tensor { template >::type* = nullptr> Tensor scale(const Scalar factor) const { - return unary( - [factor](const numeric_type a) -> numeric_type { return a * factor; }); + return unary([factor](const numeric_type a) -> numeric_type { + using namespace TiledArray::detail; + return a * factor; + }); } /// Construct a scaled and permuted copy of this tensor @@ -1429,7 +1457,10 @@ class Tensor { detail::is_permutation_v>> Tensor scale(const Scalar factor, const Perm& perm) const { return unary( - [factor](const numeric_type a) -> numeric_type { return a * factor; }, + [factor](const numeric_type a) -> numeric_type { + using namespace TiledArray::detail; + return a * factor; + }, perm); } @@ -2672,6 +2703,22 @@ struct transform> { }; } // namespace detail +namespace detail { + +template +struct real_t_impl> { + using type = typename Tensor::template rebind_numeric_t< + typename Tensor::scalar_type>; +}; + +template +struct complex_t_impl> { + using type = typename Tensor::template rebind_numeric_t< + std::complex::scalar_type>>; +}; + +} // namespace detail + #ifndef TILEDARRAY_HEADER_ONLY extern template class Tensor; diff --git a/src/TiledArray/tensor/tensor_interface.h b/src/TiledArray/tensor/tensor_interface.h index 76413a51a3..bc5e9abab2 100644 --- a/src/TiledArray/tensor/tensor_interface.h +++ b/src/TiledArray/tensor/tensor_interface.h @@ -191,7 +191,9 @@ class TensorInterface { template ::value>::type* = nullptr> TensorInterface_& operator=(const T1& other) { - TA_ASSERT(data_ != other.data()); + if constexpr (std::is_same_v>) { + TA_ASSERT(data_ != other.data()); + } detail::inplace_tensor_op([](numeric_type& MADNESS_RESTRICT result, const numeric_t arg) { result = arg; }, diff --git a/src/TiledArray/tile.h b/src/TiledArray/tile.h index 57366dbe60..b8242fbf19 100644 --- a/src/TiledArray/tile.h +++ b/src/TiledArray/tile.h @@ -95,6 +95,31 @@ class Tile { using scalar_type = typename TiledArray::detail::scalar_type< tensor_type>::type; ///< the scalar type that supports T + private: + template + struct rebind; + template + struct rebind>> { + using type = Tile>; + }; + + template + struct rebind_numeric; + template + struct rebind_numeric< + Numeric, std::enable_if_t>> { + using type = Tile>; + }; + + public: + /// compute type of Tile with different element type + template + using rebind_t = typename rebind::type; + + /// compute type of Tile with different numeric type + template + using rebind_numeric_t = typename rebind_numeric::type; + private: std::shared_ptr pimpl_; @@ -1648,6 +1673,22 @@ bool operator!=(const Tile& t1, const Tile& t2) { return !(t1 == t2); } +namespace detail { + +template +struct real_t_impl> { + using type = typename Tile::template rebind_numeric_t< + typename Tile::scalar_type>; +}; + +template +struct complex_t_impl> { + using type = typename Tile::template rebind_numeric_t< + std::complex::scalar_type>>; +}; + +} // namespace detail + } // namespace TiledArray #endif // TILEDARRAY_TILE_H__INCLUDED diff --git a/src/TiledArray/type_traits.h b/src/TiledArray/type_traits.h index ece535d929..47c90f0130 100644 --- a/src/TiledArray/type_traits.h +++ b/src/TiledArray/type_traits.h @@ -632,6 +632,42 @@ struct is_complex> : public std::true_type {}; template constexpr const bool is_complex_v = is_complex::value; +template +struct complex_t_impl; + +template +struct complex_t_impl> { + using type = std::complex; +}; + +template +struct complex_t_impl>> { + using type = std::complex; +}; + +/// evaluates to std::complex if T is real, else T +/// @note specialize complex_t_impl to customize the behavior for type T +template +using complex_t = typename complex_t_impl::type; + +template +struct real_t_impl; + +template +struct real_t_impl> { + using type = T; +}; + +template +struct real_t_impl>> { + using type = T; +}; + +/// evaluates to U if T is std::complex, or if T is real then evaluates to T +/// @note specialize real_t_impl to customize the behavior for type T +template +using real_t = typename real_t_impl::type; + template struct is_numeric : public std::is_arithmetic {}; @@ -760,6 +796,31 @@ struct scalar_type>::type> template using scalar_t = typename TiledArray::detail::scalar_type::type; +/// is true type if `T::rebind_t` is defined +template +struct has_rebind : std::false_type {}; +template +struct has_rebind>> + : std::true_type {}; + +/// alias to has_rebind::value +template +inline constexpr bool has_rebind_v = has_rebind::value; + +/// is true type if `T::rebind_numeric_t` is defined +template +struct has_rebind_numeric : std::false_type {}; +template +struct has_rebind_numeric< + T, Numeric, std::void_t>> + : std::true_type {}; + +/// alias to has_rebind_numeric::value +template +inline constexpr bool has_rebind_numeric_v = + has_rebind_numeric::value; + template struct is_strictly_ordered_helper { using Yes = char; diff --git a/tests/dist_array.cpp b/tests/dist_array.cpp index 4f2e1dbe9b..061c5fdd17 100644 --- a/tests/dist_array.cpp +++ b/tests/dist_array.cpp @@ -513,7 +513,7 @@ BOOST_AUTO_TEST_CASE(make_replicated) { BOOST_REQUIRE_NO_THROW(a.make_replicated()); // check for cda7b8a33b85f9ebe92bc369d6a362c94f1eae40 bug - for (const auto &tile : a) { + for (const auto& tile : a) { BOOST_CHECK(tile.get().size() != 0); } @@ -532,7 +532,6 @@ BOOST_AUTO_TEST_CASE(make_replicated) { it != tile.get().end(); ++it) BOOST_CHECK_EQUAL(*it, distributed_pmap->owner(i) + 1); } - } BOOST_AUTO_TEST_CASE(serialization_by_tile) { @@ -710,4 +709,38 @@ BOOST_AUTO_TEST_CASE(issue_225) { std::remove(archive_file_name); } +BOOST_AUTO_TEST_CASE(rebind) { + static_assert( + std::is_same_v, TArrayD>); + static_assert( + std::is_same_v, + TArrayD>); + static_assert( + std::is_same_v, TSpArrayD>); + static_assert( + std::is_same_v, + TSpArrayD>); + static_assert(std::is_same_v, TArrayD>); + static_assert( + std::is_same_v, TArrayZ>); + static_assert( + std::is_same_v, TSpArrayD>); + static_assert( + std::is_same_v, TSpArrayZ>); + + // DistArray of Tensors + using SpArrayTD = DistArray, SparsePolicy>; + using SpArrayTZ = DistArray, SparsePolicy>; + static_assert(std::is_same_v, + TSpArrayZ>); + static_assert( + std::is_same_v< + typename SpArrayTD::template rebind_numeric_t>, + SpArrayTZ>); + static_assert( + std::is_same_v, SpArrayTD>); + static_assert( + std::is_same_v, SpArrayTZ>); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/tests/tensor.cpp b/tests/tensor.cpp index b329b5af44..1281e5d164 100644 --- a/tests/tensor.cpp +++ b/tests/tensor.cpp @@ -724,4 +724,14 @@ BOOST_AUTO_TEST_CASE(block) { #endif } +BOOST_AUTO_TEST_CASE(rebind) { + static_assert( + std::is_same_v>, TensorZ>); + static_assert( + std::is_same_v>, TensorZ>); + static_assert( + std::is_same_v, TensorZ>); + static_assert(std::is_same_v, TensorD>); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/tests/tensor_of_tensor.cpp b/tests/tensor_of_tensor.cpp index 0f4683d174..21d136b67c 100644 --- a/tests/tensor_of_tensor.cpp +++ b/tests/tensor_of_tensor.cpp @@ -1234,4 +1234,19 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(serialization, ITensor, itensor_types) { cend(a_roundtrip)); } +BOOST_AUTO_TEST_CASE_TEMPLATE(rebind, ITensor, itensor_types) { + using ITensorD = typename ITensor::template rebind_t; + using ITensorZ = typename ITensor::template rebind_t>; + static_assert( + std::is_same_v::template rebind_t, + TensorD>); + static_assert(std::is_same_v< + typename Tensor::template rebind_numeric_t, + Tensor>); + static_assert(std::is_same_v>, + Tensor>); + static_assert(std::is_same_v>, + Tensor>); +} + BOOST_AUTO_TEST_SUITE_END()