Skip to content

Commit f999adb

Browse files
committed
WIP: Add Ray Tracing (#3604)
Summary: Revamped version of #3234 (which was also revamp of #2850) Differential Revision: D49197174
1 parent d8e6ec5 commit f999adb

File tree

7 files changed

+853
-24
lines changed

7 files changed

+853
-24
lines changed

docs/source/prototype.functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ Room Impulse Response Simulation
3535
:toctree: generated
3636
:nosignatures:
3737

38+
ray_tracing
3839
simulate_rir_ism

test/torchaudio_unittest/prototype/functional/functional_test_impl.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
if _mod_utils.is_module_available("pyroomacoustics"):
44
import pyroomacoustics as pra
55

6+
import numpy as np
67
import torch
78
import torchaudio.prototype.functional as F
89
from parameterized import param, parameterized
@@ -545,3 +546,303 @@ def test_simulate_rir_ism_multi_band(self, channel):
545546
expected[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0])
546547
actual = F.simulate_rir_ism(room_dim, source, mic_array, max_order, absorption)
547548
self.assertEqual(expected, actual, atol=1e-3, rtol=1e-3)
549+
550+
@parameterized.expand(
551+
[
552+
(0.1, 0.2, (2, 1, 2500)), # both float
553+
# Per-wall
554+
(torch.rand(4), 0.2, (2, 1, 2500)),
555+
(0.1, torch.rand(4), (2, 1, 2500)),
556+
(torch.rand(4), torch.rand(4), (2, 1, 2500)),
557+
# Per-band and per-wall
558+
(torch.rand(6, 4), 0.2, (2, 6, 2500)),
559+
(0.1, torch.rand(6, 4), (2, 6, 2500)),
560+
(torch.rand(6, 4), torch.rand(6, 4), (2, 6, 2500)),
561+
]
562+
)
563+
def test_ray_tracing_output_shape(self, absorption, scattering, expected_shape):
564+
room_dim = torch.tensor([20, 25], dtype=self.dtype)
565+
mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype)
566+
source = torch.tensor([7, 6], dtype=self.dtype)
567+
num_rays = 100
568+
569+
hist = F.ray_tracing(
570+
room=room_dim,
571+
source=source,
572+
mic_array=mic_array,
573+
num_rays=num_rays,
574+
absorption=absorption,
575+
scattering=scattering,
576+
)
577+
578+
assert hist.shape == expected_shape
579+
580+
def test_ray_tracing_input_errors(self):
581+
with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"):
582+
F.ray_tracing(
583+
room=torch.tensor([[4, 5]]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[3, 4]]), num_rays=10
584+
)
585+
with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"):
586+
F.ray_tracing(
587+
room=torch.tensor([4, 5, 4, 5]),
588+
source=torch.tensor([0, 0]),
589+
mic_array=torch.tensor([[3, 4]]),
590+
num_rays=10,
591+
)
592+
with self.assertRaisesRegex(ValueError, r"mic_array must be 1D tensor of shape \(D,\), or 2D tensor"):
593+
F.ray_tracing(
594+
room=torch.tensor([4, 5]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[[3, 4]]]), num_rays=10
595+
)
596+
with self.assertRaisesRegex(ValueError, "room must be of float32 or float64 dtype"):
597+
F.ray_tracing(
598+
room=torch.tensor([4, 5]).to(torch.int),
599+
source=torch.tensor([0, 0]),
600+
mic_array=torch.tensor([3, 4]),
601+
num_rays=10,
602+
)
603+
with self.assertRaisesRegex(ValueError, "dtype of room, source and mic_array must be the same"):
604+
F.ray_tracing(
605+
room=torch.tensor([4, 5]).to(torch.float64),
606+
source=torch.tensor([0, 0]).to(torch.float32),
607+
mic_array=torch.tensor([3, 4]),
608+
num_rays=10,
609+
)
610+
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
611+
F.ray_tracing(
612+
room=torch.tensor([4, 5, 10], dtype=torch.float),
613+
source=torch.tensor([0, 0], dtype=torch.float),
614+
mic_array=torch.tensor([3, 4], dtype=torch.float),
615+
num_rays=10,
616+
)
617+
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
618+
F.ray_tracing(
619+
room=torch.tensor([4, 5], dtype=torch.float),
620+
source=torch.tensor([0, 0, 0], dtype=torch.float),
621+
mic_array=torch.tensor([3, 4], dtype=torch.float),
622+
num_rays=10,
623+
)
624+
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
625+
F.ray_tracing(
626+
room=torch.tensor([4, 5, 10], dtype=torch.float),
627+
source=torch.tensor([0, 0, 0], dtype=torch.float),
628+
mic_array=torch.tensor([3, 4], dtype=torch.float),
629+
num_rays=10,
630+
)
631+
with self.assertRaisesRegex(ValueError, "time_thres=10 must be at least greater than hist_bin_size=11"):
632+
F.ray_tracing(
633+
room=torch.tensor([4, 5], dtype=torch.float),
634+
source=torch.tensor([0, 0], dtype=torch.float),
635+
mic_array=torch.tensor([3, 4], dtype=torch.float),
636+
num_rays=10,
637+
time_thres=10,
638+
hist_bin_size=11,
639+
)
640+
with self.assertRaisesRegex(ValueError, "The shape of absorption must be"):
641+
F.ray_tracing(
642+
room=torch.tensor([4, 5], dtype=torch.float),
643+
source=torch.tensor([0, 0], dtype=torch.float),
644+
mic_array=torch.tensor([3, 4], dtype=torch.float),
645+
num_rays=10,
646+
absorption=torch.rand(5, dtype=torch.float),
647+
)
648+
with self.assertRaisesRegex(ValueError, "The shape of scattering must be"):
649+
F.ray_tracing(
650+
room=torch.tensor([4, 5], dtype=torch.float),
651+
source=torch.tensor([0, 0], dtype=torch.float),
652+
mic_array=torch.tensor([3, 4], dtype=torch.float),
653+
num_rays=10,
654+
scattering=torch.rand(5, 5, dtype=torch.float),
655+
)
656+
with self.assertRaisesRegex(ValueError, "The shape of absorption must be"):
657+
F.ray_tracing(
658+
room=torch.tensor([4, 5], dtype=torch.float),
659+
source=torch.tensor([0, 0], dtype=torch.float),
660+
mic_array=torch.tensor([3, 4], dtype=torch.float),
661+
num_rays=10,
662+
absorption=torch.rand(5, 5, dtype=torch.float),
663+
)
664+
with self.assertRaisesRegex(ValueError, "The shape of scattering must be"):
665+
F.ray_tracing(
666+
room=torch.tensor([4, 5], dtype=torch.float),
667+
source=torch.tensor([0, 0], dtype=torch.float),
668+
mic_array=torch.tensor([3, 4], dtype=torch.float),
669+
num_rays=10,
670+
scattering=torch.rand(5, dtype=torch.float),
671+
)
672+
with self.assertRaisesRegex(
673+
ValueError, "absorption and scattering must have the same number of bands and walls"
674+
):
675+
F.ray_tracing(
676+
room=torch.tensor([4, 5], dtype=torch.float),
677+
source=torch.tensor([0, 0], dtype=torch.float),
678+
mic_array=torch.tensor([3, 4], dtype=torch.float),
679+
num_rays=10,
680+
absorption=torch.rand(6, 4, dtype=torch.float),
681+
scattering=torch.rand(5, 4, dtype=torch.float),
682+
)
683+
684+
# Make sure passing different shapes for absorption or scattering doesn't raise an error
685+
# float and tensor
686+
F.ray_tracing(
687+
room=torch.tensor([4, 5], dtype=torch.float),
688+
source=torch.tensor([0, 0], dtype=torch.float),
689+
mic_array=torch.tensor([3, 4], dtype=torch.float),
690+
num_rays=10,
691+
absorption=0.1,
692+
scattering=torch.rand(5, 4, dtype=torch.float),
693+
)
694+
F.ray_tracing(
695+
room=torch.tensor([4, 5], dtype=torch.float),
696+
source=torch.tensor([0, 0], dtype=torch.float),
697+
mic_array=torch.tensor([3, 4], dtype=torch.float),
698+
num_rays=10,
699+
absorption=torch.rand(5, 4, dtype=torch.float),
700+
scattering=0.1,
701+
)
702+
# per-wall only and per-band + per-wall
703+
F.ray_tracing(
704+
room=torch.tensor([4, 5], dtype=torch.float),
705+
source=torch.tensor([0, 0], dtype=torch.float),
706+
mic_array=torch.tensor([3, 4], dtype=torch.float),
707+
num_rays=10,
708+
absorption=torch.rand(4, dtype=torch.float),
709+
scattering=torch.rand(6, 4, dtype=torch.float),
710+
)
711+
F.ray_tracing(
712+
room=torch.tensor([4, 5], dtype=torch.float),
713+
source=torch.tensor([0, 0], dtype=torch.float),
714+
mic_array=torch.tensor([3, 4], dtype=torch.float),
715+
num_rays=10,
716+
absorption=torch.rand(6, 4, dtype=torch.float),
717+
scattering=torch.rand(4, dtype=torch.float),
718+
)
719+
720+
def test_ray_tracing_per_band_per_wall_absorption(self):
721+
"""Check that when the value of absorption and scattering are the same
722+
across walls and frequency bands, the output histograms are:
723+
- all equal across frequency bands
724+
- equal to simply passing a float value instead of a (num_bands, D) or
725+
(D,) tensor.
726+
"""
727+
728+
room_dim = torch.tensor([20, 25], dtype=self.dtype)
729+
mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype)
730+
source = torch.tensor([7, 6], dtype=self.dtype)
731+
num_rays = 1_000
732+
ABS, SCAT = 0.1, 0.2
733+
734+
absorption = torch.full(fill_value=ABS, size=(6, 4), dtype=self.dtype)
735+
scattering = torch.full(fill_value=SCAT, size=(6, 4), dtype=self.dtype)
736+
hist_per_band_per_wall = F.ray_tracing(
737+
room=room_dim,
738+
source=source,
739+
mic_array=mic_array,
740+
num_rays=num_rays,
741+
absorption=absorption,
742+
scattering=scattering,
743+
)
744+
absorption = torch.full(fill_value=ABS, size=(4,), dtype=self.dtype)
745+
scattering = torch.full(fill_value=SCAT, size=(4,), dtype=self.dtype)
746+
hist_per_wall = F.ray_tracing(
747+
room=room_dim,
748+
source=source,
749+
mic_array=mic_array,
750+
num_rays=num_rays,
751+
absorption=absorption,
752+
scattering=scattering,
753+
)
754+
755+
absorption = ABS
756+
scattering = SCAT
757+
hist_single = F.ray_tracing(
758+
room=room_dim,
759+
source=source,
760+
mic_array=mic_array,
761+
num_rays=num_rays,
762+
absorption=absorption,
763+
scattering=scattering,
764+
)
765+
assert hist_per_band_per_wall.shape == (2, 6, 2500)
766+
assert hist_per_wall.shape == (2, 1, 2500)
767+
assert hist_single.shape == (2, 1, 2500)
768+
torch.testing.assert_close(hist_single, hist_per_wall)
769+
770+
hist_single = hist_single.expand(2, 6, 2500)
771+
torch.testing.assert_close(hist_single, hist_per_band_per_wall)
772+
773+
@parameterized.expand(
774+
[
775+
([20, 25], [2, 2], [[8, 8], [7, 6]], 10_000), # 2D with 2 mics
776+
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic
777+
]
778+
)
779+
def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays):
780+
781+
walls = ["west", "east", "south", "north"]
782+
if len(room_dim) == 3:
783+
walls += ["floor", "ceiling"]
784+
num_walls = len(walls)
785+
num_bands = 6 # Note: in ray tracing, we don't need to restrict the number of bands to 7
786+
787+
absorption = torch.rand(num_bands, num_walls, dtype=self.dtype)
788+
scattering = torch.rand(num_bands, num_walls, dtype=self.dtype)
789+
energy_thres = 1e-7
790+
time_thres = 10.0
791+
hist_bin_size = 0.004
792+
mic_radius = 0.5
793+
sound_speed = 343.0
794+
795+
room_dim = torch.tensor(room_dim, dtype=self.dtype)
796+
source = torch.tensor(source, dtype=self.dtype)
797+
mic_array = torch.tensor(mic_array, dtype=self.dtype)
798+
799+
room = pra.ShoeBox(
800+
room_dim.tolist(),
801+
ray_tracing=True,
802+
materials={
803+
walls[i]: pra.Material(
804+
energy_absorption={
805+
"coeffs": absorption[:, i].reshape(-1).detach().numpy(),
806+
"center_freqs": 125 * 2 ** np.arange(num_bands),
807+
},
808+
scattering={
809+
"coeffs": scattering[:, i].reshape(-1).detach().numpy(),
810+
"center_freqs": 125 * 2 ** np.arange(num_bands),
811+
},
812+
)
813+
for i in range(num_walls)
814+
},
815+
air_absorption=False,
816+
max_order=0, # Make sure PRA doesn't use the hybrid method (we just want ray tracing)
817+
)
818+
room.add_microphone_array(mic_array.T.tolist())
819+
room.add_source(source.tolist())
820+
room.set_ray_tracing(
821+
n_rays=num_rays,
822+
energy_thres=energy_thres,
823+
time_thres=time_thres,
824+
hist_bin_size=hist_bin_size,
825+
receiver_radius=mic_radius,
826+
)
827+
room.set_sound_speed(sound_speed)
828+
829+
room.compute_rir()
830+
hist_pra = torch.tensor(np.array(room.rt_histograms))[:, 0, 0]
831+
832+
hist = F.ray_tracing(
833+
room=room_dim,
834+
source=source,
835+
mic_array=mic_array,
836+
num_rays=num_rays,
837+
absorption=absorption,
838+
scattering=scattering,
839+
sound_speed=sound_speed,
840+
mic_radius=mic_radius,
841+
energy_thres=energy_thres,
842+
time_thres=time_thres,
843+
hist_bin_size=hist_bin_size,
844+
)
845+
846+
assert hist.ndim == 3
847+
assert hist.shape == hist_pra.shape
848+
self.assertEqual(hist.to(torch.float32), hist_pra)

