Skip to content

Commit 0be21fb

Browse files
committed
fix unit tests
1 parent e7907d8 commit 0be21fb

File tree

3 files changed

+117
-148
lines changed

3 files changed

+117
-148
lines changed

test/torchaudio_unittest/prototype/functional/functional_test_impl.py

Lines changed: 78 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -551,19 +551,24 @@ def test_simulate_rir_ism_multi_band(self, channel):
551551
[
552552
(0.1, 0.2, (2, 1, 2500)), # both float
553553
# Per-wall
554-
(torch.rand(4), 0.2, (2, 1, 2500)),
555-
(0.1, torch.rand(4), (2, 1, 2500)),
556-
(torch.rand(4), torch.rand(4), (2, 1, 2500)),
554+
(torch.rand(6), 0.2, (2, 1, 2500)),
555+
(0.1, torch.rand(6), (2, 1, 2500)),
556+
(torch.rand(6), torch.rand(6), (2, 1, 2500)),
557557
# Per-band and per-wall
558-
(torch.rand(6, 4), 0.2, (2, 6, 2500)),
559-
(0.1, torch.rand(6, 4), (2, 6, 2500)),
560-
(torch.rand(6, 4), torch.rand(6, 4), (2, 6, 2500)),
558+
(torch.rand(4, 6), 0.2, (2, 4, 2500)),
559+
(0.1, torch.rand(4, 6), (2, 4, 2500)),
560+
(torch.rand(4, 6), torch.rand(4, 6), (2, 4, 2500)),
561561
]
562562
)
563563
def test_ray_tracing_output_shape(self, absorption, scattering, expected_shape):
564-
room_dim = torch.tensor([20, 25], dtype=self.dtype)
565-
mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype)
566-
source = torch.tensor([7, 6], dtype=self.dtype)
564+
room_dim = torch.tensor([20, 25, 30], dtype=self.dtype)
565+
mic_array = torch.tensor([[2, 2, 2], [8, 8, 8]], dtype=self.dtype)
566+
source = torch.tensor([7, 6, 5], dtype=self.dtype)
567+
if isinstance(absorption, torch.Tensor):
568+
absorption = absorption.to(self.dtype)
569+
if isinstance(scattering, torch.Tensor):
570+
scattering = scattering.to(self.dtype)
571+
567572
num_rays = 100
568573

