Skip to content

Commit 4117ac6

Browse files
committed
WIP: Add Ray Tracing (#3604)
Summary: Revamped version of #3234 (which was also revamp of #2850) Differential Revision: D49197174 Pulled By: mthrok
1 parent d9942ba commit 4117ac6

File tree

7 files changed

+478
-9
lines changed

7 files changed

+478
-9
lines changed

docs/source/prototype.functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ Room Impulse Response Simulation
3535
:toctree: generated
3636
:nosignatures:
3737

38+
ray_tracing
3839
simulate_rir_ism

test/torchaudio_unittest/prototype/functional/functional_test_impl.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,258 @@ def test_exp_sigmoid_input_diff(self, linspace_input_values, exp_sigmoid_paramet
412412

413413
self.assertEqual(torch_out, torch.tensor(np_out))
414414

415+
@parameterized.expand(
416+
[
417+
# both float
418+
(0.1, 0.2, (2, 1, 2500)),
419+
# Per-wall
420+
((6, ), 0.2, (2, 1, 2500)),
421+
(0.1, (6, ), (2, 1, 2500)),
422+
((6, ), (6, ), (2, 1, 2500)),
423+
# Per-band and per-wall
424+
((3, 6), 0.2, (2, 3, 2500)),
425+
(0.1, (5, 6), (2, 5, 2500)),
426+
((7, 6), (7, 6), (2, 7, 2500)),
427+
]
428+
)
429+
def test_ray_tracing_output_shape(self, abs_, scat_, expected_shape):
430+
if isinstance(abs_, float):
431+
absorption = abs_
432+
else:
433+
absorption = torch.rand(abs_, dtype=self.dtype)
434+
if isinstance(scat_, float):
435+
scattering = scat_
436+
else:
437+
scattering = torch.rand(scat_, dtype=self.dtype)
438+
439+
room_dim = torch.tensor([3, 4, 5], dtype=self.dtype)
440+
mic_array = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=self.dtype)
441+
source = torch.tensor([1, 2, 3], dtype=self.dtype)
442+
num_rays = 100
443+
444+
hist = F.ray_tracing(
445+
room=room_dim,
446+
source=source,
447+
mic_array=mic_array,
448+
num_rays=num_rays,
449+
absorption=absorption,
450+
scattering=scattering,
451+
)
452+
assert hist.shape == expected_shape
453+
454+
def test_ray_tracing_input_errors(self):
455+
if self.dtype == torch.float64:
456+
import unittest
457+
458+
raise unittest.SkipTest("float64 is not supported yet")
459+
460+
room = torch.tensor([3., 4., 5.], dtype=self.dtype)
461+
source = torch.tensor([0., 0., 0.], dtype=self.dtype)
462+
mic = torch.tensor([[1., 2., 3.]], dtype=self.dtype)
463+
464+
_ = F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10)
465+
466+
for invalid in ([[4, 5]], [4, 5, 4, 5]):
467+
invalid = torch.tensor(invalid, dtype=self.dtype)
468+
with self.assertRaises(ValueError) as cm:
469+
F.ray_tracing(room=invalid, source=source, mic_array=mic, num_rays=10)
470+
471+
error = str(cm.exception)
472+
self.assertIn("`room` must be a 1D Tensor with 3 elements.", error)
473+
self.assertIn(str(invalid.shape), error)
474+
475+
invalid = torch.tensor([[[3, 4]]], dtype=self.dtype)
476+
with self.assertRaises(ValueError) as cm:
477+
F.ray_tracing(room=room, source=source, mic_array=invalid, num_rays=10)
478+
479+
error = str(cm.exception)
480+
self.assertIn("`mic_array` must be a 2D Tensor with shape (num_channels, 3).", error)
481+
self.assertIn(str(invalid.shape), error)
482+
483+
with self.assertRaises(ValueError) as cm:
484+
F.ray_tracing(
485+
room=room.to(torch.float64),
486+
source=source.to(torch.float32),
487+
mic_array=mic,
488+
num_rays=10,
489+
)
490+
error = str(cm.exception)
491+
self.assertIn("dtype of `room`, `source` and `mic_array` must match.", error)
492+
self.assertIn("`room` (torch.float64)", error)
493+
self.assertIn("`source` (torch.float32)", error)
494+
self.assertIn("`mic_array` (torch.float32)", error)
495+
496+
with self.assertRaises(ValueError) as cm:
497+
F.ray_tracing(
498+
room=room,
499+
source=source,
500+
mic_array=mic,
501+
num_rays=10,
502+
time_thres=10,
503+
hist_bin_size=11,
504+
)
505+
error = str(cm.exception)
506+
self.assertIn("`time_thres` must be greater than `hist_bin_size`.", error)
507+
self.assertIn("hist_bin_size=11", error)
508+
self.assertIn("time_thres=10", error)
509+
510+
invalid_abs = torch.tensor([1, 2, 3], dtype=self.dtype)
511+
with self.assertRaises(ValueError) as cm:
512+
F.ray_tracing(
513+
room=room,
514+
source=source,
515+
mic_array=mic,
516+
num_rays=10,
517+
absorption=invalid_abs,
518+
)
519+
error = str(cm.exception)
520+
self.assertIn("The shape of `absorption` must be (6,) when", error)
521+
self.assertIn(str(invalid_abs.shape), error)
522+
523+
invalid_scat = torch.tensor([1, 2, 3], dtype=self.dtype)
524+
with self.assertRaises(ValueError) as cm:
525+
F.ray_tracing(
526+
room=room,
527+
source=source,
528+
mic_array=mic,
529+
num_rays=10,
530+
scattering=invalid_scat,
531+
)
532+
error = str(cm.exception)
533+
self.assertIn("The shape of `scattering` must be (6,) when", error)
534+
self.assertIn(str(invalid_scat.shape), error)
535+
536+
invalid_abs = torch.tensor([[1, 2, 3]], dtype=self.dtype)
537+
with self.assertRaises(ValueError) as cm:
538+
F.ray_tracing(
539+
room=room,
540+
source=source,
541+
mic_array=mic,
542+
num_rays=10,
543+
absorption=invalid_abs
544+
)
545+
error = str(cm.exception)
546+
self.assertIn("The shape of `absorption` must be (NUM_BANDS, 6) when", error)
547+
self.assertIn(str(invalid_abs.shape), error)
548+
549+
invalid_scat = torch.tensor([[1, 2, 3]], dtype=self.dtype)
550+
with self.assertRaises(ValueError) as cm:
551+
F.ray_tracing(
552+
room=room,
553+
source=source,
554+
mic_array=mic,
555+
num_rays=10,
556+
scattering=invalid_scat
557+
)
558+
error = str(cm.exception)
559+
self.assertIn("The shape of `scattering` must be (NUM_BANDS, 6) when", error)
560+
self.assertIn(str(invalid_scat.shape), error)
561+
562+
abs_ = torch.randn((7, 6), dtype=self.dtype)
563+
scat = torch.randn((5, 6), dtype=self.dtype)
564+
with self.assertRaises(ValueError) as cm:
565+
F.ray_tracing(
566+
room=room,
567+
source=source,
568+
mic_array=mic,
569+
num_rays=10,
570+
absorption=abs_,
571+
scattering=scat,
572+
)
573+
error = str(cm.exception)
574+
self.assertIn("`absorption` and `scattering` must be broadcastable to the same number of bands and walls", error)
575+
self.assertIn(f"absorption={abs_.shape}", error)
576+
self.assertIn(f"scattering={scat.shape}", error)
577+
578+
# Make sure passing different shapes for absorption or scattering doesn't raise an error
579+
# float and tensor
580+
F.ray_tracing(
581+
room=room,
582+
source=source,
583+
mic_array=mic,
584+
num_rays=10,
585+
absorption=0.1,
586+
scattering=torch.randn((5, 6), dtype=self.dtype),
587+
)
588+
F.ray_tracing(
589+
room=room,
590+
source=source,
591+
mic_array=mic,
592+
num_rays=10,
593+
absorption=torch.randn((7, 6), dtype=self.dtype),
594+
scattering=0.1,
595+
)
596+
# per-wall only and per-band + per-wall
597+
F.ray_tracing(
598+
room=room,
599+
source=source,
600+
mic_array=mic,
601+
num_rays=10,
602+
absorption=torch.rand(6, dtype=self.dtype),
603+
scattering=torch.rand(7, 6, dtype=self.dtype),
604+
)
605+
F.ray_tracing(
606+
room=room,
607+
source=source,
608+
mic_array=mic,
609+
num_rays=10,
610+
absorption=torch.rand(7, 6, dtype=self.dtype),
611+
scattering=torch.rand(6, dtype=self.dtype),
612+
)
613+
614+
def test_ray_tracing_per_band_per_wall_absorption(self):
615+
"""Check that when the value of absorption and scattering are the same
616+
across walls and frequency bands, the output histograms are:
617+
- all equal across frequency bands
618+
- equal to simply passing a float value instead of a (num_bands, D) or
619+
(D,) tensor.
620+
"""
621+
622+
room_dim = torch.tensor([20, 25, 5], dtype=self.dtype)
623+
mic_array = torch.tensor([[2, 2, 0], [8, 8, 0]], dtype=self.dtype)
624+
source = torch.tensor([7, 6, 0], dtype=self.dtype)
625+
num_rays = 1_000
626+
ABS, SCAT = 0.1, 0.2
627+
628+
absorption = torch.full(fill_value=ABS, size=(7, 6), dtype=self.dtype)
629+
scattering = torch.full(fill_value=SCAT, size=(7, 6), dtype=self.dtype)
630+
hist_per_band_per_wall = F.ray_tracing(
631+
room=room_dim,
632+
source=source,
633+
mic_array=mic_array,
634+
num_rays=num_rays,
635+
absorption=absorption,
636+
scattering=scattering,
637+
)
638+
absorption = torch.full(fill_value=ABS, size=(6,), dtype=self.dtype)
639+
scattering = torch.full(fill_value=SCAT, size=(6,), dtype=self.dtype)
640+
hist_per_wall = F.ray_tracing(
641+
room=room_dim,
642+
source=source,
643+
mic_array=mic_array,
644+
num_rays=num_rays,
645+
absorption=absorption,
646+
scattering=scattering,
647+
)
648+
649+
absorption = ABS
650+
scattering = SCAT
651+
hist_single = F.ray_tracing(
652+
room=room_dim,
653+
source=source,
654+
mic_array=mic_array,
655+
num_rays=num_rays,
656+
absorption=absorption,
657+
scattering=scattering,
658+
)
659+
self.assertEqual(hist_per_band_per_wall.shape, (2, 7, 2500))
660+
self.assertEqual(hist_per_wall.shape, (2, 1, 2500))
661+
self.assertEqual(hist_single.shape, (2, 1, 2500))
662+
torch.testing.assert_close(hist_single, hist_per_wall)
663+
664+
hist_single = hist_single.expand(hist_per_band_per_wall.shape)
665+
torch.testing.assert_close(hist_single, hist_per_band_per_wall)
666+
415667

