Skip to content

Commit 49be591

Browse files
committed
WIP: Add Ray Tracing (#3604)
Summary: Revamped version of #3234 (which was also revamp of #2850) Differential Revision: D49197174 Pulled By: mthrok
1 parent d9942ba commit 49be591

File tree

8 files changed

+500
-10
lines changed

8 files changed

+500
-10
lines changed

docs/source/prototype.functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ Room Impulse Response Simulation
3535
:toctree: generated
3636
:nosignatures:
3737

38+
ray_tracing
3839
simulate_rir_ism

test/torchaudio_unittest/prototype/functional/functional_test_impl.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,255 @@ def test_exp_sigmoid_input_diff(self, linspace_input_values, exp_sigmoid_paramet
412412

413413
self.assertEqual(torch_out, torch.tensor(np_out))
414414

415+
@parameterized.expand(
416+
[
417+
# both float
418+
(0.1, 0.2, (2, 1, 2500)),
419+
# Per-wall
420+
((6,), 0.2, (2, 1, 2500)),
421+
(0.1, (6,), (2, 1, 2500)),
422+
((6,), (6,), (2, 1, 2500)),
423+
# Per-band and per-wall
424+
((3, 6), 0.2, (2, 3, 2500)),
425+
(0.1, (5, 6), (2, 5, 2500)),
426+
((7, 6), (7, 6), (2, 7, 2500)),
427+
]
428+
)
429+
def test_ray_tracing_output_shape(self, abs_, scat_, expected_shape):
430+
if isinstance(abs_, float):
431+
absorption = abs_
432+
else:
433+
absorption = torch.rand(abs_, dtype=self.dtype)
434+
if isinstance(scat_, float):
435+
scattering = scat_
436+
else:
437+
scattering = torch.rand(scat_, dtype=self.dtype)
438+
439+
room_dim = torch.tensor([3, 4, 5], dtype=self.dtype)
440+
mic_array = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=self.dtype)
441+
source = torch.tensor([1, 2, 3], dtype=self.dtype)
442+
num_rays = 100
443+
444+
hist = F.ray_tracing(
445+
room=room_dim,
446+
source=source,
447+
mic_array=mic_array,
448+
num_rays=num_rays,
449+
absorption=absorption,
450+
scattering=scattering,
451+
)
452+
assert hist.shape == expected_shape
453+
454+
def test_ray_tracing_input_errors(self):
455+
room = torch.tensor([3.0, 4.0, 5.0], dtype=self.dtype)
456+
source = torch.tensor([0.0, 0.0, 0.0], dtype=self.dtype)
457+
mic = torch.tensor([[1.0, 2.0, 3.0]], dtype=self.dtype)
458+
459+
# baseline. This should not raise
460+
_ = F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10)
461+
462+
# invlaid room shape
463+
for invalid in ([[4, 5]], [4, 5, 4, 5]):
464+
invalid = torch.tensor(invalid, dtype=self.dtype)
465+
with self.assertRaises(ValueError) as cm:
466+
F.ray_tracing(room=invalid, source=source, mic_array=mic, num_rays=10)
467+
468+
error = str(cm.exception)
469+
self.assertIn("`room` must be a 1D Tensor with 3 elements.", error)
470+
self.assertIn(str(invalid.shape), error)
471+
472+
# invalid microphone shape
473+
invalid = torch.tensor([[[3, 4]]], dtype=self.dtype)
474+
with self.assertRaises(ValueError) as cm:
475+
F.ray_tracing(room=room, source=source, mic_array=invalid, num_rays=10)
476+
477+
error = str(cm.exception)
478+
self.assertIn("`mic_array` must be a 2D Tensor with shape (num_channels, 3).", error)
479+
self.assertIn(str(invalid.shape), error)
480+
481+
# incompatible dtypes
482+
with self.assertRaises(ValueError) as cm:
483+
F.ray_tracing(
484+
room=room.to(torch.float64),
485+
source=source.to(torch.float32),
486+
mic_array=mic,
487+
num_rays=10,
488+
)
489+
error = str(cm.exception)
490+
self.assertIn("dtype of `room`, `source` and `mic_array` must match.", error)
491+
self.assertIn("`room` (torch.float64)", error)
492+
self.assertIn("`source` (torch.float32)", error)
493+
self.assertIn("`mic_array` (torch.float32)", error)
494+
495+
# invalid time configuration
496+
with self.assertRaises(ValueError) as cm:
497+
F.ray_tracing(
498+
room=room,
499+
source=source,
500+
mic_array=mic,
501+
num_rays=10,
502+
time_thres=10,
503+
hist_bin_size=11,
504+
)
505+
error = str(cm.exception)
506+
self.assertIn("`time_thres` must be greater than `hist_bin_size`.", error)
507+
self.assertIn("hist_bin_size=11", error)
508+
self.assertIn("time_thres=10", error)
509+
510+
# invalid absorption shape 1D
511+
invalid_abs = torch.tensor([1, 2, 3], dtype=self.dtype)
512+
with self.assertRaises(ValueError) as cm:
513+
F.ray_tracing(
514+
room=room,
515+
source=source,
516+
mic_array=mic,
517+
num_rays=10,
518+
absorption=invalid_abs,
519+
)
520+
error = str(cm.exception)
521+
self.assertIn("The shape of `absorption` must be (6,) when", error)
522+
self.assertIn(str(invalid_abs.shape), error)
523+
524+
# invalid absorption shape 2D
525+
invalid_abs = torch.tensor([[1, 2, 3]], dtype=self.dtype)
526+
with self.assertRaises(ValueError) as cm:
527+
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, absorption=invalid_abs)
528+
error = str(cm.exception)
529+
self.assertIn("The shape of `absorption` must be (NUM_BANDS, 6) when", error)
530+
self.assertIn(str(invalid_abs.shape), error)
531+
532+
# invalid scattering shape 1D
533+
invalid_scat = torch.tensor([1, 2, 3], dtype=self.dtype)
534+
with self.assertRaises(ValueError) as cm:
535+
F.ray_tracing(
536+
room=room,
537+
source=source,
538+
mic_array=mic,
539+
num_rays=10,
540+
scattering=invalid_scat,
541+
)
542+
error = str(cm.exception)
543+
self.assertIn("The shape of `scattering` must be (6,) when", error)
544+
self.assertIn(str(invalid_scat.shape), error)
545+
546+
# invalid scattering shape 2D
547+
invalid_scat = torch.tensor([[1, 2, 3]], dtype=self.dtype)
548+
with self.assertRaises(ValueError) as cm:
549+
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, scattering=invalid_scat)
550+
error = str(cm.exception)
551+
self.assertIn("The shape of `scattering` must be (NUM_BANDS, 6) when", error)
552+
self.assertIn(str(invalid_scat.shape), error)
553+
554+
# TODO: Invalid absorption/scattering value
555+
556+
# incompatible scattering and absorption
557+
abs_ = torch.zeros((7, 6), dtype=self.dtype)
558+
scat = torch.zeros((5, 6), dtype=self.dtype)
559+
with self.assertRaises(ValueError) as cm:
560+
F.ray_tracing(
561+
room=room,
562+
source=source,
563+
mic_array=mic,
564+
num_rays=10,
565+
absorption=abs_,
566+
scattering=scat,
567+
)
568+
error = str(cm.exception)
569+
self.assertIn(
570+
"`absorption` and `scattering` must be broadcastable to the same number of bands and walls", error
571+
)
572+
self.assertIn(f"absorption={abs_.shape}", error)
573+
self.assertIn(f"scattering={scat.shape}", error)
574+
575+
# Make sure passing different shapes for absorption or scattering doesn't raise an error
576+
# float and tensor
577+
F.ray_tracing(
578+
room=room,
579+
source=source,
580+
mic_array=mic,
581+
num_rays=10,
582+
absorption=0.1,
583+
scattering=torch.randn((5, 6), dtype=self.dtype),
584+
)
585+
F.ray_tracing(
586+
room=room,
587+
source=source,
588+
mic_array=mic,
589+
num_rays=10,
590+
absorption=torch.randn((7, 6), dtype=self.dtype),
591+
scattering=0.1,
592+
)
593+
# per-wall only and per-band + per-wall
594+
F.ray_tracing(
595+
room=room,
596+
source=source,
597+
mic_array=mic,
598+
num_rays=10,
599+
absorption=torch.rand(6, dtype=self.dtype),
600+
scattering=torch.rand(7, 6, dtype=self.dtype),
601+
)
602+
F.ray_tracing(
603+
room=room,
604+
source=source,
605+
mic_array=mic,
606+
num_rays=10,
607+
absorption=torch.rand(7, 6, dtype=self.dtype),
608+
scattering=torch.rand(6, dtype=self.dtype),
609+
)
610+
611+
def test_ray_tracing_per_band_per_wall_absorption(self):
612+
"""Check that when the value of absorption and scattering are the same
613+
across walls and frequency bands, the output histograms are:
614+
- all equal across frequency bands
615+
- equal to simply passing a float value instead of a (num_bands, D) or
616+
(D,) tensor.
617+
"""
618+
619+
room_dim = torch.tensor([20, 25, 5], dtype=self.dtype)
620+
mic_array = torch.tensor([[2, 2, 0], [8, 8, 0]], dtype=self.dtype)
621+
source = torch.tensor([7, 6, 0], dtype=self.dtype)
622+
num_rays = 1_000
623+
ABS, SCAT = 0.1, 0.2
624+
625+
absorption = torch.full(fill_value=ABS, size=(7, 6), dtype=self.dtype)
626+
scattering = torch.full(fill_value=SCAT, size=(7, 6), dtype=self.dtype)
627+
hist_per_band_per_wall = F.ray_tracing(
628+
room=room_dim,
629+
source=source,
630+
mic_array=mic_array,
631+
num_rays=num_rays,
632+
absorption=absorption,
633+
scattering=scattering,
634+
)
635+
absorption = torch.full(fill_value=ABS, size=(6,), dtype=self.dtype)
636+
scattering = torch.full(fill_value=SCAT, size=(6,), dtype=self.dtype)
637+
hist_per_wall = F.ray_tracing(
638+
room=room_dim,
639+
source=source,
640+
mic_array=mic_array,
641+
num_rays=num_rays,
642+
absorption=absorption,
643+
scattering=scattering,
644+
)
645+
646+
absorption = ABS
647+
scattering = SCAT
648+
hist_single = F.ray_tracing(
649+
room=room_dim,
650+
source=source,
651+
mic_array=mic_array,
652+
num_rays=num_rays,
653+
absorption=absorption,
654+
scattering=scattering,
655+
)
656+
self.assertEqual(hist_per_band_per_wall.shape, (2, 7, 2500))
657+
self.assertEqual(hist_per_wall.shape, (2, 1, 2500))
658+
self.assertEqual(hist_single.shape, (2, 1, 2500))
659+
torch.testing.assert_close(hist_single, hist_per_wall)
660+
661+
hist_single = hist_single.expand(hist_per_band_per_wall.shape)
662+
torch.testing.assert_close(hist_single, hist_per_band_per_wall)
663+
415664