569574
hist = F.ray_tracing(
@@ -578,143 +583,125 @@ def test_ray_tracing_output_shape(self, absorption, scattering, expected_shape):
578583
assert hist.shape == expected_shape
579584

580585
def test_ray_tracing_input_errors(self):
581-
with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"):
586+
with self.assertRaisesRegex(ValueError, "room must be a 1D Tensor."):
582587
F.ray_tracing(
583588
room=torch.tensor([[4, 5]]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[3, 4]]), num_rays=10
584589
)
585-
with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"):
590+
with self.assertRaisesRegex(ValueError, "The shape of room must be"):
586591
F.ray_tracing(
587592
room=torch.tensor([4, 5, 4, 5]),
588593
source=torch.tensor([0, 0]),
589594
mic_array=torch.tensor([[3, 4]]),
590595
num_rays=10,
591596
)
592-
with self.assertRaisesRegex(ValueError, r"mic_array must be 1D tensor of shape \(D,\), or 2D tensor"):
597+
with self.assertRaisesRegex(ValueError, "The second dimension of mic_array must be 3"):
593598
F.ray_tracing(
594-
room=torch.tensor([4, 5]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[[3, 4]]]), num_rays=10
599+
room=torch.tensor([4, 5, 6]),
600+
source=torch.tensor([0, 0, 0]),
601+
mic_array=torch.tensor([[3, 4]]),
602+
num_rays=10,
595603
)
596604
with self.assertRaisesRegex(ValueError, "room must be of float32 or float64 dtype"):
597605
F.ray_tracing(
598-
room=torch.tensor([4, 5]).to(torch.int),
599-
source=torch.tensor([0, 0]),
600-
mic_array=torch.tensor([3, 4]),
606+
room=torch.tensor([4, 5, 6]).to(torch.int),
607+
source=torch.tensor([0, 0, 0]),
608+
mic_array=torch.tensor([[3, 4, 5]]),
601609
num_rays=10,
602610
)
603611
with self.assertRaisesRegex(ValueError, "dtype of room, source and mic_array must be the same"):
604612
F.ray_tracing(
605-
room=torch.tensor([4, 5]).to(torch.float64),
606-
source=torch.tensor([0, 0]).to(torch.float32),
607-
mic_array=torch.tensor([3, 4]),
608-
num_rays=10,
609-
)
610-
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
611-
F.ray_tracing(
612-
room=torch.tensor([4, 5, 10], dtype=torch.float),
613-
source=torch.tensor([0, 0], dtype=torch.float),
614-
mic_array=torch.tensor([3, 4], dtype=torch.float),
615-
num_rays=10,
616-
)
617-
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
618-
F.ray_tracing(
619-
room=torch.tensor([4, 5], dtype=torch.float),
620-
source=torch.tensor([0, 0, 0], dtype=torch.float),
621-
mic_array=torch.tensor([3, 4], dtype=torch.float),
622-
num_rays=10,
623-
)
624-
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
625-
F.ray_tracing(
626-
room=torch.tensor([4, 5, 10], dtype=torch.float),
627-
source=torch.tensor([0, 0, 0], dtype=torch.float),
628-
mic_array=torch.tensor([3, 4], dtype=torch.float),
613+
room=torch.tensor([4, 5, 6]).to(torch.float64),
614+
source=torch.tensor([0, 0, 0]).to(torch.float32),
615+
mic_array=torch.tensor([[3, 4, 5]]),
629616
num_rays=10,
630617
)
631618
with self.assertRaisesRegex(ValueError, "time_thres=10 must be at least greater than hist_bin_size=11"):
632619
F.ray_tracing(
633-
room=torch.tensor([4, 5], dtype=torch.float),
634-
source=torch.tensor([0, 0], dtype=torch.float),
635-
mic_array=torch.tensor([3, 4], dtype=torch.float),
620+
room=torch.tensor([4, 5, 6], dtype=torch.float),
621+
source=torch.tensor([0, 0, 0], dtype=torch.float),
622+
mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float),
636623
num_rays=10,
637624
time_thres=10,
638625
hist_bin_size=11,
639626
)
640-
with self.assertRaisesRegex(ValueError, "The shape of absorption must be"):
627+
with self.assertRaisesRegex(ValueError, "The shape of coefficient must be"):
641628
F.ray_tracing(
642-
room=torch.tensor([4, 5], dtype=torch.float),
643-
source=torch.tensor([0, 0], dtype=torch.float),
644-
mic_array=torch.tensor([3, 4], dtype=torch.float),
629+
room=torch.tensor([4, 5, 6], dtype=torch.float),
630+
source=torch.tensor([0, 0, 0], dtype=torch.float),
631+
mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float),
645632
num_rays=10,
646633
absorption=torch.rand(5, dtype=torch.float),
647634
)
648-
with self.assertRaisesRegex(ValueError, "The shape of scattering must be"):
635+
with self.assertRaisesRegex(ValueError, "The shape of coefficient must be"):
649636
F.ray_tracing(
650-
room=torch.tensor([4, 5], dtype=torch.float),
651-
source=torch.tensor([0, 0], dtype=torch.float),
652-
mic_array=torch.tensor([3, 4], dtype=torch.float),
637+
room=torch.tensor([4, 5, 6], dtype=torch.float),
638+
source=torch.tensor([0, 0, 0], dtype=torch.float),
639+
mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float),
653640
num_rays=10,
654641
scattering=torch.rand(5, 5, dtype=torch.float),
655642
)
656-
with self.assertRaisesRegex(ValueError, "The shape of absorption must be"):
643+
with self.assertRaisesRegex(ValueError, "The shape of coefficient must be"):
657644
F.ray_tracing(
658-
room=torch.tensor([4, 5], dtype=torch.float),
659-
source=torch.tensor([0, 0], dtype=torch.float),
660-
mic_array=torch.tensor([3, 4], dtype=torch.float),
645+
room=torch.tensor([4, 5, 6], dtype=torch.float),
646+
source=torch.tensor([0, 0, 0], dtype=torch.float),
647+
mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float),
661648
num_rays=10,
662649
absorption=torch.rand(5, 5, dtype=torch.float),
663650
)
664-
with self.assertRaisesRegex(ValueError, "The shape of scattering must be"):
651+
with self.assertRaisesRegex(ValueError, "The shape of coefficient must be"):
665652
F.ray_tracing(
666-
room=torch.tensor([4, 5], dtype=torch.float),
667-
source=torch.tensor([0, 0], dtype=torch.float),
668-
mic_array=torch.tensor([3, 4], dtype=torch.float),
653+
room=torch.tensor([4, 5, 6], dtype=torch.float),
654+
source=torch.tensor([0, 0, 0], dtype=torch.float),
655+
mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float),
669656
num_rays=10,
670657
scattering=torch.rand(5, dtype=torch.float),
671658
)
672659
with self.assertRaisesRegex(
673660
ValueError, "absorption and scattering must have the same number of bands and walls"
674661
):
675662
F.ray_tracing(
676-
room=torch.tensor([4, 5], dtype=torch.float),
677-
source=torch.tensor([0, 0], dtype=torch.float),
678-
mic_array=torch.tensor([3, 4], dtype=torch.float),
663+
room=torch.tensor([4, 5, 6], dtype=torch.float),
664+
source=torch.tensor([0, 0, 0], dtype=torch.float),
665+
mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float),
679666
num_rays=10,
680-
absorption=torch.rand(6, 4, dtype=torch.float),
681-
scattering=torch.rand(5, 4, dtype=torch.float),
667+
absorption=torch.rand(6, 6, dtype=torch.float),
668+
scattering=torch.rand(5, 6, dtype=torch.float),
682669
)
683670

