Skip to content

Commit fbf5404

Browse files
authored
Merge pull request #391 from ValeevGroup/evaleev/fix/einsum
[WIP] some einsum tests fail
2 parents 9c186b6 + e2009bf commit fbf5404

File tree

3 files changed

+71
-29
lines changed

3 files changed

+71
-29
lines changed

src/TiledArray/einsum/tiledarray.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ auto einsum(expressions::TsrExpr<Array_> A, expressions::TsrExpr<Array_> B,
163163
for (size_t i = 0; i < h.size(); ++i) {
164164
batch *= H.batch[i].at(h[i]);
165165
}
166-
Tensor tile(TiledArray::Range{batch}, typename Tensor::value_type());
166+
Tensor tile(TiledArray::Range{batch}, typename Tensor::value_type(0));
167167
for (Index i : tiles) {
168168
// skip this unless both input tiles exist
169169
const auto pahi_inv = apply_inverse(pa, h + i);
@@ -179,7 +179,7 @@ auto einsum(expressions::TsrExpr<Array_> A, expressions::TsrExpr<Array_> B,
179179
bi = bi.reshape(shape, batch);
180180
for (size_t k = 0; k < batch; ++k) {
181181
auto hk = ai.batch(k).dot(bi.batch(k));
182-
tile[k] = hk;
182+
tile[k] += hk;
183183
}
184184
}
185185
auto pc = C.permutation;

src/TiledArray/util/random.h

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
#ifndef TILEDARRAY_RANDOM_H__INCLUDED
2121
#define TILEDARRAY_RANDOM_H__INCLUDED
2222

23-
#include <complex> // for std::complex
24-
#include <cstdlib> // for std::rand
25-
#include <type_traits> // for true_type, false_type, and enable_if
23+
#include <complex> // for std::complex
24+
#include <cstdlib> // for std::rand
25+
#include <type_traits> // for true_type, false_type, and enable_if
2626

2727
namespace TiledArray::detail {
2828

@@ -37,28 +37,28 @@ namespace TiledArray::detail {
3737
/// types. Specific types can be enabled by specializing CanMakeRandom.
3838
///
3939
/// \tparam ValueType The type of random value we are attempting to generate.
40-
template<typename ValueType>
41-
struct CanMakeRandom : std::false_type{};
40+
template <typename ValueType>
41+
struct CanMakeRandom : std::false_type {};
4242

4343
/// Enables generating random int values
44-
template<>
45-
struct CanMakeRandom<int> : std::true_type{};
44+
template <>
45+
struct CanMakeRandom<int> : std::true_type {};
4646

4747
/// Enables generating random float values
48-
template<>
49-
struct CanMakeRandom<float> : std::true_type{};
48+
template <>
49+
struct CanMakeRandom<float> : std::true_type {};
5050

5151
/// Enables generating random double values
52-
template<>
53-
struct CanMakeRandom<double> : std::true_type{};
52+
template <>
53+
struct CanMakeRandom<double> : std::true_type {};
5454

5555
/// Enables generating random std::complex<float> values
56-
template<>
57-
struct CanMakeRandom<std::complex<float>> : std::true_type{};
56+
template <>
57+
struct CanMakeRandom<std::complex<float>> : std::true_type {};
5858

5959
/// Enables generating random std::complex<double> values
60-
template<>
61-
struct CanMakeRandom<std::complex<double>> : std::true_type{};
60+
template <>
61+
struct CanMakeRandom<std::complex<double>> : std::true_type {};
6262

6363
/// Variable for whether or not we can make a random value of type ValueType
6464
///
@@ -67,13 +67,13 @@ struct CanMakeRandom<std::complex<double>> : std::true_type{};
6767
/// example `can_make_random_v<T>` is shorthand for `CanMakeRandom<T>::value`.
6868
///
6969
/// \tparam ValueType the type of random value we are attempting to make.
70-
template<typename ValueType>
70+
template <typename ValueType>
7171
static constexpr auto can_make_random_v = CanMakeRandom<ValueType>::value;
7272

7373
/// Enables a function only when we can generate a random value of type `T`
7474
///
7575
/// \tparam T The type of random value we are attempting to generate.
76-
template<typename T>
76+
template <typename T>
7777
using enable_if_can_make_random_t = std::enable_if_t<can_make_random_v<T>>;
7878

7979
//------------------------------------------------------------------------------
@@ -83,16 +83,36 @@ using enable_if_can_make_random_t = std::enable_if_t<can_make_random_v<T>>;
8383
/// Struct wrapping the process of generating a random value of type `ValueType`
8484
///
8585
/// MakeRandom contains a single static member function `generate_value`, which
86-
/// can be called to generate a random value of type `ValueType` between 0
87-
/// and 1. Users can specialize the MakeRandom class to control how random
86+
/// generates a random value using `std::rand()`. The default implementation is
87+
/// only provided for fundamental types:
88+
/// - for a floating-point type this returns a random value in [-1,1].
89+
/// - for a signed integral type this returns a random value in [-4,4].
90+
/// - for an unsigned integral type this returns a random value in [0,8].
91+
/// Users can specialize the MakeRandom class to control how random
8892
/// values of other types are formed.
8993
///
9094
/// \tparam ValueType The type of random value to generate
91-
template<typename ValueType>
95+
template <typename ValueType>
9296
struct MakeRandom {
9397
/// Generates a random value of type ValueType
9498
static ValueType generate_value() {
95-
return static_cast<ValueType>(static_cast<double>(std::rand()) / RAND_MAX);
99+
static_assert(std::is_floating_point_v<ValueType> ||
100+
std::is_integral_v<ValueType>);
101+
if constexpr (std::is_floating_point_v<ValueType>)
102+
return (2 * static_cast<ValueType>(std::rand()) / RAND_MAX) - 1;
103+
else if constexpr (std::is_integral_v<ValueType>) {
104+
static_assert(RAND_MAX == 2147483647);
105+
static_assert(RAND_MAX % 2 == 1);
106+
constexpr std::int64_t RAND_MAX_DIVBY_9 =
107+
(static_cast<std::int64_t>(RAND_MAX) + 8) / 9;
108+
const ValueType v = static_cast<ValueType>(
109+
static_cast<std::int64_t>(std::rand()) / RAND_MAX_DIVBY_9);
110+
if constexpr (std::is_signed_v<ValueType>) {
111+
return v - 4;
112+
} else {
113+
return v;
114+
}
115+
}
96116
}
97117
};
98118

@@ -105,18 +125,20 @@ struct MakeRandom {
105125
///
106126
/// \tparam ScalarType The type used to hold the real and imaginary components
107127
/// of the complex value.
108-
template<typename ScalarType>
128+
template <typename ScalarType>
109129
struct MakeRandom<std::complex<ScalarType>> {
110-
111130
/// Generates a random complex number.
112131
static auto generate_value() {
132+
static_assert(
133+
std::is_floating_point_v<ScalarType>); // std::complex is only defined
134+
// for fundamental
135+
// floating-point types
113136
const ScalarType real = MakeRandom<ScalarType>::generate_value();
114137
const ScalarType imag = MakeRandom<ScalarType>::generate_value();
115138
return std::complex<ScalarType>(real, imag);
116139
}
117140
};
118141

119-
} // namespace TiledArray::detail
120-
142+
} // namespace TiledArray::detail
121143

122144
#endif

tests/einsum.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,10 +744,30 @@ BOOST_AUTO_TEST_SUITE_END()
744744

745745
#include "TiledArray/einsum/eigen.h"
746746

747+
template <typename... Tensors>
748+
using common_scalar_t = std::common_type_t<typename Tensors::Scalar...>;
749+
750+
template <typename T>
751+
auto abs_comparison_threshold_default() {
752+
if constexpr (std::is_integral_v<T>) {
753+
return 0;
754+
} else {
755+
return 1e-4;
756+
}
757+
}
758+
747759
template <typename TA, typename TB>
748760
bool isApprox(const Eigen::TensorBase<TA, Eigen::ReadOnlyAccessors>& A,
749-
const Eigen::TensorBase<TB, Eigen::ReadOnlyAccessors>& B) {
750-
Eigen::Tensor<bool, 0> r = (derived(A) == derived(B)).all();
761+
const Eigen::TensorBase<TB, Eigen::ReadOnlyAccessors>& B,
762+
common_scalar_t<TA, TB> abs_comparison_threshold =
763+
abs_comparison_threshold_default<common_scalar_t<TA, TB>>()) {
764+
Eigen::Tensor<bool, 0> r;
765+
if constexpr (std::is_integral_v<typename TA::Scalar> &&
766+
std::is_integral_v<typename TB::Scalar>) {
767+
r = (derived(A) == derived(B)).all();
768+
} else { // soft floating-point comparison
769+
r = ((derived(A) - derived(B)).abs() <= abs_comparison_threshold).all();
770+
}
751771
return r.coeffRef();
752772
}
753773

0 commit comments

Comments
 (0)