Skip to content

Commit 7d124e3

Browse files
committed
WIP: Add Ray Tracing (#3604)
Summary: Revamped version of #3234 (which was also revamp of #2850) Differential Revision: D49197174 Pulled By: mthrok
1 parent 22281d3 commit 7d124e3

File tree

6 files changed

+472
-23
lines changed

6 files changed

+472
-23
lines changed

docs/source/prototype.functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ Room Impulse Response Simulation
3535
:toctree: generated
3636
:nosignatures:
3737

38+
ray_tracing
3839
simulate_rir_ism

test/torchaudio_unittest/prototype/functional/functional_test_impl.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,3 +460,226 @@ def _debug_plot():
460460
except AssertionError:
461461
_debug_plot()
462462
raise
463+
464+
@parameterized.expand(
465+
[
466+
(0.1, 0.2, (2, 1, 2500)), # both float
467+
# Per-wall
468+
(torch.rand(4), 0.2, (2, 1, 2500)),
469+
(0.1, torch.rand(4), (2, 1, 2500)),
470+
(torch.rand(4), torch.rand(4), (2, 1, 2500)),
471+
# Per-band and per-wall
472+
(torch.rand(6, 4), 0.2, (2, 6, 2500)),
473+
(0.1, torch.rand(6, 4), (2, 6, 2500)),
474+
(torch.rand(6, 4), torch.rand(6, 4), (2, 6, 2500)),
475+
]
476+
)
477+
def test_ray_tracing_output_shape(self, absorption, scattering, expected_shape):
478+
room_dim = torch.tensor([20, 25], dtype=self.dtype)
479+
mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype)
480+
source = torch.tensor([7, 6], dtype=self.dtype)
481+
num_rays = 100
482+
483+
hist = F.ray_tracing(
484+
room=room_dim,
485+
source=source,
486+
mic_array=mic_array,
487+
num_rays=num_rays,
488+
absorption=absorption,
489+
scattering=scattering,
490+
)
491+
492+
assert hist.shape == expected_shape
493+
494+
def test_ray_tracing_input_errors(self):
495+
with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"):
496+
F.ray_tracing(
497+
room=torch.tensor([[4, 5]]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[3, 4]]), num_rays=10
498+
)
499+
with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"):
500+
F.ray_tracing(
501+
room=torch.tensor([4, 5, 4, 5]),
502+
source=torch.tensor([0, 0]),
503+
mic_array=torch.tensor([[3, 4]]),
504+
num_rays=10,
505+
)
506+
with self.assertRaisesRegex(ValueError, r"mic_array must be 1D tensor of shape \(D,\), or 2D tensor"):
507+
F.ray_tracing(
508+
room=torch.tensor([4, 5]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[[3, 4]]]), num_rays=10
509+
)
510+
with self.assertRaisesRegex(ValueError, "room must be of float32 or float64 dtype"):
511+
F.ray_tracing(
512+
room=torch.tensor([4, 5]).to(torch.int),
513+
source=torch.tensor([0, 0]),
514+
mic_array=torch.tensor([3, 4]),
515+
num_rays=10,
516+
)
517+
with self.assertRaisesRegex(ValueError, "dtype of room, source and mic_array must be the same"):
518+
F.ray_tracing(
519+
room=torch.tensor([4, 5]).to(torch.float64),
520+
source=torch.tensor([0, 0]).to(torch.float32),
521+
mic_array=torch.tensor([3, 4]),
522+
num_rays=10,
523+
)
524+
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
525+
F.ray_tracing(
526+
room=torch.tensor([4, 5, 10], dtype=torch.float),
527+
source=torch.tensor([0, 0], dtype=torch.float),
528+
mic_array=torch.tensor([3, 4], dtype=torch.float),
529+
num_rays=10,
530+
)
531+
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
532+
F.ray_tracing(
533+
room=torch.tensor([4, 5], dtype=torch.float),
534+
source=torch.tensor([0, 0, 0], dtype=torch.float),
535+
mic_array=torch.tensor([3, 4], dtype=torch.float),
536+
num_rays=10,
537+
)
538+
with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"):
539+
F.ray_tracing(
540+
room=torch.tensor([4, 5, 10], dtype=torch.float),
541+
source=torch.tensor([0, 0, 0], dtype=torch.float),
542+
mic_array=torch.tensor([3, 4], dtype=torch.float),
543+
num_rays=10,
544+
)
545+
with self.assertRaisesRegex(ValueError, "time_thres=10 must be at least greater than hist_bin_size=11"):
546+
F.ray_tracing(
547+
room=torch.tensor([4, 5], dtype=torch.float),
548+
source=torch.tensor([0, 0], dtype=torch.float),
549+
mic_array=torch.tensor([3, 4], dtype=torch.float),
550+
num_rays=10,
551+
time_thres=10,
552+
hist_bin_size=11,
553+
)
554+
with self.assertRaisesRegex(ValueError, "The shape of absorption must be"):
555+
F.ray_tracing(
556+
room=torch.tensor([4, 5], dtype=torch.float),
557+
source=torch.tensor([0, 0], dtype=torch.float),
558+
mic_array=torch.tensor([3, 4], dtype=torch.float),
559+
num_rays=10,
560+
absorption=torch.rand(5, dtype=torch.float),
561+
)
562+
with self.assertRaisesRegex(ValueError, "The shape of scattering must be"):
563+
F.ray_tracing(
564+
room=torch.tensor([4, 5], dtype=torch.float),
565+
source=torch.tensor([0, 0], dtype=torch.float),
566+
mic_array=torch.tensor([3, 4], dtype=torch.float),
567+
num_rays=10,
568+
scattering=torch.rand(5, 5, dtype=torch.float),
569+
)
570+
with self.assertRaisesRegex(ValueError, "The shape of absorption must be"):
571+
F.ray_tracing(
572+
room=torch.tensor([4, 5], dtype=torch.float),
573+
source=torch.tensor([0, 0], dtype=torch.float),
574+
mic_array=torch.tensor([3, 4], dtype=torch.float),
575+
num_rays=10,
576+
absorption=torch.rand(5, 5, dtype=torch.float),
577+
)
578+
with self.assertRaisesRegex(ValueError, "The shape of scattering must be"):
579+
F.ray_tracing(
580+
room=torch.tensor([4, 5], dtype=torch.float),
581+
source=torch.tensor([0, 0], dtype=torch.float),
582+
mic_array=torch.tensor([3, 4], dtype=torch.float),
583+
num_rays=10,
584+
scattering=torch.rand(5, dtype=torch.float),
585+
)
586+
with self.assertRaisesRegex(
587+
ValueError, "absorption and scattering must have the same number of bands and walls"
588+
):
589+
F.ray_tracing(
590+
room=torch.tensor([4, 5], dtype=torch.float),
591+
source=torch.tensor([0, 0], dtype=torch.float),
592+
mic_array=torch.tensor([3, 4], dtype=torch.float),
593+
num_rays=10,
594+
absorption=torch.rand(6, 4, dtype=torch.float),
595+
scattering=torch.rand(5, 4, dtype=torch.float),
596+
)
597+
598+
# Make sure passing different shapes for absorption or scattering doesn't raise an error
599+
# float and tensor
600+
F.ray_tracing(
601+
room=torch.tensor([4, 5], dtype=torch.float),
602+
source=torch.tensor([0, 0], dtype=torch.float),
603+
mic_array=torch.tensor([3, 4], dtype=torch.float),
604+
num_rays=10,
605+
absorption=0.1,
606+
scattering=torch.rand(5, 4, dtype=torch.float),
607+
)
608+
F.ray_tracing(
609+
room=torch.tensor([4, 5], dtype=torch.float),
610+
source=torch.tensor([0, 0], dtype=torch.float),
611+
mic_array=torch.tensor([3, 4], dtype=torch.float),
612+
num_rays=10,
613+
absorption=torch.rand(5, 4, dtype=torch.float),
614+
scattering=0.1,
615+
)
616+
# per-wall only and per-band + per-wall
617+
F.ray_tracing(
618+
room=torch.tensor([4, 5], dtype=torch.float),
619+
source=torch.tensor([0, 0], dtype=torch.float),
620+
mic_array=torch.tensor([3, 4], dtype=torch.float),
621+
num_rays=10,
622+
absorption=torch.rand(4, dtype=torch.float),
623+
scattering=torch.rand(6, 4, dtype=torch.float),
624+
)
625+
F.ray_tracing(
626+
room=torch.tensor([4, 5], dtype=torch.float),
627+
source=torch.tensor([0, 0], dtype=torch.float),
628+
mic_array=torch.tensor([3, 4], dtype=torch.float),
629+
num_rays=10,
630+
absorption=torch.rand(6, 4, dtype=torch.float),
631+
scattering=torch.rand(4, dtype=torch.float),
632+
)
633+
634+
def test_ray_tracing_per_band_per_wall_absorption(self):
635+
"""Check that when the value of absorption and scattering are the same
636+
across walls and frequency bands, the output histograms are:
637+
- all equal across frequency bands
638+
- equal to simply passing a float value instead of a (num_bands, D) or
639+
(D,) tensor.
640+
"""
641+
642+
room_dim = torch.tensor([20, 25], dtype=self.dtype)
643+
mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype)
644+
source = torch.tensor([7, 6], dtype=self.dtype)
645+
num_rays = 1_000
646+
ABS, SCAT = 0.1, 0.2
647+
648+
absorption = torch.full(fill_value=ABS, size=(6, 4), dtype=self.dtype)
649+
scattering = torch.full(fill_value=SCAT, size=(6, 4), dtype=self.dtype)
650+
hist_per_band_per_wall = F.ray_tracing(
651+
room=room_dim,
652+
source=source,
653+
mic_array=mic_array,
654+
num_rays=num_rays,
655+
absorption=absorption,
656+
scattering=scattering,
657+
)
658+
absorption = torch.full(fill_value=ABS, size=(4,), dtype=self.dtype)
659+
scattering = torch.full(fill_value=SCAT, size=(4,), dtype=self.dtype)
660+
hist_per_wall = F.ray_tracing(
661+
room=room_dim,
662+
source=source,
663+
mic_array=mic_array,
664+
num_rays=num_rays,
665+
absorption=absorption,
666+
scattering=scattering,
667+
)
668+
669+
absorption = ABS
670+
scattering = SCAT
671+
hist_single = F.ray_tracing(
672+
room=room_dim,
673+
source=source,
674+
mic_array=mic_array,
675+
num_rays=num_rays,
676+
absorption=absorption,
677+
scattering=scattering,
678+
)
679+
assert hist_per_band_per_wall.shape == (2, 6, 2500)
680+
assert hist_per_wall.shape == (2, 1, 2500)
681+
assert hist_single.shape == (2, 1, 2500)
682+
torch.testing.assert_close(hist_single, hist_per_wall)
683+
684+
hist_single = hist_single.expand(2, 6, 2500)
685+
torch.testing.assert_close(hist_single, hist_per_band_per_wall)

