Skip to content

Commit 5aa547c

Browse files
committed
Einsum: handle sparse DistArray
1 parent e967606 commit 5aa547c

File tree

2 files changed

+107
-28
lines changed

2 files changed

+107
-28
lines changed

src/TiledArray/einsum/tiledarray.h

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TILEDARRAY_EINSUM_H__INCLUDED
33

44
#include "TiledArray/fwd.h"
5+
#include "TiledArray/dist_array.h"
56
#include "TiledArray/expressions/fwd.h"
67
#include "TiledArray/einsum/index.h"
78
#include "TiledArray/einsum/range.h"
@@ -53,6 +54,8 @@ auto einsum(
5354
{
5455

5556
using Array = std::remove_cv_t<Array_>;
57+
using Tensor = typename Array::value_type;
58+
using Shape = typename Array::shape_type;
5659

5760
auto a = std::get<0>(Einsum::idx(A));
5861
auto b = std::get<0>(Einsum::idx(B));
@@ -103,6 +106,7 @@ auto einsum(
103106
TiledRange ei_tiled_range;
104107
Array ei;
105108
std::string expr;
109+
std::vector< std::pair<Einsum::Index<size_t>,Tensor> > local_tiles;
106110
bool own(Einsum::Index<size_t> h) const {
107111
for (Einsum::Index<size_t> ei : tiles) {
108112
auto idx = apply_inverse(permutation, h+ei);
@@ -149,7 +153,6 @@ auto einsum(
149153
}
150154

151155
using Index = Einsum::Index<size_t>;
152-
using Tensor = typename Array::value_type;
153156

154157
if constexpr(std::tuple_size<decltype(cs)>::value > 1) {
155158
TA_ASSERT(e);
@@ -169,7 +172,7 @@ auto einsum(
169172
for (size_t i = 0; i < h.size(); ++i) {
170173
batch *= H.batch[i].at(h[i]);
171174
}
172-
Tensor tile(TiledArray::Range{batch});
175+
Tensor tile(TiledArray::Range{batch}, typename Tensor::value_type());
173176
for (Index i : tiles) {
174177
// skip this unless both input tiles exist
175178
const auto pahi_inv = apply_inverse(pa,h+i);
@@ -208,6 +211,7 @@ auto einsum(
208211
}
209212

210213
std::vector< std::shared_ptr<World> > worlds;
214+
std::vector< std::tuple<Index,Tensor> > local_tiles;
211215

212216
// iterates over tiles of hadamard indices
213217
for (Index h : H.tiles) {
@@ -222,21 +226,29 @@ auto einsum(
222226
batch *= H.batch[i].at(h[i]);
223227
}
224228
for (auto &term : AB) {
225-
term.ei = Array(*owners, term.ei_tiled_range);
229+
term.local_tiles.clear();
226230
const Permutation &P = term.permutation;
227231
for (Index ei : term.tiles) {
228232
auto idx = apply_inverse(P, h+ei);
229233
if (!term.array.is_local(idx)) continue;
234+
if (term.array.is_zero(idx)) continue;
230235
auto tile = term.array.find(idx).get();
231236
if (P) tile = tile.permute(P);
232-
auto shape = term.ei.trange().tile(ei);
237+
auto shape = term.ei_tiled_range.tile(ei);
233238
tile = tile.reshape(shape, batch);
234-
term.ei.set(ei, tile);
239+
term.local_tiles.push_back({ei, tile});
235240
}
241+
term.ei = TiledArray::make_array<Array>(
242+
*owners,
243+
term.ei_tiled_range,
244+
term.local_tiles.begin(),
245+
term.local_tiles.end()
246+
);
236247
}
237248
C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners);
238249
for (Index e : C.tiles) {
239250
if (!C.ei.is_local(e)) continue;
251+
if (C.ei.is_zero(e)) continue;
240252
auto tile = C.ei.find(e).get();
241253
assert(tile.batch_size() == batch);
242254
const Permutation &P = C.permutation;
@@ -245,14 +257,29 @@ auto einsum(
245257
shape = apply_inverse(P, shape);
246258
tile = tile.reshape(shape);
247259
if (P) tile = tile.permute(P);
248-
C.array.set(c, tile);
260+
local_tiles.push_back({c, tile});
249261
}
250262
// mark for lazy deletion
251263
A.ei = Array();
252264
B.ei = Array();
253265
C.ei = Array();
254266
}
255267

268+
if constexpr (!Shape::is_dense()) {
269+
TiledRange tiled_range = TiledRange(range_map[c]);
270+
std::vector< std::pair<Index,float> > tile_norms;
271+
for (auto& [index,tile] : local_tiles) {
272+
tile_norms.push_back({index,tile.norm()});
273+
}
274+
Shape shape(world, tile_norms, tiled_range);
275+
C.array = Array(world, TiledRange(range_map[c]), shape);
276+
}
277+
278+
for (auto& [index,tile] : local_tiles) {
279+
if (C.array.is_zero(index)) continue;
280+
C.array.set(index, tile);
281+
}
282+
256283
for (auto &w : worlds) {
257284
w->gop.fence();
258285
}

tests/einsum.cpp

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -494,15 +494,28 @@ BOOST_AUTO_TEST_SUITE_END()
494494
// TiledArray einsum expressions
495495
BOOST_AUTO_TEST_SUITE(einsum_tiledarray)
496496

497-
template<typename T = Tensor<int>, typename ... Args>
497+
using TiledArray::SparsePolicy;
498+
using TiledArray::DensePolicy;
499+
500+
template<typename Policy, typename T = Tensor<int>, typename ... Args>
498501
auto random(Args ... args) {
499502
TiledArray::TiledRange tr{ {0, args}... };
500503
auto& world = TiledArray::get_default_world();
501-
TiledArray::DistArray<T,TiledArray::SparsePolicy> t(world,tr);
504+
TiledArray::DistArray<T,Policy> t(world,tr);
502505
t.fill_random();
503506
return t;
504507
}
505508

509+
template<typename T = Tensor<int>, typename ... Args>
510+
auto sparse_zero(Args ... args) {
511+
TiledArray::TiledRange tr{ {0, args}... };
512+
auto& world = TiledArray::get_default_world();
513+
TiledArray::SparsePolicy::shape_type shape(0.0f, tr);
514+
TiledArray::DistArray<T,TiledArray::SparsePolicy> t(world,tr,shape);
515+
t.fill(0);
516+
return t;
517+
}
518+
506519
template<int NA, int NB, int NC, typename T, typename Policy>
507520
void einsum_tiledarray_check(
508521
TiledArray::DistArray<T,Policy> &&A,
@@ -523,85 +536,124 @@ void einsum_tiledarray_check(
523536
array_to_eigen_tensor<Tensor<U,NB>>(B)
524537
);
525538
auto result = array_to_eigen_tensor<TC>(C);
539+
//std::cout << "e=" << result << std::endl;
526540
BOOST_CHECK(isApprox(result, reference));
527541
}
528542

529543
BOOST_AUTO_TEST_CASE(einsum_tiledarray_ak_bk_ab) {
530544
einsum_tiledarray_check<2,2,2>(
531-
random(11,7),
532-
random(13,7),
545+
random<SparsePolicy>(11,7),
546+
random<SparsePolicy>(13,7),
533547
"ak,bk->ab"
534548
);
535549
}
536550

537551
BOOST_AUTO_TEST_CASE(einsum_tiledarray_ka_bk_ba) {
538552
einsum_tiledarray_check<2,2,2>(
539-
random(7,11),
540-
random(13,7),
553+
random<SparsePolicy>(7,11),
554+
random<SparsePolicy>(13,7),
541555
"ka,bk->ba"
542556
);
543557
}
544558

545559
BOOST_AUTO_TEST_CASE(einsum_tiledarray_abi_cdi_cdab) {
546560
einsum_tiledarray_check<3,3,4>(
547-
random(21,22,3),
548-
random(24,25,3),
561+
random<SparsePolicy>(21,22,3),
562+
random<SparsePolicy>(24,25,3),
549563
"abi,cdi->cdab"
550564
);
551565
}
552566

553567
BOOST_AUTO_TEST_CASE(einsum_tiledarray_icd_ai_abcd) {
554568
einsum_tiledarray_check<3,3,4>(
555-
random(3,12,13),
556-
random(14,15,3),
569+
random<SparsePolicy>(3,12,13),
570+
random<SparsePolicy>(14,15,3),
557571
"icd,bai->abcd"
558572
);
559573
}
560574

561575
BOOST_AUTO_TEST_CASE(einsum_tiledarray_cdji_ibja_abcd) {
562576
einsum_tiledarray_check<4,4,4>(
563-
random(14,15,3,5),
564-
random(5,12,3,13),
577+
random<SparsePolicy>(14,15,3,5),
578+
random<SparsePolicy>(5,12,3,13),
565579
"cdji,ibja->abcd"
566580
);
567581
}
568582

569583
BOOST_AUTO_TEST_CASE(einsum_tiledarray_hai_hbi_hab) {
570584
einsum_tiledarray_check<3,3,3>(
571-
random(7,14,3),
572-
random(7,15,3),
585+
random<SparsePolicy>(7,14,3),
586+
random<SparsePolicy>(7,15,3),
587+
"hai,hbi->hab"
588+
);
589+
einsum_tiledarray_check<3,3,3>(
590+
sparse_zero(7,14,3),
591+
sparse_zero(7,15,3),
573592
"hai,hbi->hab"
574593
);
575594
}
576595

577596
BOOST_AUTO_TEST_CASE(einsum_tiledarray_iah_hib_bha) {
578597
einsum_tiledarray_check<3,3,3>(
579-
random(7,14,3),
580-
random(3,7,15),
598+
random<SparsePolicy>(7,14,3),
599+
random<SparsePolicy>(3,7,15),
600+
"iah,hib->bha"
601+
);
602+
einsum_tiledarray_check<3,3,3>(
603+
sparse_zero(7,14,3),
604+
sparse_zero(3,7,15),
581605
"iah,hib->bha"
582606
);
583607
}
584608

585609
BOOST_AUTO_TEST_CASE(einsum_tiledarray_iah_hib_abh) {
586610
einsum_tiledarray_check<3,3,3>(
587-
random(7,14,3),
588-
random(3,7,15),
611+
random<SparsePolicy>(7,14,3),
612+
random<SparsePolicy>(3,7,15),
589613
"iah,hib->abh"
590614
);
615+
einsum_tiledarray_check<3,3,3>(
616+
sparse_zero(7,14,3),
617+
sparse_zero(3,7,15),
618+
"iah,hib->abh"
619+
);
620+
}
621+
622+
BOOST_AUTO_TEST_CASE(einsum_tiledarray_hai_hibc_habc) {
623+
einsum_tiledarray_check<3,4,4>(
624+
random<SparsePolicy>(9,3,11),
625+
random<SparsePolicy>(9,11,5,7),
626+
"hai,hibc->habc"
627+
);
628+
einsum_tiledarray_check<3,4,4>(
629+
sparse_zero(9,3,11),
630+
sparse_zero(9,11,5,7),
631+
"hai,hibc->habc"
632+
);
591633
}
592634

593635
BOOST_AUTO_TEST_CASE(einsum_tiledarray_hi_hi_h) {
594636
einsum_tiledarray_check<2,2,1>(
595-
random(7,14),
596-
random(7,14),
637+
random<SparsePolicy>(7,14),
638+
random<SparsePolicy>(7,14),
639+
"hi,hi->h"
640+
);
641+
einsum_tiledarray_check<2,2,1>(
642+
sparse_zero(7,14),
643+
sparse_zero(7,14),
597644
"hi,hi->h"
598645
);
599646
}
600647

601648
BOOST_AUTO_TEST_CASE(einsum_tiledarray_hji_jih_hj) {
602649
einsum_tiledarray_check<3,3,2>(
603-
random(14,7,5),
604-
random(7,5,14),
650+
random<SparsePolicy>(14,7,5),
651+
random<SparsePolicy>(7,5,14),
652+
"hji,jih->hj"
653+
);
654+
einsum_tiledarray_check<3,3,2>(
655+
sparse_zero(14,7,5),
656+
sparse_zero(7,5,14),
605657
"hji,jih->hj"
606658
);
607659
}

0 commit comments

Comments
 (0)