416668
class Functional64OnlyTestImpl(TestBaseMixin):
417669
@nested_params(

test/torchaudio_unittest/prototype/functional/pyroomacoustics_compatibility_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23
import torchaudio.prototype.functional as F
34

@@ -91,3 +92,76 @@ def test_simulate_rir_ism_multi_band(self, channel):
9192
expected[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0])
9293
actual = F.simulate_rir_ism(room_dim, source, mic_array, max_order, absorption)
9394
self.assertEqual(expected, actual, atol=1e-3, rtol=1e-3)
95+
96+
@parameterized.expand(
97+
[
98+
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic
99+
]
100+
)
101+
def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays):
102+
103+
walls = ["west", "east", "south", "north", "floor", "ceiling"]
104+
num_walls = len(walls)
105+
num_bands = 6
106+
107+
energy_thres = 1e-7
108+
time_thres = 10.0
109+
hist_bin_size = 0.004
110+
mic_radius = 0.5
111+
sound_speed = 343.0
112+
113+
absorption = torch.rand(num_bands, num_walls, dtype=self.dtype)
114+
scattering = torch.rand(num_bands, num_walls, dtype=self.dtype)
115+
room_dim = torch.tensor(room_dim, dtype=self.dtype)
116+
source = torch.tensor(source, dtype=self.dtype)
117+
mic_array = torch.tensor(mic_array, dtype=self.dtype)
118+
119+
room = pra.ShoeBox(
120+
room_dim.tolist(),
121+
ray_tracing=True,
122+
materials={
123+
walls[i]: pra.Material(
124+
energy_absorption={
125+
"coeffs": absorption[:, i].reshape(-1).detach().numpy(),
126+
"center_freqs": 125 * 2 ** np.arange(num_bands),
127+
},
128+
scattering={
129+
"coeffs": scattering[:, i].reshape(-1).detach().numpy(),
130+
"center_freqs": 125 * 2 ** np.arange(num_bands),
131+
},
132+
)
133+
for i in range(num_walls)
134+
},
135+
air_absorption=False,
136+
max_order=0, # Make sure PRA doesn't use the hybrid method (we just want ray tracing)
137+
)
138+
room.add_microphone_array(mic_array.T.tolist())
139+
room.add_source(source.tolist())
140+
room.set_ray_tracing(
141+
n_rays=num_rays,
142+
energy_thres=energy_thres,
143+
time_thres=time_thres,
144+
hist_bin_size=hist_bin_size,
145+
receiver_radius=mic_radius,
146+
)
147+
room.set_sound_speed(sound_speed)
148+
room.compute_rir()
149+
hist_pra = np.array(room.rt_histograms, dtype=np.double)[:, 0, 0]
150+
151+
hist = F.ray_tracing(
152+
room=room_dim,
153+
source=source,
154+
mic_array=mic_array,
155+
num_rays=num_rays,
156+
absorption=absorption,
157+
scattering=scattering,
158+
sound_speed=sound_speed,
159+
mic_radius=mic_radius,
160+
energy_thres=energy_thres,
161+
time_thres=time_thres,
162+
hist_bin_size=hist_bin_size,
163+
)
164+
165+
assert hist.ndim == 3
166+
assert hist.shape == hist_pra.shape
167+
self.assertEqual(hist, hist_pra)

