Skip to content

[WIP] some einsum tests fail #391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ auto einsum(expressions::TsrExpr<Array_> A, expressions::TsrExpr<Array_> B,
for (size_t i = 0; i < h.size(); ++i) {
batch *= H.batch[i].at(h[i]);
}
Tensor tile(TiledArray::Range{batch}, typename Tensor::value_type());
Tensor tile(TiledArray::Range{batch}, typename Tensor::value_type(0));
for (Index i : tiles) {
// skip this unless both input tiles exist
const auto pahi_inv = apply_inverse(pa, h + i);
Expand All @@ -179,7 +179,7 @@ auto einsum(expressions::TsrExpr<Array_> A, expressions::TsrExpr<Array_> B,
bi = bi.reshape(shape, batch);
for (size_t k = 0; k < batch; ++k) {
auto hk = ai.batch(k).dot(bi.batch(k));
tile[k] = hk;
tile[k] += hk;
}
}
auto pc = C.permutation;
Expand Down
72 changes: 47 additions & 25 deletions src/TiledArray/util/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
#ifndef TILEDARRAY_RANDOM_H__INCLUDED
#define TILEDARRAY_RANDOM_H__INCLUDED

#include <complex> // for std::complex
#include <cstdlib> // for std::rand
#include <type_traits> // for true_type, false_type, and enable_if
#include <complex> // for std::complex
#include <cstdlib> // for std::rand
#include <type_traits> // for true_type, false_type, and enable_if

namespace TiledArray::detail {

Expand All @@ -37,28 +37,28 @@ namespace TiledArray::detail {
/// types. Specific types can be enabled by specializing CanMakeRandom.
///
/// \tparam ValueType The type of random value we are attempting to generate.
template<typename ValueType>
struct CanMakeRandom : std::false_type{};
template <typename ValueType>
struct CanMakeRandom : std::false_type {};

/// Enables generating random int values
template<>
struct CanMakeRandom<int> : std::true_type{};
template <>
struct CanMakeRandom<int> : std::true_type {};

/// Enables generating random float values
template<>
struct CanMakeRandom<float> : std::true_type{};
template <>
struct CanMakeRandom<float> : std::true_type {};

/// Enables generating random double values
template<>
struct CanMakeRandom<double> : std::true_type{};
template <>
struct CanMakeRandom<double> : std::true_type {};

/// Enables generating random std::complex<float> values
template<>
struct CanMakeRandom<std::complex<float>> : std::true_type{};
template <>
struct CanMakeRandom<std::complex<float>> : std::true_type {};

/// Enables generating random std::complex<double> values
template<>
struct CanMakeRandom<std::complex<double>> : std::true_type{};
template <>
struct CanMakeRandom<std::complex<double>> : std::true_type {};

/// Variable for whether or not we can make a random value of type ValueType
///
Expand All @@ -67,13 +67,13 @@ struct CanMakeRandom<std::complex<double>> : std::true_type{};
/// example `can_make_random_v<T>` is shorthand for `CanMakeRandom<T>::value`.
///
/// \tparam ValueType the type of random value we are attempting to make.
template<typename ValueType>
template <typename ValueType>
static constexpr auto can_make_random_v = CanMakeRandom<ValueType>::value;

/// Enables a function only when we can generate a random value of type `T`
///
/// \tparam T The type of random value we are attempting to generate.
template<typename T>
template <typename T>
using enable_if_can_make_random_t = std::enable_if_t<can_make_random_v<T>>;

//------------------------------------------------------------------------------
Expand All @@ -83,16 +83,36 @@ using enable_if_can_make_random_t = std::enable_if_t<can_make_random_v<T>>;
/// Struct wrapping the process of generating a random value of type `ValueType`
///
/// MakeRandom contains a single static member function `generate_value`, which
/// can be called to generate a random value of type `ValueType` between 0
/// and 1. Users can specialize the MakeRandom class to control how random
/// generates a random value using `std::rand()`. The default implementation is
/// only provided for fundamental types:
/// - for a floating-point type this returns a random value in [-1,1].
/// - for a signed integral type this returns a random value in [-4,4].
/// - for an unsigned integral type this returns a random value in [0,8].
/// Users can specialize the MakeRandom class to control how random
/// values of other types are formed.
///
/// \tparam ValueType The type of random value to generate
template<typename ValueType>
template <typename ValueType>
struct MakeRandom {
/// Generates a random value of type ValueType
static ValueType generate_value() {
return static_cast<ValueType>(static_cast<double>(std::rand()) / RAND_MAX);
static_assert(std::is_floating_point_v<ValueType> ||
std::is_integral_v<ValueType>);
if constexpr (std::is_floating_point_v<ValueType>)
return (2 * static_cast<ValueType>(std::rand()) / RAND_MAX) - 1;
else if constexpr (std::is_integral_v<ValueType>) {
static_assert(RAND_MAX == 2147483647);
static_assert(RAND_MAX % 2 == 1);
constexpr std::int64_t RAND_MAX_DIVBY_9 =
(static_cast<std::int64_t>(RAND_MAX) + 8) / 9;
const ValueType v = static_cast<ValueType>(
static_cast<std::int64_t>(std::rand()) / RAND_MAX_DIVBY_9);
if constexpr (std::is_signed_v<ValueType>) {
return v - 4;
} else {
return v;
}
}
}
};

Expand All @@ -105,18 +125,20 @@ struct MakeRandom {
///
/// \tparam ScalarType The type used to hold the real and imaginary components
/// of the complex value.
template<typename ScalarType>
template <typename ScalarType>
struct MakeRandom<std::complex<ScalarType>> {

/// Generates a random complex number.
static auto generate_value() {
static_assert(
std::is_floating_point_v<ScalarType>); // std::complex is only defined
// for fundamental
// floating-point types
const ScalarType real = MakeRandom<ScalarType>::generate_value();
const ScalarType imag = MakeRandom<ScalarType>::generate_value();
return std::complex<ScalarType>(real, imag);
}
};

} // namespace TiledArray::detail

} // namespace TiledArray::detail

#endif
24 changes: 22 additions & 2 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,10 +744,30 @@ BOOST_AUTO_TEST_SUITE_END()

#include "TiledArray/einsum/eigen.h"

template <typename... Tensors>
using common_scalar_t = std::common_type_t<typename Tensors::Scalar...>;

template <typename T>
auto abs_comparison_threshold_default() {
if constexpr (std::is_integral_v<T>) {
return 0;
} else {
return 1e-4;
}
}

template <typename TA, typename TB>
bool isApprox(const Eigen::TensorBase<TA, Eigen::ReadOnlyAccessors>& A,
const Eigen::TensorBase<TB, Eigen::ReadOnlyAccessors>& B) {
Eigen::Tensor<bool, 0> r = (derived(A) == derived(B)).all();
const Eigen::TensorBase<TB, Eigen::ReadOnlyAccessors>& B,
common_scalar_t<TA, TB> abs_comparison_threshold =
abs_comparison_threshold_default<common_scalar_t<TA, TB>>()) {
Eigen::Tensor<bool, 0> r;
if constexpr (std::is_integral_v<typename TA::Scalar> &&
std::is_integral_v<typename TB::Scalar>) {
r = (derived(A) == derived(B)).all();
} else { // soft floating-point comparison
r = ((derived(A) - derived(B)).abs() <= abs_comparison_threshold).all();
}
return r.coeffRef();
}

Expand Down