test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23
import torchaudio.prototype.functional as F
34

@@ -91,3 +92,80 @@ def test_simulate_rir_ism_multi_band(self, channel):
9192
expected[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0])
9293
actual = F.simulate_rir_ism(room_dim, source, mic_array, max_order, absorption)
9394
self.assertEqual(expected, actual, atol=1e-3, rtol=1e-3)
95+
96+
@parameterized.expand(
97+
[
98+
([20, 25], [2, 2], [[8, 8], [7, 6]], 10_000), # 2D with 2 mics
99+
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic
100+
]
101+
)
102+
def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays):
103+
104+
walls = ["west", "east", "south", "north"]
105+
if len(room_dim) == 3:
106+
walls += ["floor", "ceiling"]
107+
num_walls = len(walls)
108+
num_bands = 6 # Note: in ray tracing, we don't need to restrict the number of bands to 7
109+
110+
absorption = torch.rand(num_bands, num_walls, dtype=self.dtype)
111+
scattering = torch.rand(num_bands, num_walls, dtype=self.dtype)
112+
energy_thres = 1e-7
113+
time_thres = 10.0
114+
hist_bin_size = 0.004
115+
mic_radius = 0.5
116+
sound_speed = 343.0
117+
118+
room_dim = torch.tensor(room_dim, dtype=self.dtype)
119+
source = torch.tensor(source, dtype=self.dtype)
120+
mic_array = torch.tensor(mic_array, dtype=self.dtype)
121+
122+
room = pra.ShoeBox(
123+
room_dim.tolist(),
124+
ray_tracing=True,
125+
materials={
126+
walls[i]: pra.Material(
127+
energy_absorption={
128+
"coeffs": absorption[:, i].reshape(-1).detach().numpy(),
129+
"center_freqs": 125 * 2 ** np.arange(num_bands),
130+
},
131+
scattering={
132+
"coeffs": scattering[:, i].reshape(-1).detach().numpy(),
133+
"center_freqs": 125 * 2 ** np.arange(num_bands),
134+
},
135+
)
136+
for i in range(num_walls)
137+
},
138+
air_absorption=False,
139+
max_order=0, # Make sure PRA doesn't use the hybrid method (we just want ray tracing)
140+
)
141+
room.add_microphone_array(mic_array.T.tolist())
142+
room.add_source(source.tolist())
143+
room.set_ray_tracing(
144+
n_rays=num_rays,
145+
energy_thres=energy_thres,
146+
time_thres=time_thres,
147+
hist_bin_size=hist_bin_size,
148+
receiver_radius=mic_radius,
149+
)
150+
room.set_sound_speed(sound_speed)
151+
152+
room.compute_rir()
153+
hist_pra = torch.tensor(np.array(room.rt_histograms))[:, 0, 0]
154+
155+
hist = F.ray_tracing(
156+
room=room_dim,
157+
source=source,
158+
mic_array=mic_array,
159+
num_rays=num_rays,
160+
absorption=absorption,
161+
scattering=scattering,
162+
sound_speed=sound_speed,
163+
mic_radius=mic_radius,
164+
energy_thres=energy_thres,
165+
time_thres=time_thres,
166+
hist_bin_size=hist_bin_size,
167+
)
168+
169+
assert hist.ndim == 3
170+
assert hist.shape == hist_pra.shape
171+
self.assertEqual(hist.to(torch.float32), hist_pra)

