@@ -412,6 +412,258 @@ def test_exp_sigmoid_input_diff(self, linspace_input_values, exp_sigmoid_paramet
412
412
413
413
self .assertEqual (torch_out , torch .tensor (np_out ))
414
414
415
+ @parameterized .expand (
416
+ [
417
+ # both float
418
+ (0.1 , 0.2 , (2 , 1 , 2500 )),
419
+ # Per-wall
420
+ ((6 , ), 0.2 , (2 , 1 , 2500 )),
421
+ (0.1 , (6 , ), (2 , 1 , 2500 )),
422
+ ((6 , ), (6 , ), (2 , 1 , 2500 )),
423
+ # Per-band and per-wall
424
+ ((3 , 6 ), 0.2 , (2 , 3 , 2500 )),
425
+ (0.1 , (5 , 6 ), (2 , 5 , 2500 )),
426
+ ((7 , 6 ), (7 , 6 ), (2 , 7 , 2500 )),
427
+ ]
428
+ )
429
+ def test_ray_tracing_output_shape (self , abs_ , scat_ , expected_shape ):
430
+ if isinstance (abs_ , float ):
431
+ absorption = abs_
432
+ else :
433
+ absorption = torch .rand (abs_ , dtype = self .dtype )
434
+ if isinstance (scat_ , float ):
435
+ scattering = scat_
436
+ else :
437
+ scattering = torch .rand (scat_ , dtype = self .dtype )
438
+
439
+ room_dim = torch .tensor ([3 , 4 , 5 ], dtype = self .dtype )
440
+ mic_array = torch .tensor ([[0 , 0 , 0 ], [1 , 1 , 1 ]], dtype = self .dtype )
441
+ source = torch .tensor ([1 , 2 , 3 ], dtype = self .dtype )
442
+ num_rays = 100
443
+
444
+ hist = F .ray_tracing (
445
+ room = room_dim ,
446
+ source = source ,
447
+ mic_array = mic_array ,
448
+ num_rays = num_rays ,
449
+ absorption = absorption ,
450
+ scattering = scattering ,
451
+ )
452
+ assert hist .shape == expected_shape
453
+
454
+ def test_ray_tracing_input_errors (self ):
455
+ if self .dtype == torch .float64 :
456
+ import unittest
457
+
458
+ raise unittest .SkipTest ("float64 is not supported yet" )
459
+
460
+ room = torch .tensor ([3. , 4. , 5. ], dtype = self .dtype )
461
+ source = torch .tensor ([0. , 0. , 0. ], dtype = self .dtype )
462
+ mic = torch .tensor ([[1. , 2. , 3. ]], dtype = self .dtype )
463
+
464
+ _ = F .ray_tracing (room = room , source = source , mic_array = mic , num_rays = 10 )
465
+
466
+ for invalid in ([[4 , 5 ]], [4 , 5 , 4 , 5 ]):
467
+ invalid = torch .tensor (invalid , dtype = self .dtype )
468
+ with self .assertRaises (ValueError ) as cm :
469
+ F .ray_tracing (room = invalid , source = source , mic_array = mic , num_rays = 10 )
470
+
471
+ error = str (cm .exception )
472
+ self .assertIn ("`room` must be a 1D Tensor with 3 elements." , error )
473
+ self .assertIn (str (invalid .shape ), error )
474
+
475
+ invalid = torch .tensor ([[[3 , 4 ]]], dtype = self .dtype )
476
+ with self .assertRaises (ValueError ) as cm :
477
+ F .ray_tracing (room = room , source = source , mic_array = invalid , num_rays = 10 )
478
+
479
+ error = str (cm .exception )
480
+ self .assertIn ("`mic_array` must be a 2D Tensor with shape (num_channels, 3)." , error )
481
+ self .assertIn (str (invalid .shape ), error )
482
+
483
+ with self .assertRaises (ValueError ) as cm :
484
+ F .ray_tracing (
485
+ room = room .to (torch .float64 ),
486
+ source = source .to (torch .float32 ),
487
+ mic_array = mic ,
488
+ num_rays = 10 ,
489
+ )
490
+ error = str (cm .exception )
491
+ self .assertIn ("dtype of `room`, `source` and `mic_array` must match." , error )
492
+ self .assertIn ("`room` (torch.float64)" , error )
493
+ self .assertIn ("`source` (torch.float32)" , error )
494
+ self .assertIn ("`mic_array` (torch.float32)" , error )
495
+
496
+ with self .assertRaises (ValueError ) as cm :
497
+ F .ray_tracing (
498
+ room = room ,
499
+ source = source ,
500
+ mic_array = mic ,
501
+ num_rays = 10 ,
502
+ time_thres = 10 ,
503
+ hist_bin_size = 11 ,
504
+ )
505
+ error = str (cm .exception )
506
+ self .assertIn ("`time_thres` must be greater than `hist_bin_size`." , error )
507
+ self .assertIn ("hist_bin_size=11" , error )
508
+ self .assertIn ("time_thres=10" , error )
509
+
510
+ invalid_abs = torch .tensor ([1 , 2 , 3 ], dtype = self .dtype )
511
+ with self .assertRaises (ValueError ) as cm :
512
+ F .ray_tracing (
513
+ room = room ,
514
+ source = source ,
515
+ mic_array = mic ,
516
+ num_rays = 10 ,
517
+ absorption = invalid_abs ,
518
+ )
519
+ error = str (cm .exception )
520
+ self .assertIn ("The shape of `absorption` must be (6,) when" , error )
521
+ self .assertIn (str (invalid_abs .shape ), error )
522
+
523
+ invalid_scat = torch .tensor ([1 , 2 , 3 ], dtype = self .dtype )
524
+ with self .assertRaises (ValueError ) as cm :
525
+ F .ray_tracing (
526
+ room = room ,
527
+ source = source ,
528
+ mic_array = mic ,
529
+ num_rays = 10 ,
530
+ scattering = invalid_scat ,
531
+ )
532
+ error = str (cm .exception )
533
+ self .assertIn ("The shape of `scattering` must be (6,) when" , error )
534
+ self .assertIn (str (invalid_scat .shape ), error )
535
+
536
+ invalid_abs = torch .tensor ([[1 , 2 , 3 ]], dtype = self .dtype )
537
+ with self .assertRaises (ValueError ) as cm :
538
+ F .ray_tracing (
539
+ room = room ,
540
+ source = source ,
541
+ mic_array = mic ,
542
+ num_rays = 10 ,
543
+ absorption = invalid_abs
544
+ )
545
+ error = str (cm .exception )
546
+ self .assertIn ("The shape of `absorption` must be (NUM_BANDS, 6) when" , error )
547
+ self .assertIn (str (invalid_abs .shape ), error )
548
+
549
+ invalid_scat = torch .tensor ([[1 , 2 , 3 ]], dtype = self .dtype )
550
+ with self .assertRaises (ValueError ) as cm :
551
+ F .ray_tracing (
552
+ room = room ,
553
+ source = source ,
554
+ mic_array = mic ,
555
+ num_rays = 10 ,
556
+ scattering = invalid_scat
557
+ )
558
+ error = str (cm .exception )
559
+ self .assertIn ("The shape of `scattering` must be (NUM_BANDS, 6) when" , error )
560
+ self .assertIn (str (invalid_scat .shape ), error )
561
+
562
+ abs_ = torch .randn ((7 , 6 ), dtype = self .dtype )
563
+ scat = torch .randn ((5 , 6 ), dtype = self .dtype )
564
+ with self .assertRaises (ValueError ) as cm :
565
+ F .ray_tracing (
566
+ room = room ,
567
+ source = source ,
568
+ mic_array = mic ,
569
+ num_rays = 10 ,
570
+ absorption = abs_ ,
571
+ scattering = scat ,
572
+ )
573
+ error = str (cm .exception )
574
+ self .assertIn ("`absorption` and `scattering` must be broadcastable to the same number of bands and walls" , error )
575
+ self .assertIn (f"absorption={ abs_ .shape } " , error )
576
+ self .assertIn (f"scattering={ scat .shape } " , error )
577
+
578
+ # Make sure passing different shapes for absorption or scattering doesn't raise an error
579
+ # float and tensor
580
+ F .ray_tracing (
581
+ room = room ,
582
+ source = source ,
583
+ mic_array = mic ,
584
+ num_rays = 10 ,
585
+ absorption = 0.1 ,
586
+ scattering = torch .randn ((5 , 6 ), dtype = self .dtype ),
587
+ )
588
+ F .ray_tracing (
589
+ room = room ,
590
+ source = source ,
591
+ mic_array = mic ,
592
+ num_rays = 10 ,
593
+ absorption = torch .randn ((7 , 6 ), dtype = self .dtype ),
594
+ scattering = 0.1 ,
595
+ )
596
+ # per-wall only and per-band + per-wall
597
+ F .ray_tracing (
598
+ room = room ,
599
+ source = source ,
600
+ mic_array = mic ,
601
+ num_rays = 10 ,
602
+ absorption = torch .rand (6 , dtype = self .dtype ),
603
+ scattering = torch .rand (7 , 6 , dtype = self .dtype ),
604
+ )
605
+ F .ray_tracing (
606
+ room = room ,
607
+ source = source ,
608
+ mic_array = mic ,
609
+ num_rays = 10 ,
610
+ absorption = torch .rand (7 , 6 , dtype = self .dtype ),
611
+ scattering = torch .rand (6 , dtype = self .dtype ),
612
+ )
613
+
614
+ def test_ray_tracing_per_band_per_wall_absorption (self ):
615
+ """Check that when the value of absorption and scattering are the same
616
+ across walls and frequency bands, the output histograms are:
617
+ - all equal across frequency bands
618
+ - equal to simply passing a float value instead of a (num_bands, D) or
619
+ (D,) tensor.
620
+ """
621
+
622
+ room_dim = torch .tensor ([20 , 25 , 5 ], dtype = self .dtype )
623
+ mic_array = torch .tensor ([[2 , 2 , 0 ], [8 , 8 , 0 ]], dtype = self .dtype )
624
+ source = torch .tensor ([7 , 6 , 0 ], dtype = self .dtype )
625
+ num_rays = 1_000
626
+ ABS , SCAT = 0.1 , 0.2
627
+
628
+ absorption = torch .full (fill_value = ABS , size = (7 , 6 ), dtype = self .dtype )
629
+ scattering = torch .full (fill_value = SCAT , size = (7 , 6 ), dtype = self .dtype )
630
+ hist_per_band_per_wall = F .ray_tracing (
631
+ room = room_dim ,
632
+ source = source ,
633
+ mic_array = mic_array ,
634
+ num_rays = num_rays ,
635
+ absorption = absorption ,
636
+ scattering = scattering ,
637
+ )
638
+ absorption = torch .full (fill_value = ABS , size = (6 ,), dtype = self .dtype )
639
+ scattering = torch .full (fill_value = SCAT , size = (6 ,), dtype = self .dtype )
640
+ hist_per_wall = F .ray_tracing (
641
+ room = room_dim ,
642
+ source = source ,
643
+ mic_array = mic_array ,
644
+ num_rays = num_rays ,
645
+ absorption = absorption ,
646
+ scattering = scattering ,
647
+ )
648
+
649
+ absorption = ABS
650
+ scattering = SCAT
651
+ hist_single = F .ray_tracing (
652
+ room = room_dim ,
653
+ source = source ,
654
+ mic_array = mic_array ,
655
+ num_rays = num_rays ,
656
+ absorption = absorption ,
657
+ scattering = scattering ,
658
+ )
659
+ self .assertEqual (hist_per_band_per_wall .shape , (2 , 7 , 2500 ))
660
+ self .assertEqual (hist_per_wall .shape , (2 , 1 , 2500 ))
661
+ self .assertEqual (hist_single .shape , (2 , 1 , 2500 ))
662
+ torch .testing .assert_close (hist_single , hist_per_wall )
663
+
664
+ hist_single = hist_single .expand (hist_per_band_per_wall .shape )
665
+ torch .testing .assert_close (hist_single , hist_per_band_per_wall )
666
+
415
667
416
668
class Functional64OnlyTestImpl (TestBaseMixin ):
417
669
@nested_params (
0 commit comments