684671
# Make sure passing different shapes for absorption or scattering doesn't raise an error
685672
# float and tensor
686673
F.ray_tracing(
687-
room=torch.tensor([4, 5], dtype=torch.float),
688-
source=torch.tensor([0, 0], dtype=torch.float),
689-
mic_array=torch.tensor([3, 4], dtype=torch.float),
674+
room=torch.tensor([4, 5, 6], dtype=torch.float),
675+
source=torch.tensor([0, 0, 0], dtype=torch.float),
676+
mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float),
690677
num_rays=10,
691678
absorption=0.1,
692-
scattering=torch.rand(5, 4, dtype=torch.float),
679+
scattering=torch.rand(5, 6, dtype=torch.float),
693680
)
694681
F.ray_tracing(
695-
room=torch.tensor([4, 5], dtype=torch.float),
696-
source=torch.tensor([0, 0], dtype=torch.float),
697-
mic_array=torch.tensor([3, 4], dtype=torch.float),
682+
room=torch.tensor([4, 5, 6], dtype=torch.float),
683+
source=torch.tensor([0, 0, 0], dtype=torch.float),
684+
mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float),
698685
num_rays=10,
699-
absorption=torch.rand(5, 4, dtype=torch.float),
686+
absorption=torch.rand(5, 6, dtype=torch.float),
700687
scattering=0.1,
701688
)
702689
# per-wall only and per-band + per-wall
703690
F.ray_tracing(
704-
room=torch.tensor([4, 5], dtype=torch.float),
705-
source=torch.tensor([0, 0], dtype=torch.float),
706-
mic_array=torch.tensor([3, 4], dtype=torch.float),
691+
room=torch.tensor([4, 5, 6], dtype=torch.float),
692+
source=torch.tensor([0, 0, 0], dtype=torch.float),
693+
mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float),
707694
num_rays=10,
708-
absorption=torch.rand(4, dtype=torch.float),
709-
scattering=torch.rand(6, 4, dtype=torch.float),
695+
absorption=torch.rand(6, dtype=torch.float),
696+
scattering=torch.rand(6, 6, dtype=torch.float),
710697
)
711698
F.ray_tracing(
712-
room=torch.tensor([4, 5], dtype=torch.float),
713-
source=torch.tensor([0, 0], dtype=torch.float),
714-
mic_array=torch.tensor([3, 4], dtype=torch.float),
699+
room=torch.tensor([4, 5, 6], dtype=torch.float),
700+
source=torch.tensor([0, 0, 0], dtype=torch.float),
701+
mic_array=torch.tensor([[3, 4, 5]], dtype=torch.float),
715702
num_rays=10,
716-
absorption=torch.rand(6, 4, dtype=torch.float),
717-
scattering=torch.rand(4, dtype=torch.float),
703+
absorption=torch.rand(6, 6, dtype=torch.float),
704+
scattering=torch.rand(6, dtype=torch.float),
718705
)
719706