test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,43 @@ def test_simulate_rir_ism_multi_band(self, channel):
112112
F.simulate_rir_ism,
113113
(room_dim, source, mic_array, max_order, absorption, None, 81, center_frequency, 343.0, 16000.0),
114114
)
115+
116+
@parameterized.expand(
117+
[
118+
([20, 25], [2, 2], [[8, 8], [7, 6]], 1_000), # 2D with 2 mics
119+
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500), # 3D with 1 mic
120+
]
121+
)
122+
def test_ray_tracing(self, room_dim, source, mic_array, num_rays):
123+
num_walls = 4 if len(room_dim) == 2 else 6
124+
num_bands = 3
125+
126+
absorption = torch.rand(num_bands, num_walls, dtype=torch.float32)
127+
scattering = torch.rand(num_bands, num_walls, dtype=torch.float32)
128+
129+
energy_thres = 1e-7
130+
time_thres = 10.0
131+
hist_bin_size = 0.004
132+
mic_radius = 0.5
133+
sound_speed = 343.0
134+
135+
room_dim = torch.tensor(room_dim, dtype=self.dtype)
136+
source = torch.tensor(source, dtype=self.dtype)
137+
mic_array = torch.tensor(mic_array, dtype=self.dtype)
138+
139+
self._assert_consistency(
140+
F.ray_tracing,
141+
(
142+
room_dim,
143+
source,
144+
mic_array,
145+
num_rays,
146+
absorption,
147+
scattering,
148+
mic_radius,
149+
sound_speed,
150+
energy_thres,
151+
time_thres,
152+
hist_bin_size,
153+
),
154+
)

torchaudio/prototype/functional/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
oscillator_bank,
88
sinc_impulse_response,
99
)
10-
from ._rir import simulate_rir_ism
10+
from ._rir import ray_tracing, simulate_rir_ism
1111
from .functional import barkscale_fbanks, chroma_filterbank
1212

1313

@@ -20,6 +20,7 @@
2020
"filter_waveform",
2121
"frequency_impulse_response",
2222
"oscillator_bank",
23+
"ray_tracing",
2324
"sinc_impulse_response",
2425
"simulate_rir_ism",
2526
]

0 commit comments

Comments
 (0)