Skip to content

Commit 99f16c6

Browse files
committed
WIP: Add Ray Tracing (#3604)
Summary: Revamped version of #3234 (which was also revamp of #2850) Differential Revision: D49197174 Pulled By: mthrok
1 parent d9942ba commit 99f16c6

File tree

9 files changed

+562
-17
lines changed

9 files changed

+562
-17
lines changed

docs/source/prototype.functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ Room Impulse Response Simulation
3535
:toctree: generated
3636
:nosignatures:
3737

38+
ray_tracing
3839
simulate_rir_ism

test/cpp/rir/wall_collision.cpp

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@
33

44
using namespace torchaudio::rir;
55

6+
using DTYPE = double;
7+
68
struct CollisionTestParam {
79
// Input
810
torch::Tensor origin;
911
torch::Tensor direction;
1012
// Expected
1113
torch::Tensor hit_point;
1214
int next_wall_index;
13-
float hit_distance;
15+
DTYPE hit_distance;
1416
};
1517

1618
CollisionTestParam par(
17-
torch::ArrayRef<float> origin,
18-
torch::ArrayRef<float> direction,
19-
torch::ArrayRef<float> hit_point,
19+
torch::ArrayRef<DTYPE> origin,
20+
torch::ArrayRef<DTYPE> direction,
21+
torch::ArrayRef<DTYPE> hit_point,
2022
int next_wall_index,
21-
float hit_distance) {
23+
DTYPE hit_distance) {
2224
auto dir = torch::tensor(direction);
2325
return {
2426
torch::tensor(origin),
@@ -50,7 +52,7 @@ TEST_P(Simple3DRoomCollisionTest, CollisionTest3D) {
5052

5153
auto param = GetParam();
5254
auto [hit_point, next_wall_index, hit_distance] =
53-
find_collision_wall<float>(room, param.origin, param.direction);
55+
find_collision_wall<DTYPE>(room, param.origin, param.direction);
5456

5557
EXPECT_EQ(param.next_wall_index, next_wall_index);
5658
EXPECT_FLOAT_EQ(param.hit_distance, hit_distance);
@@ -100,3 +102,41 @@ INSTANTIATE_TEST_CASE_P(
100102
par({.5, .5, 1}, {0.0, -1., -1.}, {.5, .0, .5}, 2, ISQRT2),
101103
par({.5, .5, 1}, {0.0, 1.0, -1.}, {.5, 1., .5}, 3, ISQRT2),
102104
par({.5, .5, 1}, {0.0, 0.0, -1.}, {.5, .5, .0}, 4, 1.0)));
105+
106+
107+
INSTANTIATE_TEST_CASE_P(
108+
EdgeCollisionTest,
109+
Simple3DRoomCollisionTest,
110+
::testing::Values(
111+
par({1, 1, 0}, {1., 1., 0.}, {1., 1., 0.}, 1, 0.0),
112+
par({1, 1, 0}, {-1., 1., 0.}, {1., 1., 0.}, 3, 0.0),
113+
//
114+
par({1, 1, 1}, {1., 1., 1.}, {1., 1., 1.}, 1, 0.0),
115+
par({1, 1, 1}, {-1., 1., 1.}, {1., 1., 1.}, 3, 0.0),
116+
par({1, 1, 1}, {-1., -1., 1.}, {1., 1., 1.}, 5, 0.0)
117+
));
118+
119+
class Simple3DRoomCollisionTest2
120+
: public ::testing::TestWithParam<CollisionTestParam> {};
121+
122+
TEST_P(Simple3DRoomCollisionTest2, CollisionTest3D) {
123+
auto room = torch::tensor({3, 4, 5});
124+
125+
auto param = GetParam();
126+
auto [hit_point, next_wall_index, hit_distance] =
127+
find_collision_wall<DTYPE>(room, param.origin, param.direction);
128+
129+
EXPECT_EQ(param.next_wall_index, next_wall_index);
130+
EXPECT_FLOAT_EQ(param.hit_distance, hit_distance);
131+
EXPECT_NEAR(param.hit_point[0].item<DTYPE>(), hit_point[0].item<DTYPE>(), 1e-5);
132+
EXPECT_NEAR(param.hit_point[1].item<DTYPE>(), hit_point[1].item<DTYPE>(), 1e-5);
133+
EXPECT_NEAR(param.hit_point[2].item<DTYPE>(), hit_point[2].item<DTYPE>(), 1e-5);
134+
}
135+
136+
137+
INSTANTIATE_TEST_CASE_P(
138+
EdgeCollisionTest2,
139+
Simple3DRoomCollisionTest2,
140+
::testing::Values(
141+
par({3., 4., 4.6542}, {-0.9798, 0.1733, 0.1000}, {3., 4., 0.}, 3, 0.0)
142+
));

test/torchaudio_unittest/prototype/functional/functional_test_impl.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,269 @@ def test_exp_sigmoid_input_diff(self, linspace_input_values, exp_sigmoid_paramet
412412

413413
self.assertEqual(torch_out, torch.tensor(np_out))
414414

415+
@parameterized.expand(
416+
[
417+
# both float
418+
(0.1, 0.2, (2, 1, 2500)),
419+
# Per-wall
420+
((6,), 0.2, (2, 1, 2500)),
421+
(0.1, (6,), (2, 1, 2500)),
422+
((6,), (6,), (2, 1, 2500)),
423+
# Per-band and per-wall
424+
((3, 6), 0.2, (2, 3, 2500)),
425+
(0.1, (5, 6), (2, 5, 2500)),
426+
((7, 6), (7, 6), (2, 7, 2500)),
427+
]
428+
)
429+
def test_ray_tracing_output_shape(self, abs_, scat_, expected_shape):
430+
if isinstance(abs_, float):
431+
absorption = abs_
432+
else:
433+
absorption = torch.rand(abs_, dtype=self.dtype)
434+
if isinstance(scat_, float):
435+
scattering = scat_
436+
else:
437+
scattering = torch.rand(scat_, dtype=self.dtype)
438+
439+
room_dim = torch.tensor([3, 4, 5], dtype=self.dtype)
440+
mic_array = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=self.dtype)
441+
source = torch.tensor([1, 2, 3], dtype=self.dtype)
442+
num_rays = 100
443+
444+
hist = F.ray_tracing(
445+
room=room_dim,
446+
source=source,
447+
mic_array=mic_array,
448+
num_rays=num_rays,
449+
absorption=absorption,
450+
scattering=scattering,
451+
)
452+
assert hist.shape == expected_shape
453+
454+
def test_ray_tracing_input_errors(self):
455+
room = torch.tensor([3.0, 4.0, 5.0], dtype=self.dtype)
456+
source = torch.tensor([0.0, 0.0, 0.0], dtype=self.dtype)
457+
mic = torch.tensor([[1.0, 2.0, 3.0]], dtype=self.dtype)
458+
459+
# baseline. This should not raise
460+
_ = F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10)
461+
462+
# invlaid room shape
463+
for invalid in ([[4, 5]], [4, 5, 4, 5]):
464+
invalid = torch.tensor(invalid, dtype=self.dtype)
465+
with self.assertRaises(ValueError) as cm:
466+
F.ray_tracing(room=invalid, source=source, mic_array=mic, num_rays=10)
467+
468+
error = str(cm.exception)
469+
self.assertIn("`room` must be a 1D Tensor with 3 elements.", error)
470+
self.assertIn(str(invalid.shape), error)
471+
472+
# invalid microphone shape
473+
invalid = torch.tensor([[[3, 4]]], dtype=self.dtype)
474+
with self.assertRaises(ValueError) as cm:
475+
F.ray_tracing(room=room, source=source, mic_array=invalid, num_rays=10)
476+
477+
error = str(cm.exception)
478+
self.assertIn("`mic_array` must be a 2D Tensor with shape (num_channels, 3).", error)
479+
self.assertIn(str(invalid.shape), error)
480+
481+
# incompatible dtypes
482+
with self.assertRaises(ValueError) as cm:
483+
F.ray_tracing(
484+
room=room.to(torch.float64),
485+
source=source.to(torch.float32),
486+
mic_array=mic.to(torch.float32),
487+
num_rays=10,
488+
)
489+
error = str(cm.exception)
490+
self.assertIn("dtype of `room`, `source` and `mic_array` must match.", error)
491+
self.assertIn("`room` (torch.float64)", error)
492+
self.assertIn("`source` (torch.float32)", error)
493+
self.assertIn("`mic_array` (torch.float32)", error)
494+
495+
# invalid time configuration
496+
with self.assertRaises(ValueError) as cm:
497+
F.ray_tracing(
498+
room=room,
499+
source=source,
500+
mic_array=mic,
501+
num_rays=10,
502+
time_thres=10,
503+
hist_bin_size=11,
504+
)
505+
error = str(cm.exception)
506+
self.assertIn("`time_thres` must be greater than `hist_bin_size`.", error)
507+
self.assertIn("hist_bin_size=11", error)
508+
self.assertIn("time_thres=10", error)
509+
510+
# invalid absorption shape 1D
511+
invalid_abs = torch.tensor([1, 2, 3], dtype=self.dtype)
512+
with self.assertRaises(ValueError) as cm:
513+
F.ray_tracing(
514+
room=room,
515+
source=source,
516+
mic_array=mic,
517+
num_rays=10,
518+
absorption=invalid_abs,
519+
)
520+
error = str(cm.exception)
521+
self.assertIn("The shape of `absorption` must be (6,) when", error)
522+
self.assertIn(str(invalid_abs.shape), error)
523+
524+
# invalid absorption shape 2D
525+
invalid_abs = torch.tensor([[1, 2, 3]], dtype=self.dtype)
526+
with self.assertRaises(ValueError) as cm:
527+
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, absorption=invalid_abs)
528+
error = str(cm.exception)
529+
self.assertIn("The shape of `absorption` must be (NUM_BANDS, 6) when", error)
530+
self.assertIn(str(invalid_abs.shape), error)
531+
532+
# invalid scattering shape 1D
533+
invalid_scat = torch.tensor([1, 2, 3], dtype=self.dtype)
534+
with self.assertRaises(ValueError) as cm:
535+
F.ray_tracing(
536+
room=room,
537+
source=source,
538+
mic_array=mic,
539+
num_rays=10,
540+
scattering=invalid_scat,
541+
)
542+
error = str(cm.exception)
543+
self.assertIn("The shape of `scattering` must be (6,) when", error)
544+
self.assertIn(str(invalid_scat.shape), error)
545+
546+
# invalid scattering shape 2D
547+
invalid_scat = torch.tensor([[1, 2, 3]], dtype=self.dtype)
548+
with self.assertRaises(ValueError) as cm:
549+
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, scattering=invalid_scat)
550+
error = str(cm.exception)
551+
self.assertIn("The shape of `scattering` must be (NUM_BANDS, 6) when", error)
552+
self.assertIn(str(invalid_scat.shape), error)
553+
554+
# Invalid absorption value
555+
for invalid_val in [-1., torch.tensor([i - 1. for i in range(6)])]:
556+
with self.assertRaises(ValueError) as cm:
557+
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, absorption=invalid_val)
558+
559+
error = str(cm.exception)
560+
self.assertIn("`absorption` must be non-negative`")
561+
562+
# Invalid scattering value
563+
for invalid_val in [-1., torch.tensor([i - 1. for i in range(6)])]:
564+
with self.assertRaises(ValueError) as cm:
565+
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, scattering=invalid_val)
566+
567+
error = str(cm.exception)
568+
self.assertIn("`scattering` must be non-negative`")
569+
570+
# incompatible scattering and absorption
571+
abs_ = torch.zeros((7, 6), dtype=self.dtype)
572+
scat = torch.zeros((5, 6), dtype=self.dtype)
573+
with self.assertRaises(ValueError) as cm:
574+
F.ray_tracing(
575+
room=room,
576+
source=source,
577+
mic_array=mic,
578+
num_rays=10,
579+
absorption=abs_,
580+
scattering=scat,
581+
)
582+
error = str(cm.exception)
583+
self.assertIn(
584+
"`absorption` and `scattering` must be broadcastable to the same number of bands and walls", error
585+
)
586+
self.assertIn(f"absorption={abs_.shape}", error)
587+
self.assertIn(f"scattering={scat.shape}", error)
588+
589+
# Make sure passing different shapes for absorption or scattering doesn't raise an error
590+
# float and tensor
591+
F.ray_tracing(
592+
room=room,
593+
source=source,
594+
mic_array=mic,
595+
num_rays=10,
596+
absorption=0.1,
597+
scattering=torch.rand((5, 6), dtype=self.dtype),
598+
)
599+
F.ray_tracing(
600+
room=room,
601+
source=source,
602+
mic_array=mic,
603+
num_rays=10,
604+
absorption=torch.rand((7, 6), dtype=self.dtype),
605+
scattering=0.1,
606+
)
607+
# per-wall only and per-band + per-wall
608+
F.ray_tracing(
609+
room=room,
610+
source=source,
611+
mic_array=mic,
612+
num_rays=10,
613+
absorption=torch.rand(6, dtype=self.dtype),
614+
scattering=torch.rand(7, 6, dtype=self.dtype),
615+
)
616+
F.ray_tracing(
617+
room=room,
618+
source=source,
619+
mic_array=mic,
620+
num_rays=10,
621+
absorption=torch.rand(7, 6, dtype=self.dtype),
622+
scattering=torch.rand(6, dtype=self.dtype),
623+
)
624+
625+
def test_ray_tracing_per_band_per_wall_absorption(self):
626+
"""Check that when the value of absorption and scattering are the same
627+
across walls and frequency bands, the output histograms are:
628+
- all equal across frequency bands
629+
- equal to simply passing a float value instead of a (num_bands, D) or
630+
(D,) tensor.
631+
"""
632+
633+
room_dim = torch.tensor([20, 25, 5], dtype=self.dtype)
634+
mic_array = torch.tensor([[2, 2, 0], [8, 8, 0]], dtype=self.dtype)
635+
source = torch.tensor([7, 6, 0], dtype=self.dtype)
636+
num_rays = 1_000
637+
ABS, SCAT = 0.1, 0.2
638+
639+
absorption = torch.full(fill_value=ABS, size=(7, 6), dtype=self.dtype)
640+
scattering = torch.full(fill_value=SCAT, size=(7, 6), dtype=self.dtype)
641+
hist_per_band_per_wall = F.ray_tracing(
642+
room=room_dim,
643+
source=source,
644+
mic_array=mic_array,
645+
num_rays=num_rays,
646+
absorption=absorption,
647+
scattering=scattering,
648+
)
649+
absorption = torch.full(fill_value=ABS, size=(6,), dtype=self.dtype)
650+
scattering = torch.full(fill_value=SCAT, size=(6,), dtype=self.dtype)
651+
hist_per_wall = F.ray_tracing(
652+
room=room_dim,
653+
source=source,
654+
mic_array=mic_array,
655+
num_rays=num_rays,
656+
absorption=absorption,
657+
scattering=scattering,
658+
)
659+
660+
absorption = ABS
661+
scattering = SCAT
662+
hist_single = F.ray_tracing(
663+
room=room_dim,
664+
source=source,
665+
mic_array=mic_array,
666+
num_rays=num_rays,
667+
absorption=absorption,
668+
scattering=scattering,
669+
)
670+
self.assertEqual(hist_per_band_per_wall.shape, (2, 7, 2500))
671+
self.assertEqual(hist_per_wall.shape, (2, 1, 2500))
672+
self.assertEqual(hist_single.shape, (2, 1, 2500))
673+
torch.testing.assert_close(hist_single, hist_per_wall)
674+
675+
hist_single = hist_single.expand(hist_per_band_per_wall.shape)
676+
torch.testing.assert_close(hist_single, hist_per_band_per_wall)
677+
415678

416679
class Functional64OnlyTestImpl(TestBaseMixin):
417680
@nested_params(

0 commit comments

Comments
 (0)