@@ -460,3 +460,226 @@ def _debug_plot():
460
460
except AssertionError :
461
461
_debug_plot ()
462
462
raise
463
+
464
+ @parameterized .expand (
465
+ [
466
+ (0.1 , 0.2 , (2 , 1 , 2500 )), # both float
467
+ # Per-wall
468
+ (torch .rand (4 ), 0.2 , (2 , 1 , 2500 )),
469
+ (0.1 , torch .rand (4 ), (2 , 1 , 2500 )),
470
+ (torch .rand (4 ), torch .rand (4 ), (2 , 1 , 2500 )),
471
+ # Per-band and per-wall
472
+ (torch .rand (6 , 4 ), 0.2 , (2 , 6 , 2500 )),
473
+ (0.1 , torch .rand (6 , 4 ), (2 , 6 , 2500 )),
474
+ (torch .rand (6 , 4 ), torch .rand (6 , 4 ), (2 , 6 , 2500 )),
475
+ ]
476
+ )
477
+ def test_ray_tracing_output_shape (self , absorption , scattering , expected_shape ):
478
+ room_dim = torch .tensor ([20 , 25 ], dtype = self .dtype )
479
+ mic_array = torch .tensor ([[2 , 2 ], [8 , 8 ]], dtype = self .dtype )
480
+ source = torch .tensor ([7 , 6 ], dtype = self .dtype )
481
+ num_rays = 100
482
+
483
+ hist = F .ray_tracing (
484
+ room = room_dim ,
485
+ source = source ,
486
+ mic_array = mic_array ,
487
+ num_rays = num_rays ,
488
+ absorption = absorption ,
489
+ scattering = scattering ,
490
+ )
491
+
492
+ assert hist .shape == expected_shape
493
+
494
+ def test_ray_tracing_input_errors (self ):
495
+ with self .assertRaisesRegex (ValueError , "room must be a 1D tensor" ):
496
+ F .ray_tracing (
497
+ room = torch .tensor ([[4 , 5 ]]), source = torch .tensor ([0 , 0 ]), mic_array = torch .tensor ([[3 , 4 ]]), num_rays = 10
498
+ )
499
+ with self .assertRaisesRegex (ValueError , "room must be a 1D tensor" ):
500
+ F .ray_tracing (
501
+ room = torch .tensor ([4 , 5 , 4 , 5 ]),
502
+ source = torch .tensor ([0 , 0 ]),
503
+ mic_array = torch .tensor ([[3 , 4 ]]),
504
+ num_rays = 10 ,
505
+ )
506
+ with self .assertRaisesRegex (ValueError , r"mic_array must be 1D tensor of shape \(D,\), or 2D tensor" ):
507
+ F .ray_tracing (
508
+ room = torch .tensor ([4 , 5 ]), source = torch .tensor ([0 , 0 ]), mic_array = torch .tensor ([[[3 , 4 ]]]), num_rays = 10
509
+ )
510
+ with self .assertRaisesRegex (ValueError , "room must be of float32 or float64 dtype" ):
511
+ F .ray_tracing (
512
+ room = torch .tensor ([4 , 5 ]).to (torch .int ),
513
+ source = torch .tensor ([0 , 0 ]),
514
+ mic_array = torch .tensor ([3 , 4 ]),
515
+ num_rays = 10 ,
516
+ )
517
+ with self .assertRaisesRegex (ValueError , "dtype of room, source and mic_array must be the same" ):
518
+ F .ray_tracing (
519
+ room = torch .tensor ([4 , 5 ]).to (torch .float64 ),
520
+ source = torch .tensor ([0 , 0 ]).to (torch .float32 ),
521
+ mic_array = torch .tensor ([3 , 4 ]),
522
+ num_rays = 10 ,
523
+ )
524
+ with self .assertRaisesRegex (ValueError , "Room dimension D must match with source and mic_array" ):
525
+ F .ray_tracing (
526
+ room = torch .tensor ([4 , 5 , 10 ], dtype = torch .float ),
527
+ source = torch .tensor ([0 , 0 ], dtype = torch .float ),
528
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
529
+ num_rays = 10 ,
530
+ )
531
+ with self .assertRaisesRegex (ValueError , "Room dimension D must match with source and mic_array" ):
532
+ F .ray_tracing (
533
+ room = torch .tensor ([4 , 5 ], dtype = torch .float ),
534
+ source = torch .tensor ([0 , 0 , 0 ], dtype = torch .float ),
535
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
536
+ num_rays = 10 ,
537
+ )
538
+ with self .assertRaisesRegex (ValueError , "Room dimension D must match with source and mic_array" ):
539
+ F .ray_tracing (
540
+ room = torch .tensor ([4 , 5 , 10 ], dtype = torch .float ),
541
+ source = torch .tensor ([0 , 0 , 0 ], dtype = torch .float ),
542
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
543
+ num_rays = 10 ,
544
+ )
545
+ with self .assertRaisesRegex (ValueError , "time_thres=10 must be at least greater than hist_bin_size=11" ):
546
+ F .ray_tracing (
547
+ room = torch .tensor ([4 , 5 ], dtype = torch .float ),
548
+ source = torch .tensor ([0 , 0 ], dtype = torch .float ),
549
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
550
+ num_rays = 10 ,
551
+ time_thres = 10 ,
552
+ hist_bin_size = 11 ,
553
+ )
554
+ with self .assertRaisesRegex (ValueError , "The shape of absorption must be" ):
555
+ F .ray_tracing (
556
+ room = torch .tensor ([4 , 5 ], dtype = torch .float ),
557
+ source = torch .tensor ([0 , 0 ], dtype = torch .float ),
558
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
559
+ num_rays = 10 ,
560
+ absorption = torch .rand (5 , dtype = torch .float ),
561
+ )
562
+ with self .assertRaisesRegex (ValueError , "The shape of scattering must be" ):
563
+ F .ray_tracing (
564
+ room = torch .tensor ([4 , 5 ], dtype = torch .float ),
565
+ source = torch .tensor ([0 , 0 ], dtype = torch .float ),
566
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
567
+ num_rays = 10 ,
568
+ scattering = torch .rand (5 , 5 , dtype = torch .float ),
569
+ )
570
+ with self .assertRaisesRegex (ValueError , "The shape of absorption must be" ):
571
+ F .ray_tracing (
572
+ room = torch .tensor ([4 , 5 ], dtype = torch .float ),
573
+ source = torch .tensor ([0 , 0 ], dtype = torch .float ),
574
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
575
+ num_rays = 10 ,
576
+ absorption = torch .rand (5 , 5 , dtype = torch .float ),
577
+ )
578
+ with self .assertRaisesRegex (ValueError , "The shape of scattering must be" ):
579
+ F .ray_tracing (
580
+ room = torch .tensor ([4 , 5 ], dtype = torch .float ),
581
+ source = torch .tensor ([0 , 0 ], dtype = torch .float ),
582
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
583
+ num_rays = 10 ,
584
+ scattering = torch .rand (5 , dtype = torch .float ),
585
+ )
586
+ with self .assertRaisesRegex (
587
+ ValueError , "absorption and scattering must have the same number of bands and walls"
588
+ ):
589
+ F .ray_tracing (
590
+ room = torch .tensor ([4 , 5 ], dtype = torch .float ),
591
+ source = torch .tensor ([0 , 0 ], dtype = torch .float ),
592
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
593
+ num_rays = 10 ,
594
+ absorption = torch .rand (6 , 4 , dtype = torch .float ),
595
+ scattering = torch .rand (5 , 4 , dtype = torch .float ),
596
+ )
597
+
598
+ # Make sure passing different shapes for absorption or scattering doesn't raise an error
599
+ # float and tensor
600
+ F .ray_tracing (
601
+ room = torch .tensor ([4 , 5 ], dtype = torch .float ),
602
+ source = torch .tensor ([0 , 0 ], dtype = torch .float ),
603
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
604
+ num_rays = 10 ,
605
+ absorption = 0.1 ,
606
+ scattering = torch .rand (5 , 4 , dtype = torch .float ),
607
+ )
608
+ F .ray_tracing (
609
+ room = torch .tensor ([4 , 5 ], dtype = torch .float ),
610
+ source = torch .tensor ([0 , 0 ], dtype = torch .float ),
611
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
612
+ num_rays = 10 ,
613
+ absorption = torch .rand (5 , 4 , dtype = torch .float ),
614
+ scattering = 0.1 ,
615
+ )
616
+ # per-wall only and per-band + per-wall
617
+ F .ray_tracing (
618
+ room = torch .tensor ([4 , 5 ], dtype = torch .float ),
619
+ source = torch .tensor ([0 , 0 ], dtype = torch .float ),
620
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
621
+ num_rays = 10 ,
622
+ absorption = torch .rand (4 , dtype = torch .float ),
623
+ scattering = torch .rand (6 , 4 , dtype = torch .float ),
624
+ )
625
+ F .ray_tracing (
626
+ room = torch .tensor ([4 , 5 ], dtype = torch .float ),
627
+ source = torch .tensor ([0 , 0 ], dtype = torch .float ),
628
+ mic_array = torch .tensor ([3 , 4 ], dtype = torch .float ),
629
+ num_rays = 10 ,
630
+ absorption = torch .rand (6 , 4 , dtype = torch .float ),
631
+ scattering = torch .rand (4 , dtype = torch .float ),
632
+ )
633
+
634
+ def test_ray_tracing_per_band_per_wall_absorption (self ):
635
+ """Check that when the value of absorption and scattering are the same
636
+ across walls and frequency bands, the output histograms are:
637
+ - all equal across frequency bands
638
+ - equal to simply passing a float value instead of a (num_bands, D) or
639
+ (D,) tensor.
640
+ """
641
+
642
+ room_dim = torch .tensor ([20 , 25 ], dtype = self .dtype )
643
+ mic_array = torch .tensor ([[2 , 2 ], [8 , 8 ]], dtype = self .dtype )
644
+ source = torch .tensor ([7 , 6 ], dtype = self .dtype )
645
+ num_rays = 1_000
646
+ ABS , SCAT = 0.1 , 0.2
647
+
648
+ absorption = torch .full (fill_value = ABS , size = (6 , 4 ), dtype = self .dtype )
649
+ scattering = torch .full (fill_value = SCAT , size = (6 , 4 ), dtype = self .dtype )
650
+ hist_per_band_per_wall = F .ray_tracing (
651
+ room = room_dim ,
652
+ source = source ,
653
+ mic_array = mic_array ,
654
+ num_rays = num_rays ,
655
+ absorption = absorption ,
656
+ scattering = scattering ,
657
+ )
658
+ absorption = torch .full (fill_value = ABS , size = (4 ,), dtype = self .dtype )
659
+ scattering = torch .full (fill_value = SCAT , size = (4 ,), dtype = self .dtype )
660
+ hist_per_wall = F .ray_tracing (
661
+ room = room_dim ,
662
+ source = source ,
663
+ mic_array = mic_array ,
664
+ num_rays = num_rays ,
665
+ absorption = absorption ,
666
+ scattering = scattering ,
667
+ )
668
+
669
+ absorption = ABS
670
+ scattering = SCAT
671
+ hist_single = F .ray_tracing (
672
+ room = room_dim ,
673
+ source = source ,
674
+ mic_array = mic_array ,
675
+ num_rays = num_rays ,
676
+ absorption = absorption ,
677
+ scattering = scattering ,
678
+ )
679
+ assert hist_per_band_per_wall .shape == (2 , 6 , 2500 )
680
+ assert hist_per_wall .shape == (2 , 1 , 2500 )
681
+ assert hist_single .shape == (2 , 1 , 2500 )
682
+ torch .testing .assert_close (hist_single , hist_per_wall )
683
+
684
+ hist_single = hist_single .expand (2 , 6 , 2500 )
685
+ torch .testing .assert_close (hist_single , hist_per_band_per_wall )
0 commit comments