test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,43 @@ def test_simulate_rir_ism_multi_band(self, channel):
112112
F.simulate_rir_ism,
113113
(room_dim, source, mic_array, max_order, absorption, None, 81, center_frequency, 343.0, 16000.0),
114114
)
115+
116+
@parameterized.expand(
117+
[
118+
([20, 25], [2, 2], [[8, 8], [7, 6]], 1_000), # 2D with 2 mics
119+
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500), # 3D with 1 mic
120+
]
121+
)
122+
def test_ray_tracing(self, room_dim, source, mic_array, num_rays):
123+
num_walls = 4 if len(room_dim) == 2 else 6
124+
num_bands = 3
125+
126+
absorption = torch.rand(num_bands, num_walls, dtype=torch.float32)
127+
scattering = torch.rand(num_bands, num_walls, dtype=torch.float32)
128+
129+
energy_thres = 1e-7
130+
time_thres = 10.0
131+
hist_bin_size = 0.004
132+
mic_radius = 0.5
133+
sound_speed = 343.0
134+
135+
room_dim = torch.tensor(room_dim, dtype=self.dtype)
136+
source = torch.tensor(source, dtype=self.dtype)
137+
mic_array = torch.tensor(mic_array, dtype=self.dtype)
138+
139+
self._assert_consistency(
140+
F.ray_tracing,
141+
(
142+
room_dim,
143+
source,
144+
mic_array,
145+
num_rays,
146+
absorption,
147+
scattering,
148+
mic_radius,
149+
sound_speed,
150+
energy_thres,
151+
time_thres,
152+
hist_bin_size,
153+
),
154+
)

torchaudio/csrc/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ if(BUILD_RNNT)
4242
endif()
4343

4444
if(BUILD_RIR)
45-
list(APPEND sources rir/rir.cpp)
45+
list(APPEND sources rir/rir.cpp rir/ray_tracing.cpp)
4646
list(APPEND compile_definitions INCLUDE_RIR)
4747
endif()
4848

0 commit comments

Comments
 (0)