Skip to content

Commit ff266b1

Browse files
authored
Add the rest of C++ ray tracing implementation (#3630)
Taken from #3234
1 parent b7791ea commit ff266b1

File tree

3 files changed

+377
-3
lines changed

3 files changed

+377
-3
lines changed

torchaudio/csrc/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ if(BUILD_RNNT)
4242
endif()
4343

4444
if(BUILD_RIR)
45-
list(APPEND sources rir/rir.cpp)
45+
list(APPEND sources rir/rir.cpp rir/ray_tracing.cpp)
4646
list(APPEND compile_definitions INCLUDE_RIR)
4747
endif()
4848

torchaudio/csrc/rir/ray_tracing.cpp

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
/*
2+
Copyright (c) 2014-2017 EPFL-LCAV
3+
4+
Permission is hereby granted, free of charge, to any person obtaining a copy
5+
of this software and associated documentation files (the "Software"), to deal
6+
in the Software without restriction, including without limitation the rights
7+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
copies of the Software, and to permit persons to whom the Software is
9+
furnished to do so, subject to the following conditions:
10+
11+
The above copyright notice and this permission notice shall be included in all
12+
copies or substantial portions of the Software.
13+
14+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20+
SOFTWARE.
21+
*/
22+
23+
//
24+
// Ray tracing implementation. This is heavily based on PyRoomAcoustics:
25+
// https://github.com/LCAV/pyroomacoustics
26+
//
27+
#include <torch/script.h>
28+
#include <torch/torch.h>
29+
#include <torchaudio/csrc/rir/wall.h>
30+
#include <cmath>
31+
32+
namespace torchaudio {
33+
namespace rir {
34+
namespace {
35+
36+
// TODO: remove this once hybrid method is supported
37+
const bool IS_HYBRID_SIM = false;
38+
39+
// TODO: remove this once ISM method is supported
40+
const int ISM_ORDER = 10;
41+
42+
#define EPS ((scalar_t)(1e-5))
43+
#define VAL(x) ((x).template item<scalar_t>())
44+
#define NORM(x) (VAL((x).norm()))
45+
#define MAX(x) (VAL((x).max()))
46+
#define IN_RANGE(x, y) ((-EPS < (x)) && ((x) < (y) + EPS))
47+
48+
template <typename scalar_t, unsigned int D>
49+
const std::array<Wall<scalar_t>, D * 2> make_walls(
50+
const torch::Tensor& room,
51+
const torch::Tensor& absorption,
52+
const torch::Tensor& scattering) {
53+
if constexpr (D == 2) {
54+
auto w = room.index({0}).item<scalar_t>();
55+
auto l = room.index({1}).item<scalar_t>();
56+
return make_room<scalar_t>(w, l, absorption, scattering);
57+
}
58+
if constexpr (D == 3) {
59+
auto w = room.index({0}).item<scalar_t>();
60+
auto l = room.index({1}).item<scalar_t>();
61+
auto h = room.index({2}).item<scalar_t>();
62+
return make_room<scalar_t>(w, l, h, absorption, scattering);
63+
}
64+
}
65+
66+
inline double get_energy_coeff(
67+
const double travel_dist,
68+
const double mic_radius_sq) {
69+
double sq = travel_dist * travel_dist;
70+
auto p_hit = 1. - std::sqrt(1. - mic_radius_sq / std::max(mic_radius_sq, sq));
71+
return sq * p_hit;
72+
}
73+
74+
/// RayTracer class helper for ray tracing.
75+
/// For attribute description, Python wrapper.
76+
template <typename scalar_t, unsigned int D>
77+
class RayTracer {
78+
// Provided parameters
79+
const torch::Tensor& room;
80+
const torch::Tensor& mic_array;
81+
const double mic_radius;
82+
83+
// Values derived from the parameters
84+
const int num_bands;
85+
const double mic_radius_sq;
86+
const bool do_scattering; // Whether scattering is needed (scattering != 0)
87+
const std::array<Wall<scalar_t>, D * 2> walls; // The walls of the room
88+
89+
// Runtime value caches
90+
// Updated at the beginning of the simulation
91+
double sound_speed = 343.0;
92+
double distance_thres = 10.0 * sound_speed; // upper bound
93+
double energy_thres = 0.0; // lower bound
94+
double hist_bin_width = 0.004; // [second]
95+
96+
public:
97+
RayTracer(
98+
const torch::Tensor& room,
99+
const torch::Tensor& absorption,
100+
const torch::Tensor& scattering,
101+
const torch::Tensor& mic_array,
102+
const double mic_radius)
103+
: room(room),
104+
mic_array(mic_array),
105+
mic_radius(mic_radius),
106+
num_bands(absorption.size(0)),
107+
mic_radius_sq(mic_radius * mic_radius),
108+
do_scattering(MAX(scattering) > 0.),
109+
walls(make_walls<scalar_t, D>(room, absorption, scattering)) {}
110+
111+
// The main (and only) public entry point of this class. The histograms Tensor
112+
// reference is passed along and modified in the subsequent private method
113+
// calls. This method spawns num_rays rays in all directions from the source
114+
// and calls simul_ray() on each of them.
115+
torch::Tensor compute_histograms(
116+
const torch::Tensor& origin,
117+
int num_rays,
118+
double time_thres,
119+
double energy_thres_ratio,
120+
double sound_speed_,
121+
int num_bins) {
122+
scalar_t energy_0 = 2. / num_rays;
123+
auto energies = torch::full({num_bands}, energy_0, room.options());
124+
125+
auto histograms =
126+
torch::zeros({mic_array.size(0), num_bins, num_bands}, room.options());
127+
128+
// Cache runtime parameters
129+
sound_speed = sound_speed_;
130+
energy_thres = energy_0 * energy_thres_ratio;
131+
distance_thres = time_thres * sound_speed;
132+
hist_bin_width = time_thres / num_bins;
133+
134+
// TODO: the for loop can be parallelized over num_rays by creating
135+
// `num_threads` histograms and then sum-reducing them into a single
136+
// histogram.
137+
static_assert(D == 2 || D == 3, "Only 2D and 3D are supported.");
138+
if constexpr (D == 2) {
139+
scalar_t delta = 2. * M_PI / num_rays;
140+
for (int i = 0; i < num_rays; ++i) {
141+
scalar_t phi = i * delta;
142+
auto dir = torch::tensor({cos(phi), sin(phi)}, room.scalar_type());
143+
simul_ray(energies, origin, dir, histograms);
144+
}
145+
} else {
146+
scalar_t delta = 2. / num_rays;
147+
scalar_t increment = M_PI * (3. - std::sqrt(5.)); // phi increment
148+
149+
for (auto i = 0; i < num_rays; ++i) {
150+
auto z = (i * delta - 1) + delta / 2.;
151+
auto rho = std::sqrt(1. - z * z);
152+
153+
scalar_t phi = i * increment;
154+
155+
auto x = cos(phi) * rho;
156+
auto y = sin(phi) * rho;
157+
158+
auto azimuth = atan2(y, x);
159+
auto colatitude = atan2(std::sqrt(x * x + y * y), z);
160+
161+
auto dir = torch::tensor(
162+
{sin(colatitude) * cos(azimuth),
163+
sin(colatitude) * sin(azimuth),
164+
cos(colatitude)},
165+
room.scalar_type());
166+
167+
simul_ray(energies, origin, dir, histograms);
168+
}
169+
}
170+
return histograms.transpose(1, 2); // (num_mics, num_bands, num_bins)
171+
}
172+
173+
private:
174+
/// Get the bin index from the distance traveled to a mic.
175+
inline int get_bin_idx(scalar_t travel_dist_at_mic) {
176+
auto time_at_mic = travel_dist_at_mic / sound_speed;
177+
return (int)floor(time_at_mic / hist_bin_width);
178+
}
179+
180+
///
181+
/// Traces a single ray. phi (horizontal) and theta (vectorical) are the
182+
/// angles of the ray from the source. Theta is 0 for 2D rooms. When a ray
183+
/// intersects a wall, it is reflected and part of its energy is absorbed. It
184+
/// is also scattered (sent directly to the microphone(s)) according to the
185+
/// scattering coefficient. When a ray is close to the microphone, its current
186+
/// energy is recoreded in the output histogram for that given time slot.
187+
///
188+
/// See also:
189+
/// https://github.com/LCAV/pyroomacoustics/blob/df8af24c88a87b5d51c6123087cd3cd2d361286a/pyroomacoustics/libroom_src/room.cpp#L855-L986
190+
void simul_ray(
191+
torch::Tensor& energies,
192+
torch::Tensor origin,
193+
torch::Tensor dir,
194+
torch::Tensor& histograms) {
195+
auto travel_dist = 0.;
196+
// To count the number of times the ray bounces on the walls
197+
// For hybrid generation we add a ray to output only if specular_counter
198+
// is higher than the ism order.
199+
int specular_counter = 0;
200+
while (true) {
201+
// Find the next hit point
202+
auto [hit_point, next_wall_index, hit_distance] =
203+
find_collision_wall<scalar_t, D>(room, origin, dir);
204+
205+
auto& wall = walls[next_wall_index];
206+
207+
// Check if the specular ray hits any of the microphone
208+
if (!(IS_HYBRID_SIM && specular_counter < ISM_ORDER)) {
209+
// Compute the distance between the line defined by (origin, hit_point)
210+
// and the center of the microphone (mic_pos)
211+
212+
for (auto mic_idx = 0; mic_idx < mic_array.size(0); mic_idx++) {
213+
//
214+
// _ o microphone
215+
// to_mic / | ^
216+
// / | wall
217+
// / | mic radious | |
218+
// origin / | | |
219+
// / v | |
220+
// x ---------------------------> |x| collision
221+
//
222+
// | <--------> |
223+
// impact_distance
224+
// | <--------------------------> |
225+
// hit_distance
226+
//
227+
torch::Tensor to_mic = mic_array[mic_idx] - origin;
228+
scalar_t impact_distance = VAL(to_mic.dot(dir));
229+
230+
// mic is further than the collision point.
231+
// So microphone did not pick up the sound.
232+
if (!IN_RANGE(impact_distance, hit_distance)) {
233+
continue;
234+
}
235+
236+
// If the ray hit the coverage of the mic, compute the energy
237+
if (NORM(to_mic - dir * impact_distance) < mic_radius + EPS) {
238+
// The length of this last hop
239+
auto travel_dist_at_mic = travel_dist + std::abs(impact_distance);
240+
auto coeff = get_energy_coeff(travel_dist_at_mic, mic_radius_sq);
241+
auto energy = energies / coeff;
242+
histograms[mic_idx][get_bin_idx(travel_dist_at_mic)] += energy;
243+
}
244+
}
245+
}
246+
247+
travel_dist += hit_distance;
248+
energies *= wall.reflection;
249+
250+
// Let's shoot the scattered ray induced by the rebound on the wall
251+
if (do_scattering) {
252+
scat_ray(histograms, wall, energies, origin, hit_point, travel_dist);
253+
energies *= (1. - wall.scattering);
254+
}
255+
256+
// Check if we reach the thresholds for this ray
257+
if (travel_dist > distance_thres || VAL(energies.max()) < energy_thres) {
258+
break;
259+
}
260+
261+
// set up for next iteration
262+
specular_counter += 1;
263+
dir = reflect(wall, dir);
264+
origin = hit_point;
265+
}
266+
}
267+
268+
///
269+
/// Scatters a ray towards the microphone(s), i.e. records its scattered
270+
/// energy in the histogram. Called when a ray hits a wall.
271+
///
272+
/// See also:
273+
/// https://github.com/LCAV/pyroomacoustics/blob/df8af24c88a87b5d51c6123087cd3cd2d361286a/pyroomacoustics/libroom_src/room.cpp#L761-L853
274+
void scat_ray(
275+
torch::Tensor& histograms,
276+
const Wall<scalar_t>& wall,
277+
const torch::Tensor& energies,
278+
const torch::Tensor& prev_hit_point,
279+
const torch::Tensor& hit_point,
280+
scalar_t travel_dist) {
281+
for (auto mic_idx = 0; mic_idx < mic_array.size(0); mic_idx++) {
282+
auto mic_pos = mic_array[mic_idx];
283+
if (side(wall, mic_pos) != side(wall, prev_hit_point)) {
284+
continue;
285+
}
286+
287+
// As the ray is shot towards the microphone center,
288+
// the hop dist can be easily computed
289+
torch::Tensor hit_point_to_mic = mic_pos - hit_point;
290+
auto hop_dist = NORM(hit_point_to_mic);
291+
auto travel_dist_at_mic = travel_dist + hop_dist;
292+
293+
// compute the scattered energy reaching the microphone
294+
auto h_sq = hop_dist * hop_dist;
295+
auto p_hit_equal = 1. - std::sqrt(1. - mic_radius_sq / h_sq);
296+
// cosine angle should be positive, but could be negative if normal is
297+
// facing out of room so we take abs
298+
auto p_lambert = (scalar_t)2. * std::abs(cosine(wall, hit_point_to_mic));
299+
auto scat_trans = wall.scattering * energies * p_hit_equal * p_lambert;
300+
301+
if (travel_dist_at_mic < distance_thres &&
302+
MAX(scat_trans) > energy_thres) {
303+
auto coeff = get_energy_coeff(travel_dist_at_mic, mic_radius_sq);
304+
auto energy = scat_trans / coeff;
305+
histograms[mic_idx][get_bin_idx(travel_dist_at_mic)] += energy;
306+
}
307+
}
308+
}
309+
};
310+
311+
///
312+
/// @brief Compute energy histogram via ray tracing. See Python wrapper for
313+
/// detail about parameters and output.
314+
///
315+
torch::Tensor ray_tracing(
316+
const torch::Tensor& room,
317+
const torch::Tensor& source,
318+
const torch::Tensor& mic_array,
319+
int64_t num_rays,
320+
const torch::Tensor& absorption,
321+
const torch::Tensor& scattering,
322+
double mic_radius,
323+
double sound_speed,
324+
double energy_thres,
325+
double time_thres, // TODO: rename to duration
326+
double hist_bin_size) {
327+
// TODO: Raise this to Python layer
328+
auto num_bins = (int)ceil(time_thres / hist_bin_size);
329+
switch (room.size(0)) {
330+
case 2: {
331+
return AT_DISPATCH_FLOATING_TYPES(
332+
room.scalar_type(), "ray_tracing_2d", [&] {
333+
RayTracer<scalar_t, 2> rt(
334+
room, mic_array, absorption, scattering, mic_radius);
335+
return rt.compute_histograms(
336+
source,
337+
num_rays,
338+
time_thres,
339+
energy_thres,
340+
sound_speed,
341+
num_bins);
342+
});
343+
}
344+
case 3: {
345+
return AT_DISPATCH_FLOATING_TYPES(
346+
room.scalar_type(), "ray_tracing_3d", [&] {
347+
RayTracer<scalar_t, 3> rt(
348+
room, mic_array, absorption, scattering, mic_radius);
349+
return rt.compute_histograms(
350+
source,
351+
num_rays,
352+
time_thres,
353+
energy_thres,
354+
sound_speed,
355+
num_bins);
356+
});
357+
}
358+
default:
359+
TORCH_CHECK(false, "Only 2D and 3D are supported.");
360+
}
361+
}
362+
363+
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
364+
m.impl("torchaudio::ray_tracing", torchaudio::rir::ray_tracing);
365+
}
366+
367+
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
368+
m.def(
369+
"torchaudio::ray_tracing(Tensor room, Tensor source, Tensor mic_array, int num_rays, Tensor absorption, Tensor scattering, float mic_radius, float sound_speed, float energy_thres, float time_thres, float hist_bin_size) -> Tensor");
370+
}
371+
372+
} // namespace
373+
} // namespace rir
374+
} // namespace torchaudio

0 commit comments

Comments
 (0)