720707
def test_ray_tracing_per_band_per_wall_absorption(self):
@@ -725,14 +712,14 @@ def test_ray_tracing_per_band_per_wall_absorption(self):
725712
(D,) tensor.
726713
"""
727714

728-
room_dim = torch.tensor([20, 25], dtype=self.dtype)
729-
mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype)
730-
source = torch.tensor([7, 6], dtype=self.dtype)
715+
room_dim = torch.tensor([20, 25, 30], dtype=self.dtype)
716+
mic_array = torch.tensor([[2, 2, 2], [8, 8, 8]], dtype=self.dtype)
717+
source = torch.tensor([7, 6, 5], dtype=self.dtype)
731718
num_rays = 1_000
732719
ABS, SCAT = 0.1, 0.2
733720

734-
absorption = torch.full(fill_value=ABS, size=(6, 4), dtype=self.dtype)
735-
scattering = torch.full(fill_value=SCAT, size=(6, 4), dtype=self.dtype)
721+
absorption = torch.full(fill_value=ABS, size=(4, 6), dtype=self.dtype)
722+
scattering = torch.full(fill_value=SCAT, size=(4, 6), dtype=self.dtype)
736723
hist_per_band_per_wall = F.ray_tracing(
737724
room=room_dim,
738725
source=source,
@@ -741,8 +728,8 @@ def test_ray_tracing_per_band_per_wall_absorption(self):
741728
absorption=absorption,
742729
scattering=scattering,
743730
)
744-
absorption = torch.full(fill_value=ABS, size=(4,), dtype=self.dtype)
745-
scattering = torch.full(fill_value=SCAT, size=(4,), dtype=self.dtype)
731+
absorption = torch.full(fill_value=ABS, size=(6,), dtype=self.dtype)
732+
scattering = torch.full(fill_value=SCAT, size=(6,), dtype=self.dtype)
746733
hist_per_wall = F.ray_tracing(
747734
room=room_dim,
748735
source=source,
@@ -762,22 +749,20 @@ def test_ray_tracing_per_band_per_wall_absorption(self):
762749
absorption=absorption,
763750
scattering=scattering,
764751
)
765-
assert hist_per_band_per_wall.shape == (2, 6, 2500)
752+
assert hist_per_band_per_wall.shape == (2, 4, 2500)
766753
assert hist_per_wall.shape == (2, 1, 2500)
767754
assert hist_single.shape == (2, 1, 2500)
768755
torch.testing.assert_close(hist_single, hist_per_wall)
769756

770-
hist_single = hist_single.expand(2, 6, 2500)
757+
hist_single = hist_single.expand(2, 4, 2500)
771758
torch.testing.assert_close(hist_single, hist_per_band_per_wall)
772759

773760
@parameterized.expand(
774761
[
775-
([20, 25], [2, 2], [[8, 8], [7, 6]], 10_000), # 2D with 2 mics
776762
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic
777763
]
778764
)
779765
def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays):
780-
781766
walls = ["west", "east", "south", "north"]
782767
if len(room_dim) == 3:
783768
walls += ["floor", "ceiling"]

test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,15 @@ def test_simulate_rir_ism_multi_band(self, channel):
115115

116116
@parameterized.expand(
117117
[
118-
([20, 25], [2, 2], [[8, 8], [7, 6]], 1_000), # 2D with 2 mics
119118
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500), # 3D with 1 mic
120119
]
121120
)
122121
def test_ray_tracing(self, room_dim, source, mic_array, num_rays):
123-
num_walls = 4 if len(room_dim) == 2 else 6
122+
num_walls = 6
124123
num_bands = 3
125124

126-
absorption = torch.rand(num_bands, num_walls, dtype=torch.float32)
127-
scattering = torch.rand(num_bands, num_walls, dtype=torch.float32)
125+
absorption = torch.rand(num_bands, num_walls, dtype=self.dtype)
126+
scattering = torch.rand(num_bands, num_walls, dtype=self.dtype)
128127

129128
energy_thres = 1e-7
130129
time_thres = 10.0

0 commit comments

Comments
 (0)