@@ -412,6 +412,255 @@ def test_exp_sigmoid_input_diff(self, linspace_input_values, exp_sigmoid_paramet
412
412
413
413
self .assertEqual (torch_out , torch .tensor (np_out ))
414
414
415
+ @parameterized .expand (
416
+ [
417
+ # both float
418
+ (0.1 , 0.2 , (2 , 1 , 2500 )),
419
+ # Per-wall
420
+ ((6 ,), 0.2 , (2 , 1 , 2500 )),
421
+ (0.1 , (6 ,), (2 , 1 , 2500 )),
422
+ ((6 ,), (6 ,), (2 , 1 , 2500 )),
423
+ # Per-band and per-wall
424
+ ((3 , 6 ), 0.2 , (2 , 3 , 2500 )),
425
+ (0.1 , (5 , 6 ), (2 , 5 , 2500 )),
426
+ ((7 , 6 ), (7 , 6 ), (2 , 7 , 2500 )),
427
+ ]
428
+ )
429
+ def test_ray_tracing_output_shape (self , abs_ , scat_ , expected_shape ):
430
+ if isinstance (abs_ , float ):
431
+ absorption = abs_
432
+ else :
433
+ absorption = torch .rand (abs_ , dtype = self .dtype )
434
+ if isinstance (scat_ , float ):
435
+ scattering = scat_
436
+ else :
437
+ scattering = torch .rand (scat_ , dtype = self .dtype )
438
+
439
+ room_dim = torch .tensor ([3 , 4 , 5 ], dtype = self .dtype )
440
+ mic_array = torch .tensor ([[0 , 0 , 0 ], [1 , 1 , 1 ]], dtype = self .dtype )
441
+ source = torch .tensor ([1 , 2 , 3 ], dtype = self .dtype )
442
+ num_rays = 100
443
+
444
+ hist = F .ray_tracing (
445
+ room = room_dim ,
446
+ source = source ,
447
+ mic_array = mic_array ,
448
+ num_rays = num_rays ,
449
+ absorption = absorption ,
450
+ scattering = scattering ,
451
+ )
452
+ assert hist .shape == expected_shape
453
+
454
+ def test_ray_tracing_input_errors (self ):
455
+ room = torch .tensor ([3.0 , 4.0 , 5.0 ], dtype = self .dtype )
456
+ source = torch .tensor ([0.0 , 0.0 , 0.0 ], dtype = self .dtype )
457
+ mic = torch .tensor ([[1.0 , 2.0 , 3.0 ]], dtype = self .dtype )
458
+
459
+ # baseline. This should not raise
460
+ _ = F .ray_tracing (room = room , source = source , mic_array = mic , num_rays = 10 )
461
+
462
+ # invlaid room shape
463
+ for invalid in ([[4 , 5 ]], [4 , 5 , 4 , 5 ]):
464
+ invalid = torch .tensor (invalid , dtype = self .dtype )
465
+ with self .assertRaises (ValueError ) as cm :
466
+ F .ray_tracing (room = invalid , source = source , mic_array = mic , num_rays = 10 )
467
+
468
+ error = str (cm .exception )
469
+ self .assertIn ("`room` must be a 1D Tensor with 3 elements." , error )
470
+ self .assertIn (str (invalid .shape ), error )
471
+
472
+ # invalid microphone shape
473
+ invalid = torch .tensor ([[[3 , 4 ]]], dtype = self .dtype )
474
+ with self .assertRaises (ValueError ) as cm :
475
+ F .ray_tracing (room = room , source = source , mic_array = invalid , num_rays = 10 )
476
+
477
+ error = str (cm .exception )
478
+ self .assertIn ("`mic_array` must be a 2D Tensor with shape (num_channels, 3)." , error )
479
+ self .assertIn (str (invalid .shape ), error )
480
+
481
+ # incompatible dtypes
482
+ with self .assertRaises (ValueError ) as cm :
483
+ F .ray_tracing (
484
+ room = room .to (torch .float64 ),
485
+ source = source .to (torch .float32 ),
486
+ mic_array = mic ,
487
+ num_rays = 10 ,
488
+ )
489
+ error = str (cm .exception )
490
+ self .assertIn ("dtype of `room`, `source` and `mic_array` must match." , error )
491
+ self .assertIn ("`room` (torch.float64)" , error )
492
+ self .assertIn ("`source` (torch.float32)" , error )
493
+ self .assertIn ("`mic_array` (torch.float32)" , error )
494
+
495
+ # invalid time configuration
496
+ with self .assertRaises (ValueError ) as cm :
497
+ F .ray_tracing (
498
+ room = room ,
499
+ source = source ,
500
+ mic_array = mic ,
501
+ num_rays = 10 ,
502
+ time_thres = 10 ,
503
+ hist_bin_size = 11 ,
504
+ )
505
+ error = str (cm .exception )
506
+ self .assertIn ("`time_thres` must be greater than `hist_bin_size`." , error )
507
+ self .assertIn ("hist_bin_size=11" , error )
508
+ self .assertIn ("time_thres=10" , error )
509
+
510
+ # invalid absorption shape 1D
511
+ invalid_abs = torch .tensor ([1 , 2 , 3 ], dtype = self .dtype )
512
+ with self .assertRaises (ValueError ) as cm :
513
+ F .ray_tracing (
514
+ room = room ,
515
+ source = source ,
516
+ mic_array = mic ,
517
+ num_rays = 10 ,
518
+ absorption = invalid_abs ,
519
+ )
520
+ error = str (cm .exception )
521
+ self .assertIn ("The shape of `absorption` must be (6,) when" , error )
522
+ self .assertIn (str (invalid_abs .shape ), error )
523
+
524
+ # invalid absorption shape 2D
525
+ invalid_abs = torch .tensor ([[1 , 2 , 3 ]], dtype = self .dtype )
526
+ with self .assertRaises (ValueError ) as cm :
527
+ F .ray_tracing (room = room , source = source , mic_array = mic , num_rays = 10 , absorption = invalid_abs )
528
+ error = str (cm .exception )
529
+ self .assertIn ("The shape of `absorption` must be (NUM_BANDS, 6) when" , error )
530
+ self .assertIn (str (invalid_abs .shape ), error )
531
+
532
+ # invalid scattering shape 1D
533
+ invalid_scat = torch .tensor ([1 , 2 , 3 ], dtype = self .dtype )
534
+ with self .assertRaises (ValueError ) as cm :
535
+ F .ray_tracing (
536
+ room = room ,
537
+ source = source ,
538
+ mic_array = mic ,
539
+ num_rays = 10 ,
540
+ scattering = invalid_scat ,
541
+ )
542
+ error = str (cm .exception )
543
+ self .assertIn ("The shape of `scattering` must be (6,) when" , error )
544
+ self .assertIn (str (invalid_scat .shape ), error )
545
+
546
+ # invalid scattering shape 2D
547
+ invalid_scat = torch .tensor ([[1 , 2 , 3 ]], dtype = self .dtype )
548
+ with self .assertRaises (ValueError ) as cm :
549
+ F .ray_tracing (room = room , source = source , mic_array = mic , num_rays = 10 , scattering = invalid_scat )
550
+ error = str (cm .exception )
551
+ self .assertIn ("The shape of `scattering` must be (NUM_BANDS, 6) when" , error )
552
+ self .assertIn (str (invalid_scat .shape ), error )
553
+
554
+ # TODO: Invalid absorption/scattering value
555
+
556
+ # incompatible scattering and absorption
557
+ abs_ = torch .zeros ((7 , 6 ), dtype = self .dtype )
558
+ scat = torch .zeros ((5 , 6 ), dtype = self .dtype )
559
+ with self .assertRaises (ValueError ) as cm :
560
+ F .ray_tracing (
561
+ room = room ,
562
+ source = source ,
563
+ mic_array = mic ,
564
+ num_rays = 10 ,
565
+ absorption = abs_ ,
566
+ scattering = scat ,
567
+ )
568
+ error = str (cm .exception )
569
+ self .assertIn (
570
+ "`absorption` and `scattering` must be broadcastable to the same number of bands and walls" , error
571
+ )
572
+ self .assertIn (f"absorption={ abs_ .shape } " , error )
573
+ self .assertIn (f"scattering={ scat .shape } " , error )
574
+
575
+ # Make sure passing different shapes for absorption or scattering doesn't raise an error
576
+ # float and tensor
577
+ F .ray_tracing (
578
+ room = room ,
579
+ source = source ,
580
+ mic_array = mic ,
581
+ num_rays = 10 ,
582
+ absorption = 0.1 ,
583
+ scattering = torch .randn ((5 , 6 ), dtype = self .dtype ),
584
+ )
585
+ F .ray_tracing (
586
+ room = room ,
587
+ source = source ,
588
+ mic_array = mic ,
589
+ num_rays = 10 ,
590
+ absorption = torch .randn ((7 , 6 ), dtype = self .dtype ),
591
+ scattering = 0.1 ,
592
+ )
593
+ # per-wall only and per-band + per-wall
594
+ F .ray_tracing (
595
+ room = room ,
596
+ source = source ,
597
+ mic_array = mic ,
598
+ num_rays = 10 ,
599
+ absorption = torch .rand (6 , dtype = self .dtype ),
600
+ scattering = torch .rand (7 , 6 , dtype = self .dtype ),
601
+ )
602
+ F .ray_tracing (
603
+ room = room ,
604
+ source = source ,
605
+ mic_array = mic ,
606
+ num_rays = 10 ,
607
+ absorption = torch .rand (7 , 6 , dtype = self .dtype ),
608
+ scattering = torch .rand (6 , dtype = self .dtype ),
609
+ )
610
+
611
+ def test_ray_tracing_per_band_per_wall_absorption (self ):
612
+ """Check that when the value of absorption and scattering are the same
613
+ across walls and frequency bands, the output histograms are:
614
+ - all equal across frequency bands
615
+ - equal to simply passing a float value instead of a (num_bands, D) or
616
+ (D,) tensor.
617
+ """
618
+
619
+ room_dim = torch .tensor ([20 , 25 , 5 ], dtype = self .dtype )
620
+ mic_array = torch .tensor ([[2 , 2 , 0 ], [8 , 8 , 0 ]], dtype = self .dtype )
621
+ source = torch .tensor ([7 , 6 , 0 ], dtype = self .dtype )
622
+ num_rays = 1_000
623
+ ABS , SCAT = 0.1 , 0.2
624
+
625
+ absorption = torch .full (fill_value = ABS , size = (7 , 6 ), dtype = self .dtype )
626
+ scattering = torch .full (fill_value = SCAT , size = (7 , 6 ), dtype = self .dtype )
627
+ hist_per_band_per_wall = F .ray_tracing (
628
+ room = room_dim ,
629
+ source = source ,
630
+ mic_array = mic_array ,
631
+ num_rays = num_rays ,
632
+ absorption = absorption ,
633
+ scattering = scattering ,
634
+ )
635
+ absorption = torch .full (fill_value = ABS , size = (6 ,), dtype = self .dtype )
636
+ scattering = torch .full (fill_value = SCAT , size = (6 ,), dtype = self .dtype )
637
+ hist_per_wall = F .ray_tracing (
638
+ room = room_dim ,
639
+ source = source ,
640
+ mic_array = mic_array ,
641
+ num_rays = num_rays ,
642
+ absorption = absorption ,
643
+ scattering = scattering ,
644
+ )
645
+
646
+ absorption = ABS
647
+ scattering = SCAT
648
+ hist_single = F .ray_tracing (
649
+ room = room_dim ,
650
+ source = source ,
651
+ mic_array = mic_array ,
652
+ num_rays = num_rays ,
653
+ absorption = absorption ,
654
+ scattering = scattering ,
655
+ )
656
+ self .assertEqual (hist_per_band_per_wall .shape , (2 , 7 , 2500 ))
657
+ self .assertEqual (hist_per_wall .shape , (2 , 1 , 2500 ))
658
+ self .assertEqual (hist_single .shape , (2 , 1 , 2500 ))
659
+ torch .testing .assert_close (hist_single , hist_per_wall )
660
+
661
+ hist_single = hist_single .expand (hist_per_band_per_wall .shape )
662
+ torch .testing .assert_close (hist_single , hist_per_band_per_wall )
663
+
415
664
416
665
class Functional64OnlyTestImpl (TestBaseMixin ):
417
666
@nested_params (
0 commit comments