Skip to content

Commit 20b92cd

Browse files
committed
WIP: Add Ray Tracing (#3604)
Summary: Revamped version of #3234 (which was also revamp of #2850) Differential Revision: D49197174 Pulled By: mthrok
1 parent c89e7a5 commit 20b92cd

File tree

9 files changed

+550
-22
lines changed

9 files changed

+550
-22
lines changed

docs/source/prototype.functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ Room Impulse Response Simulation
3535
:toctree: generated
3636
:nosignatures:
3737

38+
ray_tracing
3839
simulate_rir_ism

src/torchaudio/csrc/rir/ray_tracing.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,17 +220,21 @@ class RayTracer {
220220
if (NORM(to_mic - dir * impact_distance) < mic_radius + EPS) {
221221
// The length of this last hop
222222
auto travel_dist_at_mic = travel_dist + std::abs(impact_distance);
223+
auto bin_idx = get_bin_idx(travel_dist_at_mic);
224+
if (bin_idx >= histograms.size(1)) {
225+
continue;
226+
}
223227
auto coeff = get_energy_coeff(travel_dist_at_mic, mic_radius_sq);
224228
auto energy = energies / coeff;
225-
histograms[mic_idx][get_bin_idx(travel_dist_at_mic)] += energy;
229+
histograms[mic_idx][bin_idx] += energy;
226230
}
227231
}
228232
}
229233

230234
travel_dist += hit_distance;
231235
energies *= wall.reflection;
232236