416665
class Functional64OnlyTestImpl(TestBaseMixin):
417666
@nested_params(

test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
import numpy as np
13
import torch
24
import torchaudio.prototype.functional as F
35

@@ -9,6 +11,43 @@
911
import pyroomacoustics as pra
1012

1113

14+
def _pra_ray_tracing(room_dim, absorption, scattering, num_bands, mic_array, source, num_rays, energy_thres, time_thres, hist_bin_size, mic_radius, sound_speed):
15+
walls = ["west", "east", "south", "north", "floor", "ceiling"]
16+
absorption = absorption.T.tolist()
17+
scattering = scattering.T.tolist()
18+
freqs = 125 * 2 ** np.arange(num_bands)
19+
20+
room = pra.ShoeBox(
21+
room_dim.tolist(),
22+
ray_tracing=True,
23+
materials={
24+
wall: pra.Material(
25+
energy_absorption={"coeffs": absorp, "center_freqs": freqs},
26+
scattering={"coeffs": scat, "center_freqs": freqs},
27+
)
28+
for wall, absorp, scat in zip(walls, absorption, scattering)
29+
},
30+
air_absorption=False,
31+
max_order=0, # Make sure PRA doesn't use the hybrid method (we just want ray tracing)
32+
)
33+
room.add_microphone_array(mic_array.T.tolist())
34+
room.add_source(source.tolist())
35+
room.set_ray_tracing(
36+
n_rays=num_rays,
37+
energy_thres=energy_thres,
38+
time_thres=time_thres,
39+
hist_bin_size=hist_bin_size,
40+
receiver_radius=mic_radius,
41+
)
42+
room.set_sound_speed(sound_speed)
43+
room.compute_rir()
44+
hist_pra = np.array(room.rt_histograms, dtype=np.double)[:, 0, 0]
45+
46+
# PRA continues the simulation beyond time threshold, but torchaudio does not.
47+
num_bins = math.ceil(time_thres / hist_bin_size)
48+
return hist_pra[:, :, :num_bins]
49+
50+
1251
@skipIfNoModule("pyroomacoustics")
1352
@skipIfNoRIR
1453
class CompatibilityTest(PytorchTestCase):
@@ -91,3 +130,54 @@ def test_simulate_rir_ism_multi_band(self, channel):
91130
expected[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0])
92131
actual = F.simulate_rir_ism(room_dim, source, mic_array, max_order, absorption)
93132
self.assertEqual(expected, actual, atol=1e-3, rtol=1e-3)
133+
134+
@parameterized.expand(
135+
[
136+
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic
137+
]
138+
)
139+
def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays):
140+
num_bands = 6
141+
energy_thres = 1e-7
142+
time_thres = 10.0
143+
hist_bin_size = 0.004
144+
mic_radius = 0.5
145+
sound_speed = 343.0
146+
147+
absorption = torch.rand((num_bands, 6), dtype=self.dtype)
148+
scattering = torch.rand((num_bands, 6), dtype=self.dtype)
149+
room_dim = torch.tensor(room_dim, dtype=self.dtype)
150+
source = torch.tensor(source, dtype=self.dtype)
151+
mic_array = torch.tensor(mic_array, dtype=self.dtype)
152+
153+
hist_pra = _pra_ray_tracing(
154+
room_dim,
155+
absorption,
156+
scattering,
157+
num_bands,
158+
mic_array,
159+
source,
160+
num_rays,
161+
energy_thres,
162+
time_thres,
163+
hist_bin_size,
164+
mic_radius,
165+
sound_speed)
166+
167+
hist = F.ray_tracing(
168+
room=room_dim,
169+
source=source,
170+
mic_array=mic_array,
171+
num_rays=num_rays,
172+
absorption=absorption,
173+
scattering=scattering,
174+
sound_speed=sound_speed,
175+
mic_radius=mic_radius,
176+
energy_thres=energy_thres,
177+
time_thres=time_thres,
178+
hist_bin_size=hist_bin_size,
179+
)
180+
181+
assert hist.ndim == 3
182+
assert hist.shape == hist_pra.shape
183+
self.assertEqual(hist, hist_pra)

0 commit comments

Comments
 (0)