test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,42 @@ def test_simulate_rir_ism_multi_band(self, channel):
112112
F.simulate_rir_ism,
113113
(room_dim, source, mic_array, max_order, absorption, None, 81, center_frequency, 343.0, 16000.0),
114114
)
115+
116+
@parameterized.expand(
117+
[
118+
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500), # 3D with 1 mic
119+
]
120+
)
121+
def test_ray_tracing(self, room_dim, source, mic_array, num_rays):
122+
num_walls = 4 if len(room_dim) == 2 else 6
123+
num_bands = 3
124+
125+
absorption = torch.rand(num_bands, num_walls, dtype=torch.float32)
126+
scattering = torch.rand(num_bands, num_walls, dtype=torch.float32)
127+
128+
energy_thres = 1e-7
129+
time_thres = 10.0
130+
hist_bin_size = 0.004
131+
mic_radius = 0.5
132+
sound_speed = 343.0
133+
134+
room_dim = torch.tensor(room_dim, dtype=self.dtype)
135+
source = torch.tensor(source, dtype=self.dtype)
136+
mic_array = torch.tensor(mic_array, dtype=self.dtype)
137+
138+
self._assert_consistency(
139+
F.ray_tracing,
140+
(
141+
room_dim,
142+
source,
143+
mic_array,
144+
num_rays,
145+
absorption,
146+
scattering,
147+
mic_radius,
148+
sound_speed,
149+
energy_thres,
150+
time_thres,
151+
hist_bin_size,
152+
),
153+
)

0 commit comments

Comments
 (0)