@@ -494,15 +494,28 @@ BOOST_AUTO_TEST_SUITE_END()
494
494
// TiledArray einsum expressions
495
495
BOOST_AUTO_TEST_SUITE(einsum_tiledarray)
496
496
497
- template<typename T = Tensor<int>, typename ... Args>
497
+ using TiledArray::SparsePolicy;
498
+ using TiledArray::DensePolicy;
499
+
500
+ template <typename Policy, typename T = Tensor<int >, typename ... Args>
498
501
auto random (Args ... args) {
499
502
TiledArray::TiledRange tr{ {0 , args}... };
500
503
auto & world = TiledArray::get_default_world ();
501
- TiledArray::DistArray<T,TiledArray::SparsePolicy > t (world,tr);
504
+ TiledArray::DistArray<T,Policy > t (world,tr);
502
505
t.fill_random ();
503
506
return t;
504
507
}
505
508
509
+ template <typename T = Tensor<int >, typename ... Args>
510
+ auto sparse_zero (Args ... args) {
511
+ TiledArray::TiledRange tr{ {0 , args}... };
512
+ auto & world = TiledArray::get_default_world ();
513
+ TiledArray::SparsePolicy::shape_type shape (0 .0f , tr);
514
+ TiledArray::DistArray<T,TiledArray::SparsePolicy> t (world,tr,shape);
515
+ t.fill (0 );
516
+ return t;
517
+ }
518
+
506
519
template <int NA, int NB, int NC, typename T, typename Policy>
507
520
void einsum_tiledarray_check (
508
521
TiledArray::DistArray<T,Policy> &&A,
@@ -523,85 +536,124 @@ void einsum_tiledarray_check(
523
536
array_to_eigen_tensor<Tensor<U,NB>>(B)
524
537
);
525
538
auto result = array_to_eigen_tensor<TC>(C);
539
+ // std::cout << "e=" << result << std::endl;
526
540
BOOST_CHECK (isApprox (result, reference));
527
541
}
528
542
529
543
BOOST_AUTO_TEST_CASE (einsum_tiledarray_ak_bk_ab) {
530
544
einsum_tiledarray_check<2 ,2 ,2 >(
531
- random (11 ,7 ),
532
- random (13 ,7 ),
545
+ random<SparsePolicy> (11 ,7 ),
546
+ random<SparsePolicy> (13 ,7 ),
533
547
" ak,bk->ab"
534
548
);
535
549
}
536
550
537
551
BOOST_AUTO_TEST_CASE (einsum_tiledarray_ka_bk_ba) {
538
552
einsum_tiledarray_check<2 ,2 ,2 >(
539
- random (7 ,11 ),
540
- random (13 ,7 ),
553
+ random<SparsePolicy> (7 ,11 ),
554
+ random<SparsePolicy> (13 ,7 ),
541
555
" ka,bk->ba"
542
556
);
543
557
}
544
558
545
559
BOOST_AUTO_TEST_CASE (einsum_tiledarray_abi_cdi_cdab) {
546
560
einsum_tiledarray_check<3 ,3 ,4 >(
547
- random (21 ,22 ,3 ),
548
- random (24 ,25 ,3 ),
561
+ random<SparsePolicy> (21 ,22 ,3 ),
562
+ random<SparsePolicy> (24 ,25 ,3 ),
549
563
" abi,cdi->cdab"
550
564
);
551
565
}
552
566
553
567
BOOST_AUTO_TEST_CASE (einsum_tiledarray_icd_ai_abcd) {
554
568
einsum_tiledarray_check<3 ,3 ,4 >(
555
- random (3 ,12 ,13 ),
556
- random (14 ,15 ,3 ),
569
+ random<SparsePolicy> (3 ,12 ,13 ),
570
+ random<SparsePolicy> (14 ,15 ,3 ),
557
571
" icd,bai->abcd"
558
572
);
559
573
}
560
574
561
575
BOOST_AUTO_TEST_CASE (einsum_tiledarray_cdji_ibja_abcd) {
562
576
einsum_tiledarray_check<4 ,4 ,4 >(
563
- random (14 ,15 ,3 ,5 ),
564
- random (5 ,12 ,3 ,13 ),
577
+ random<SparsePolicy> (14 ,15 ,3 ,5 ),
578
+ random<SparsePolicy> (5 ,12 ,3 ,13 ),
565
579
" cdji,ibja->abcd"
566
580
);
567
581
}
568
582
569
583
BOOST_AUTO_TEST_CASE (einsum_tiledarray_hai_hbi_hab) {
570
584
einsum_tiledarray_check<3 ,3 ,3 >(
571
- random (7 ,14 ,3 ),
572
- random (7 ,15 ,3 ),
585
+ random<SparsePolicy>(7 ,14 ,3 ),
586
+ random<SparsePolicy>(7 ,15 ,3 ),
587
+ " hai,hbi->hab"
588
+ );
589
+ einsum_tiledarray_check<3 ,3 ,3 >(
590
+ sparse_zero (7 ,14 ,3 ),
591
+ sparse_zero (7 ,15 ,3 ),
573
592
" hai,hbi->hab"
574
593
);
575
594
}
576
595
577
596
BOOST_AUTO_TEST_CASE (einsum_tiledarray_iah_hib_bha) {
578
597
einsum_tiledarray_check<3 ,3 ,3 >(
579
- random (7 ,14 ,3 ),
580
- random (3 ,7 ,15 ),
598
+ random<SparsePolicy>(7 ,14 ,3 ),
599
+ random<SparsePolicy>(3 ,7 ,15 ),
600
+ " iah,hib->bha"
601
+ );
602
+ einsum_tiledarray_check<3 ,3 ,3 >(
603
+ sparse_zero (7 ,14 ,3 ),
604
+ sparse_zero (3 ,7 ,15 ),
581
605
" iah,hib->bha"
582
606
);
583
607
}
584
608
585
609
BOOST_AUTO_TEST_CASE (einsum_tiledarray_iah_hib_abh) {
586
610
einsum_tiledarray_check<3 ,3 ,3 >(
587
- random (7 ,14 ,3 ),
588
- random (3 ,7 ,15 ),
611
+ random<SparsePolicy> (7 ,14 ,3 ),
612
+ random<SparsePolicy> (3 ,7 ,15 ),
589
613
" iah,hib->abh"
590
614
);
615
+ einsum_tiledarray_check<3 ,3 ,3 >(
616
+ sparse_zero (7 ,14 ,3 ),
617
+ sparse_zero (3 ,7 ,15 ),
618
+ " iah,hib->abh"
619
+ );
620
+ }
621
+
622
+ BOOST_AUTO_TEST_CASE (einsum_tiledarray_hai_hibc_habc) {
623
+ einsum_tiledarray_check<3 ,4 ,4 >(
624
+ random<SparsePolicy>(9 ,3 ,11 ),
625
+ random<SparsePolicy>(9 ,11 ,5 ,7 ),
626
+ " hai,hibc->habc"
627
+ );
628
+ einsum_tiledarray_check<3 ,4 ,4 >(
629
+ sparse_zero (9 ,3 ,11 ),
630
+ sparse_zero (9 ,11 ,5 ,7 ),
631
+ " hai,hibc->habc"
632
+ );
591
633
}
592
634
593
635
BOOST_AUTO_TEST_CASE (einsum_tiledarray_hi_hi_h) {
594
636
einsum_tiledarray_check<2 ,2 ,1 >(
595
- random (7 ,14 ),
596
- random (7 ,14 ),
637
+ random<SparsePolicy>(7 ,14 ),
638
+ random<SparsePolicy>(7 ,14 ),
639
+ " hi,hi->h"
640
+ );
641
+ einsum_tiledarray_check<2 ,2 ,1 >(
642
+ sparse_zero (7 ,14 ),
643
+ sparse_zero (7 ,14 ),
597
644
" hi,hi->h"
598
645
);
599
646
}
600
647
601
648
BOOST_AUTO_TEST_CASE (einsum_tiledarray_hji_jih_hj) {
602
649
einsum_tiledarray_check<3 ,3 ,2 >(
603
- random (14 ,7 ,5 ),
604
- random (7 ,5 ,14 ),
650
+ random<SparsePolicy>(14 ,7 ,5 ),
651
+ random<SparsePolicy>(7 ,5 ,14 ),
652
+ " hji,jih->hj"
653
+ );
654
+ einsum_tiledarray_check<3 ,3 ,2 >(
655
+ sparse_zero (14 ,7 ,5 ),
656
+ sparse_zero (7 ,5 ,14 ),
605
657
" hji,jih->hj"
606
658
);
607
659
}
0 commit comments