233-
// Let's shoot the scattered ray induced by the rebound on the wall
237+
// Let's shoot the scattered ray induced by the rebound on the wall
234238
if (do_scattering) {
235239
scat_ray(histograms, wall, energies, origin, hit_point, travel_dist);
236240
energies *= (1. - wall.scattering);

src/torchaudio/csrc/rir/wall.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,15 @@ struct Wall {
1818
const torch::Tensor origin;
1919
const torch::Tensor normal;
2020
const torch::Tensor scattering;
21-
2221
const torch::Tensor reflection;
2322

2423
Wall(
2524
const torch::ArrayRef<scalar_t>& origin,
2625
const torch::ArrayRef<scalar_t>& normal,
2726
const torch::Tensor& absorption,
2827
const torch::Tensor& scattering)
29-
: origin(torch::tensor(origin)),
30-
normal(torch::tensor(normal)),
28+
: origin(torch::tensor(origin).to(scattering.dtype())),
29+
normal(torch::tensor(normal).to(scattering.dtype())),
3130
scattering(scattering),
3231
reflection(1. - absorption) {}
3332
};
@@ -136,7 +135,6 @@ std::tuple<torch::Tensor, int, scalar_t> find_collision_wall(
136135
for (unsigned int i = 0; i < 3; ++i) {
137136
auto dir0 = SCALAR(direction[i]);
138137
auto abs_dir0 = std::abs(dir0);
139-
140138
// If the ray is almost parallel to a plane, then we delegate the
141139
// computation to the other planes.
142140
if (abs_dir0 < EPS) {
@@ -147,6 +145,10 @@ std::tuple<torch::Tensor, int, scalar_t> find_collision_wall(
147145
scalar_t distance = (dir0 < 0.)
148146
? SCALAR(origin[i]) // Going towards origin
149147
: SCALAR(room[i] - origin[i]); // Going away from origin
148+
// sometimes origin is slightly outside of room
149+
if (distance < 0) {
150+
distance = 0.;
151+
}
150152
auto ratio = distance / abs_dir0;
151153
int i_increment = dir0 > 0.;
152154

src/torchaudio/prototype/functional/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
oscillator_bank,
88
sinc_impulse_response,
99
)
10-
from ._rir import simulate_rir_ism
10+
from ._rir import ray_tracing, simulate_rir_ism
1111
from .functional import barkscale_fbanks, chroma_filterbank
1212

1313

@@ -20,6 +20,7 @@
2020
"filter_waveform",
2121
"frequency_impulse_response",
2222
"oscillator_bank",
23+
"ray_tracing",
2324
"sinc_impulse_response",
2425
"simulate_rir_ism",
2526
]

src/torchaudio/prototype/functional/_rir.py

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,20 +133,24 @@ def _adjust_coeff(coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor
133133
"""
134134
num_walls = 6
135135
if isinstance(coeffs, float):
136+
if coeffs < 0:
137+
raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
136138
return torch.full((1, num_walls), coeffs)
137139
if isinstance(coeffs, Tensor):
140+
if torch.any(coeffs < 0):
141+
raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
138142
if coeffs.ndim == 1:
139143
if coeffs.numel() != num_walls:
140144
raise ValueError(
141-
f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor."
145+
f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor. "
142146
f"Found the shape {coeffs.shape}."
143147
)
144148
return coeffs.unsqueeze(0)
145149
if coeffs.ndim == 2:
146-
if coeffs.shape != (7, num_walls):
150+
if coeffs.shape[1] != num_walls:
147151
raise ValueError(
148-
f"The shape of `{name}` must be (7, {num_walls}) when it is a 2D Tensor."
149-
f"Found the shape {coeffs.shape}."
152+
f"The shape of `{name}` must be (NUM_BANDS, {num_walls}) when it "
153+
f"is a 2D Tensor. Found: {coeffs.shape}."
150154
)
151155
return coeffs
152156
raise TypeError(f"`{name}` must be float or Tensor.")
@@ -169,7 +173,7 @@ def _validate_inputs(
169173
if not (source.ndim == 1 and source.numel() == 3):
170174
raise ValueError(f"`source` must be 1D Tensor with 3 elements. Found {source.shape}.")
171175
if not (mic_array.ndim == 2 and mic_array.shape[1] == 3):
172-
raise ValueError(f"mic_array must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")
176+
raise ValueError(f"`mic_array` must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")
173177

174178

175179
def simulate_rir_ism(
@@ -270,3 +274,106 @@ def simulate_rir_ism(
270274
rir = rir[..., :output_length]
271275

272276
return rir
277+
278+
279+
def ray_tracing(
280+
room: torch.Tensor,
281+
source: torch.Tensor,
282+
mic_array: torch.Tensor,
283+
num_rays: int,
284+
absorption: Union[float, torch.Tensor] = 0.0,
285+
scattering: Union[float, torch.Tensor] = 0.0,
286+
mic_radius: float = 0.5,
287+
sound_speed: float = 343.0,
288+
energy_thres: float = 1e-7,
289+
time_thres: float = 10.0,
290+
hist_bin_size: float = 0.004,
291+
) -> torch.Tensor:
292+
r"""Compute energy histogram via ray tracing.
293+
294+
The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`.
295+
296+
``num_rays`` rays are casted uniformly in all directions from the source;
297+
when a ray intersects a wall, it is reflected and part of its energy is absorbed.
298+
It is also scattered (sent directly to the microphone(s)) according to the ``scattering``
299+
coefficient.
300+
When a ray is close to the microphone, its current energy is recorded in the output
301+
histogram for that given time slot.
302+
303+
.. devices:: CPU
304+
305+
.. properties:: TorchScript
306+
307+
Args:
308+
room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
309+
three dimensions of the room.
310+
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
311+
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
312+
absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials.
313+
(Default: ``0.0``).
314+
If the type is ``float``, the absorption coefficient is identical to all walls and
315+
all frequencies.
316+
If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, representing absorption
317+
coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and
318+
``"ceiling"``, respectively.
319+
If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 6)`.
320+
``num_bands`` is the number of frequency bands (usually 7).
321+
scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. (Default: ``0.0``)
322+
The shape and type of this parameter is the same as for ``absorption``.
323+
mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5)
324+
sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``)
325+
energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``)
326+
The initial energy of each ray is ``2 / num_rays``.
327+
time_thres (float, optional): The maximal duration for which rays are traced. (Unit: seconds) (Default: 10.0)
328+
hist_bin_size (float, optional): The size of each bin in the output histogram. (Unit: seconds) (Default: 0.004)
329+
330+
Returns:
331+
(torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded.
332+
Each bin corresponds to a given time slot.
333+
The shape is `(channel, num_bands, num_bins)`, where
334+
``num_bins = ceil(time_thres / hist_bin_size)``.
335+
If both ``absorption`` and ``scattering`` are floats, then ``num_bands == 1``.
336+
"""
337+
if time_thres < hist_bin_size:
338+
raise ValueError(
339+
"`time_thres` must be greater than `hist_bin_size`. "
340+
f"Found: hist_bin_size={hist_bin_size}, time_thres={time_thres}."
341+
)
342+
343+
if room.dtype != source.dtype or source.dtype != mic_array.dtype:
344+
raise ValueError(
345+
"dtype of `room`, `source` and `mic_array` must match. "
346+
f"Found: `room` ({room.dtype}), `source` ({source.dtype}) and "
347+
f"`mic_array` ({mic_array.dtype})"
348+
)
349+
350+
_validate_inputs(room, source, mic_array)
351+
absorption = _adjust_coeff(absorption, "absorption").to(room.dtype)
352+
scattering = _adjust_coeff(scattering, "scattering").to(room.dtype)
353+
354+
# Bring absorption and scattering to the same shape
355+
if absorption.shape[0] == 1 and scattering.shape[0] > 1:
356+
absorption = absorption.expand(scattering.shape)
357+
if scattering.shape[0] == 1 and absorption.shape[0] > 1:
358+
scattering = scattering.expand(absorption.shape)
359+
if absorption.shape != scattering.shape:
360+
raise ValueError(
361+
"`absorption` and `scattering` must be broadcastable to the same number of bands and walls. "
362+
f"Inferred shapes absorption={absorption.shape} and scattering={scattering.shape}"
363+
)
364+
365+
histograms = torch.ops.torchaudio.ray_tracing(
366+
room,
367+
source,
368+
mic_array,
369+
num_rays,
370+
absorption,
371+
scattering,
372+
mic_radius,
373+
sound_speed,
374+
energy_thres,
375+
time_thres,
376+
hist_bin_size,
377+
)
378+
379+
return histograms

test/cpp/rir/wall_collision.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@
33

44
using namespace torchaudio::rir;
55

6+
using DTYPE = double;
7+
68
struct CollisionTestParam {
79
// Input
810
torch::Tensor origin;
911
torch::Tensor direction;
1012
// Expected
1113
torch::Tensor hit_point;
1214
int next_wall_index;
13-
float hit_distance;
15+
DTYPE hit_distance;
1416
};
1517

1618
CollisionTestParam par(
17-
torch::ArrayRef<float> origin,
18-
torch::ArrayRef<float> direction,
19-
torch::ArrayRef<float> hit_point,
19+
torch::ArrayRef<DTYPE> origin,
20+
torch::ArrayRef<DTYPE> direction,
21+
torch::ArrayRef<DTYPE> hit_point,
2022
int next_wall_index,
21-
float hit_distance) {
23+
DTYPE hit_distance) {
2224
auto dir = torch::tensor(direction);
2325
return {
2426
torch::tensor(origin),
@@ -50,18 +52,22 @@ TEST_P(Simple3DRoomCollisionTest, CollisionTest3D) {
5052

5153
auto param = GetParam();
5254
auto [hit_point, next_wall_index, hit_distance] =
53-
find_collision_wall<float>(room, param.origin, param.direction);
55+
find_collision_wall<DTYPE>(room, param.origin, param.direction);
5456

5557
EXPECT_EQ(param.next_wall_index, next_wall_index);
5658
EXPECT_FLOAT_EQ(param.hit_distance, hit_distance);
57-
EXPECT_TRUE(torch::allclose(
58-
param.hit_point, hit_point, /*rtol*/ 1e-05, /*atol*/ 1e-07));
59+
EXPECT_NEAR(
60+
param.hit_point[0].item<DTYPE>(), hit_point[0].item<DTYPE>(), 1e-5);
61+
EXPECT_NEAR(
62+
param.hit_point[1].item<DTYPE>(), hit_point[1].item<DTYPE>(), 1e-5);
63+
EXPECT_NEAR(
64+
param.hit_point[2].item<DTYPE>(), hit_point[2].item<DTYPE>(), 1e-5);
5965
}
6066

6167
#define ISQRT2 0.70710678118
6268

6369
INSTANTIATE_TEST_CASE_P(
64-
Collision3DTests,
70+
BasicCollisionTests,
6571
Simple3DRoomCollisionTest,
6672
::testing::Values(
6773
// From 0
@@ -100,3 +106,13 @@ INSTANTIATE_TEST_CASE_P(
100106
par({.5, .5, 1}, {0.0, -1., -1.}, {.5, .0, .5}, 2, ISQRT2),
101107
par({.5, .5, 1}, {0.0, 1.0, -1.}, {.5, 1., .5}, 3, ISQRT2),
102108
par({.5, .5, 1}, {0.0, 0.0, -1.}, {.5, .5, .0}, 4, 1.0)));
109+
110+
INSTANTIATE_TEST_CASE_P(
111+
CornerCollisionTest,
112+
Simple3DRoomCollisionTest,
113+
::testing::Values(
114+
par({1, 1, 0}, {1., 1., 0.}, {1., 1., 0.}, 1, 0.0),
115+
par({1, 1, 0}, {-1., 1., 0.}, {1., 1., 0.}, 3, 0.0),
116+
par({1, 1, 1}, {1., 1., 1.}, {1., 1., 1.}, 1, 0.0),
117+
par({1, 1, 1}, {-1., 1., 1.}, {1., 1., 1.}, 3, 0.0),
118+
par({1, 1, 1}, {-1., -1., 1.}, {1., 1., 1.}, 5, 0.0)));

0 commit comments

Comments
 (0)