@@ -412,6 +412,269 @@ def test_exp_sigmoid_input_diff(self, linspace_input_values, exp_sigmoid_paramet
412
412
413
413
self .assertEqual (torch_out , torch .tensor (np_out ))
414
414
415
+ @parameterized .expand (
416
+ [
417
+ # both float
418
+ (0.1 , 0.2 , (2 , 1 , 2500 )),
419
+ # Per-wall
420
+ ((6 ,), 0.2 , (2 , 1 , 2500 )),
421
+ (0.1 , (6 ,), (2 , 1 , 2500 )),
422
+ ((6 ,), (6 ,), (2 , 1 , 2500 )),
423
+ # Per-band and per-wall
424
+ ((3 , 6 ), 0.2 , (2 , 3 , 2500 )),
425
+ (0.1 , (5 , 6 ), (2 , 5 , 2500 )),
426
+ ((7 , 6 ), (7 , 6 ), (2 , 7 , 2500 )),
427
+ ]
428
+ )
429
+ def test_ray_tracing_output_shape (self , abs_ , scat_ , expected_shape ):
430
+ if isinstance (abs_ , float ):
431
+ absorption = abs_
432
+ else :
433
+ absorption = torch .rand (abs_ , dtype = self .dtype )
434
+ if isinstance (scat_ , float ):
435
+ scattering = scat_
436
+ else :
437
+ scattering = torch .rand (scat_ , dtype = self .dtype )
438
+
439
+ room_dim = torch .tensor ([3 , 4 , 5 ], dtype = self .dtype )
440
+ mic_array = torch .tensor ([[0 , 0 , 0 ], [1 , 1 , 1 ]], dtype = self .dtype )
441
+ source = torch .tensor ([1 , 2 , 3 ], dtype = self .dtype )
442
+ num_rays = 100
443
+
444
+ hist = F .ray_tracing (
445
+ room = room_dim ,
446
+ source = source ,
447
+ mic_array = mic_array ,
448
+ num_rays = num_rays ,
449
+ absorption = absorption ,
450
+ scattering = scattering ,
451
+ )
452
+ assert hist .shape == expected_shape
453
+
454
+ def test_ray_tracing_input_errors (self ):
455
+ room = torch .tensor ([3.0 , 4.0 , 5.0 ], dtype = self .dtype )
456
+ source = torch .tensor ([0.0 , 0.0 , 0.0 ], dtype = self .dtype )
457
+ mic = torch .tensor ([[1.0 , 2.0 , 3.0 ]], dtype = self .dtype )
458
+
459
+ # baseline. This should not raise
460
+ _ = F .ray_tracing (room = room , source = source , mic_array = mic , num_rays = 10 )
461
+
462
+ # invlaid room shape
463
+ for invalid in ([[4 , 5 ]], [4 , 5 , 4 , 5 ]):
464
+ invalid = torch .tensor (invalid , dtype = self .dtype )
465
+ with self .assertRaises (ValueError ) as cm :
466
+ F .ray_tracing (room = invalid , source = source , mic_array = mic , num_rays = 10 )
467
+
468
+ error = str (cm .exception )
469
+ self .assertIn ("`room` must be a 1D Tensor with 3 elements." , error )
470
+ self .assertIn (str (invalid .shape ), error )
471
+
472
+ # invalid microphone shape
473
+ invalid = torch .tensor ([[[3 , 4 ]]], dtype = self .dtype )
474
+ with self .assertRaises (ValueError ) as cm :
475
+ F .ray_tracing (room = room , source = source , mic_array = invalid , num_rays = 10 )
476
+
477
+ error = str (cm .exception )
478
+ self .assertIn ("`mic_array` must be a 2D Tensor with shape (num_channels, 3)." , error )
479
+ self .assertIn (str (invalid .shape ), error )
480
+
481
+ # incompatible dtypes
482
+ with self .assertRaises (ValueError ) as cm :
483
+ F .ray_tracing (
484
+ room = room .to (torch .float64 ),
485
+ source = source .to (torch .float32 ),
486
+ mic_array = mic .to (torch .float32 ),
487
+ num_rays = 10 ,
488
+ )
489
+ error = str (cm .exception )
490
+ self .assertIn ("dtype of `room`, `source` and `mic_array` must match." , error )
491
+ self .assertIn ("`room` (torch.float64)" , error )
492
+ self .assertIn ("`source` (torch.float32)" , error )
493
+ self .assertIn ("`mic_array` (torch.float32)" , error )
494
+
495
+ # invalid time configuration
496
+ with self .assertRaises (ValueError ) as cm :
497
+ F .ray_tracing (
498
+ room = room ,
499
+ source = source ,
500
+ mic_array = mic ,
501
+ num_rays = 10 ,
502
+ time_thres = 10 ,
503
+ hist_bin_size = 11 ,
504
+ )
505
+ error = str (cm .exception )
506
+ self .assertIn ("`time_thres` must be greater than `hist_bin_size`." , error )
507
+ self .assertIn ("hist_bin_size=11" , error )
508
+ self .assertIn ("time_thres=10" , error )
509
+
510
+ # invalid absorption shape 1D
511
+ invalid_abs = torch .tensor ([1 , 2 , 3 ], dtype = self .dtype )
512
+ with self .assertRaises (ValueError ) as cm :
513
+ F .ray_tracing (
514
+ room = room ,
515
+ source = source ,
516
+ mic_array = mic ,
517
+ num_rays = 10 ,
518
+ absorption = invalid_abs ,
519
+ )
520
+ error = str (cm .exception )
521
+ self .assertIn ("The shape of `absorption` must be (6,) when" , error )
522
+ self .assertIn (str (invalid_abs .shape ), error )
523
+
524
+ # invalid absorption shape 2D
525
+ invalid_abs = torch .tensor ([[1 , 2 , 3 ]], dtype = self .dtype )
526
+ with self .assertRaises (ValueError ) as cm :
527
+ F .ray_tracing (room = room , source = source , mic_array = mic , num_rays = 10 , absorption = invalid_abs )
528
+ error = str (cm .exception )
529
+ self .assertIn ("The shape of `absorption` must be (NUM_BANDS, 6) when" , error )
530
+ self .assertIn (str (invalid_abs .shape ), error )
531
+
532
+ # invalid scattering shape 1D
533
+ invalid_scat = torch .tensor ([1 , 2 , 3 ], dtype = self .dtype )
534
+ with self .assertRaises (ValueError ) as cm :
535
+ F .ray_tracing (
536
+ room = room ,
537
+ source = source ,
538
+ mic_array = mic ,
539
+ num_rays = 10 ,
540
+ scattering = invalid_scat ,
541
+ )
542
+ error = str (cm .exception )
543
+ self .assertIn ("The shape of `scattering` must be (6,) when" , error )
544
+ self .assertIn (str (invalid_scat .shape ), error )
545
+
546
+ # invalid scattering shape 2D
547
+ invalid_scat = torch .tensor ([[1 , 2 , 3 ]], dtype = self .dtype )
548
+ with self .assertRaises (ValueError ) as cm :
549
+ F .ray_tracing (room = room , source = source , mic_array = mic , num_rays = 10 , scattering = invalid_scat )
550
+ error = str (cm .exception )
551
+ self .assertIn ("The shape of `scattering` must be (NUM_BANDS, 6) when" , error )
552
+ self .assertIn (str (invalid_scat .shape ), error )
553
+
554
+ # Invalid absorption value
555
+ for invalid_val in [- 1. , torch .tensor ([i - 1. for i in range (6 )])]:
556
+ with self .assertRaises (ValueError ) as cm :
557
+ F .ray_tracing (room = room , source = source , mic_array = mic , num_rays = 10 , absorption = invalid_val )
558
+
559
+ error = str (cm .exception )
560
+ self .assertIn ("`absorption` must be non-negative`" )
561
+
562
+ # Invalid scattering value
563
+ for invalid_val in [- 1. , torch .tensor ([i - 1. for i in range (6 )])]:
564
+ with self .assertRaises (ValueError ) as cm :
565
+ F .ray_tracing (room = room , source = source , mic_array = mic , num_rays = 10 , scattering = invalid_val )
566
+
567
+ error = str (cm .exception )
568
+ self .assertIn ("`scattering` must be non-negative`" )
569
+
570
+ # incompatible scattering and absorption
571
+ abs_ = torch .zeros ((7 , 6 ), dtype = self .dtype )
572
+ scat = torch .zeros ((5 , 6 ), dtype = self .dtype )
573
+ with self .assertRaises (ValueError ) as cm :
574
+ F .ray_tracing (
575
+ room = room ,
576
+ source = source ,
577
+ mic_array = mic ,
578
+ num_rays = 10 ,
579
+ absorption = abs_ ,
580
+ scattering = scat ,
581
+ )
582
+ error = str (cm .exception )
583
+ self .assertIn (
584
+ "`absorption` and `scattering` must be broadcastable to the same number of bands and walls" , error
585
+ )
586
+ self .assertIn (f"absorption={ abs_ .shape } " , error )
587
+ self .assertIn (f"scattering={ scat .shape } " , error )
588
+
589
+ # Make sure passing different shapes for absorption or scattering doesn't raise an error
590
+ # float and tensor
591
+ F .ray_tracing (
592
+ room = room ,
593
+ source = source ,
594
+ mic_array = mic ,
595
+ num_rays = 10 ,
596
+ absorption = 0.1 ,
597
+ scattering = torch .rand ((5 , 6 ), dtype = self .dtype ),
598
+ )
599
+ F .ray_tracing (
600
+ room = room ,
601
+ source = source ,
602
+ mic_array = mic ,
603
+ num_rays = 10 ,
604
+ absorption = torch .rand ((7 , 6 ), dtype = self .dtype ),
605
+ scattering = 0.1 ,
606
+ )
607
+ # per-wall only and per-band + per-wall
608
+ F .ray_tracing (
609
+ room = room ,
610
+ source = source ,
611
+ mic_array = mic ,
612
+ num_rays = 10 ,
613
+ absorption = torch .rand (6 , dtype = self .dtype ),
614
+ scattering = torch .rand (7 , 6 , dtype = self .dtype ),
615
+ )
616
+ F .ray_tracing (
617
+ room = room ,
618
+ source = source ,
619
+ mic_array = mic ,
620
+ num_rays = 10 ,
621
+ absorption = torch .rand (7 , 6 , dtype = self .dtype ),
622
+ scattering = torch .rand (6 , dtype = self .dtype ),
623
+ )
624
+
625
+ def test_ray_tracing_per_band_per_wall_absorption (self ):
626
+ """Check that when the value of absorption and scattering are the same
627
+ across walls and frequency bands, the output histograms are:
628
+ - all equal across frequency bands
629
+ - equal to simply passing a float value instead of a (num_bands, D) or
630
+ (D,) tensor.
631
+ """
632
+
633
+ room_dim = torch .tensor ([20 , 25 , 5 ], dtype = self .dtype )
634
+ mic_array = torch .tensor ([[2 , 2 , 0 ], [8 , 8 , 0 ]], dtype = self .dtype )
635
+ source = torch .tensor ([7 , 6 , 0 ], dtype = self .dtype )
636
+ num_rays = 1_000
637
+ ABS , SCAT = 0.1 , 0.2
638
+
639
+ absorption = torch .full (fill_value = ABS , size = (7 , 6 ), dtype = self .dtype )
640
+ scattering = torch .full (fill_value = SCAT , size = (7 , 6 ), dtype = self .dtype )
641
+ hist_per_band_per_wall = F .ray_tracing (
642
+ room = room_dim ,
643
+ source = source ,
644
+ mic_array = mic_array ,
645
+ num_rays = num_rays ,
646
+ absorption = absorption ,
647
+ scattering = scattering ,
648
+ )
649
+ absorption = torch .full (fill_value = ABS , size = (6 ,), dtype = self .dtype )
650
+ scattering = torch .full (fill_value = SCAT , size = (6 ,), dtype = self .dtype )
651
+ hist_per_wall = F .ray_tracing (
652
+ room = room_dim ,
653
+ source = source ,
654
+ mic_array = mic_array ,
655
+ num_rays = num_rays ,
656
+ absorption = absorption ,
657
+ scattering = scattering ,
658
+ )
659
+
660
+ absorption = ABS
661
+ scattering = SCAT
662
+ hist_single = F .ray_tracing (
663
+ room = room_dim ,
664
+ source = source ,
665
+ mic_array = mic_array ,
666
+ num_rays = num_rays ,
667
+ absorption = absorption ,
668
+ scattering = scattering ,
669
+ )
670
+ self .assertEqual (hist_per_band_per_wall .shape , (2 , 7 , 2500 ))
671
+ self .assertEqual (hist_per_wall .shape , (2 , 1 , 2500 ))
672
+ self .assertEqual (hist_single .shape , (2 , 1 , 2500 ))
673
+ torch .testing .assert_close (hist_single , hist_per_wall )
674
+
675
+ hist_single = hist_single .expand (hist_per_band_per_wall .shape )
676
+ torch .testing .assert_close (hist_single , hist_per_band_per_wall )
677
+
415
678
416
679
class Functional64OnlyTestImpl (TestBaseMixin ):
417
680
@nested_params (
0 commit comments