Skip to content

Commit 42c8292

Browse files
committed
Add the rest of C++ ray tracing implementation
Taken from pytorch#3234
1 parent b7791ea commit 42c8292

File tree

3 files changed

+379
-3
lines changed

3 files changed

+379
-3
lines changed

torchaudio/csrc/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ if(BUILD_RNNT)
4242
endif()
4343

4444
if(BUILD_RIR)
45-
list(APPEND sources rir/rir.cpp)
45+
list(APPEND sources rir/rir.cpp rir/ray_tracing.cpp)
4646
list(APPEND compile_definitions INCLUDE_RIR)
4747
endif()
4848

torchaudio/csrc/rir/ray_tracing.cpp

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
/*
2+
Copyright (c) 2014-2017 EPFL-LCAV
3+
4+
Permission is hereby granted, free of charge, to any person obtaining a copy
5+
of this software and associated documentation files (the "Software"), to deal
6+
in the Software without restriction, including without limitation the rights
7+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
copies of the Software, and to permit persons to whom the Software is
9+
furnished to do so, subject to the following conditions:
10+
11+
The above copyright notice and this permission notice shall be included in all
12+
copies or substantial portions of the Software.
13+
14+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20+
SOFTWARE.
21+
*/
22+
23+
/**
24+
* Ray tracing implementation. This is heavily based on PyRoomAcoustics:
25+
* https://github.com/LCAV/pyroomacoustics
26+
*/
27+
#include <torch/script.h>
28+
#include <torch/torch.h>
29+
#include <torchaudio/csrc/rir/wall.h>
30+
#include <cmath>
31+
32+
namespace torchaudio {
33+
namespace rir {
34+
namespace {
35+
36+
// TODO: remove this once hybrid method is supported
37+
const bool IS_HYBRID_SIM = false;
38+
39+
// TODO: remove this once ISM method is supported
40+
const int ISM_ORDER = 10;
41+
42+
#define EPS ((scalar_t)(1e-5))
43+
#define VAL(x) ((x).template item<scalar_t>())
44+
#define NORM(x) (VAL((x).norm()))
45+
#define MAX(x) (VAL((x).max()))
46+
#define IN_RANGE(x, y) ((-EPS < (x)) && ((x) < (y) + EPS))
47+
48+
template <typename scalar_t, unsigned int D>
49+
const std::array<Wall<scalar_t>, D * 2> make_walls(
50+
const torch::Tensor& room,
51+
const torch::Tensor& absorption,
52+
const torch::Tensor& scattering) {
53+
if constexpr (D == 2) {
54+
auto w = room.index({0}).item<scalar_t>();
55+
auto l = room.index({1}).item<scalar_t>();
56+
return make_room<scalar_t>(w, l, absorption, scattering);
57+
}
58+
if constexpr (D == 3) {
59+
auto w = room.index({0}).item<scalar_t>();
60+
auto l = room.index({1}).item<scalar_t>();
61+
auto h = room.index({2}).item<scalar_t>();
62+
return make_room<scalar_t>(w, l, h, absorption, scattering);
63+
}
64+
}
65+
66+
inline double get_energy_coeff(
67+
const double travel_dist,
68+
const double mic_radius_sq) {
69+
double sq = travel_dist * travel_dist;
70+
auto p_hit = 1. - std::sqrt(1. - mic_radius_sq / std::max(mic_radius_sq, sq));
71+
return sq * p_hit;
72+
}
73+
74+
/// RayTracer class helper for ray tracing.
75+
/// For attribute description, Python wrapper.
76+
template <typename scalar_t, unsigned int D>
77+
class RayTracer {
78+
// Provided parameters
79+
const torch::Tensor& room;
80+
const torch::Tensor& mic_array;
81+
const double mic_radius;
82+
83+
// Values derived from the parameters
84+
const int num_bands;
85+
const double mic_radius_sq;
86+
const bool do_scattering; // Whether scattering is needed (scattering != 0)
87+
const std::array<Wall<scalar_t>, D * 2> walls; // The walls of the room
88+
89+
// Runtime value caches
90+
// Updated at the beginning of the simulation
91+
double sound_speed = 343.0;
92+
double distance_thres = 10.0 * sound_speed; // upper bound
93+
double energy_thres = 0.0; // lower bound
94+
double hist_bin_width = 0.004; // [second]
95+
96+
public:
97+
RayTracer(
98+
const torch::Tensor& room,
99+
const torch::Tensor& absorption,
100+
const torch::Tensor& scattering,
101+
const torch::Tensor& mic_array,
102+
const double mic_radius)
103+
: room(room),
104+
mic_array(mic_array),
105+
mic_radius(mic_radius),
106+
num_bands(absorption.size(0)),
107+
mic_radius_sq(mic_radius * mic_radius),
108+
do_scattering(MAX(scattering) > 0.),
109+
walls(make_walls<scalar_t, D>(room, absorption, scattering)) {}
110+
111+
/**
112+
* The main (and only) public entry point of this class. The histograms Tensor
113+
* reference is passed along and modified in the subsequent private method
114+
* calls. This method spawns num_rays rays in all directions from the source
115+
* and calls simul_ray() on each of them.
116+
*/
117+
torch::Tensor compute_histograms(
118+
const torch::Tensor& origin,
119+
int num_rays,
120+
double time_thres,
121+
double energy_thres_ratio,
122+
double sound_speed_,
123+
int num_bins) {
124+
scalar_t energy_0 = 2. / num_rays;
125+
auto energies = torch::full({num_bands}, energy_0, room.options());
126+
127+
auto histograms =
128+
torch::zeros({mic_array.size(0), num_bins, num_bands}, room.options());
129+
130+
// Cache runtime parameters
131+
sound_speed = sound_speed_;
132+
energy_thres = energy_0 * energy_thres_ratio;
133+
distance_thres = time_thres * sound_speed;
134+
hist_bin_width = time_thres / num_bins;
135+
136+
// TODO: the for loop can be parallelized over num_rays by creating
137+
// `num_threads` histograms and then sum-reducing them into a single
138+
// histogram.
139+
static_assert(D == 2 || D == 3, "Only 2D and 3D are supported.");
140+
if constexpr (D == 2) {
141+
scalar_t delta = 2. * M_PI / num_rays;
142+
for (int i = 0; i < num_rays; ++i) {
143+
scalar_t phi = i * delta;
144+
auto dir = torch::tensor({cos(phi), sin(phi)}, room.scalar_type());
145+
simul_ray(energies, origin, dir, histograms);
146+
}
147+
} else {
148+
scalar_t delta = 2. / num_rays;
149+
scalar_t increment = M_PI * (3. - std::sqrt(5.)); // phi increment
150+
151+
for (auto i = 0; i < num_rays; ++i) {
152+
auto z = (i * delta - 1) + delta / 2.;
153+
auto rho = std::sqrt(1. - z * z);
154+
155+
scalar_t phi = i * increment;
156+
157+
auto x = cos(phi) * rho;
158+
auto y = sin(phi) * rho;
159+
160+
auto azimuth = atan2(y, x);
161+
auto colatitude = atan2(std::sqrt(x * x + y * y), z);
162+
163+
auto dir = torch::tensor(
164+
{sin(colatitude) * cos(azimuth),
165+
sin(colatitude) * sin(azimuth),
166+
cos(colatitude)},
167+
room.scalar_type());
168+
169+
simul_ray(energies, origin, dir, histograms);
170+
}
171+
}
172+
return histograms.transpose(1, 2); // (num_mics, num_bands, num_bins)
173+
}
174+
175+
private:
176+
/// Get the bin index from the distance traveled to a mic.
177+
inline int get_bin_idx(scalar_t travel_dist_at_mic) {
178+
auto time_at_mic = travel_dist_at_mic / sound_speed;
179+
return (int)floor(time_at_mic / hist_bin_width);
180+
}
181+
182+
///
183+
/// Traces a single ray. phi (horizontal) and theta (vectorical) are the
184+
/// angles of the ray from the source. Theta is 0 for 2D rooms. When a ray
185+
/// intersects a wall, it is reflected and part of its energy is absorbed. It
186+
/// is also scattered (sent directly to the microphone(s)) according to the
187+
/// scattering coefficient. When a ray is close to the microphone, its current
188+
/// energy is recoreded in the output histogram for that given time slot.
189+
///
190+
/// See also:
191+
/// https://github.com/LCAV/pyroomacoustics/blob/df8af24c88a87b5d51c6123087cd3cd2d361286a/pyroomacoustics/libroom_src/room.cpp#L855-L986
192+
void simul_ray(
193+
torch::Tensor& energies,
194+
torch::Tensor origin,
195+
torch::Tensor dir,
196+
torch::Tensor& histograms) {
197+
auto travel_dist = 0.;
198+
// To count the number of times the ray bounces on the walls
199+
// For hybrid generation we add a ray to output only if specular_counter
200+
// is higher than the ism order.
201+
int specular_counter = 0;
202+
while (true) {
203+
// Find the next hit point
204+
auto [hit_point, next_wall_index, hit_distance] =
205+
find_collision_wall<scalar_t, D>(room, origin, dir);
206+
207+
auto& wall = walls[next_wall_index];
208+
209+
// Check if the specular ray hits any of the microphone
210+
if (!(IS_HYBRID_SIM && specular_counter < ISM_ORDER)) {
211+
// Compute the distance between the line defined by (origin, hit_point)
212+
// and the center of the microphone (mic_pos)
213+
214+
for (auto mic_idx = 0; mic_idx < mic_array.size(0); mic_idx++) {
215+
//
216+
// _ o microphone
217+
// to_mic / | ^
218+
// / | wall
219+
// / | mic radious | |
220+
// origin / | | |
221+
// / v | |
222+
// x ---------------------------> |x| collision
223+
//
224+
// | <--------> |
225+
// impact_distance
226+
// | <--------------------------> |
227+
// hit_distance
228+
//
229+
torch::Tensor to_mic = mic_array[mic_idx] - origin;
230+
scalar_t impact_distance = VAL(to_mic.dot(dir));
231+
232+
// mic is further than the collision point.
233+
// So microphone did not pick up the sound.
234+
if (!IN_RANGE(impact_distance, hit_distance)) {
235+
continue;
236+
}
237+
238+
// If the ray hit the coverage of the mic, compute the energy
239+
if (NORM(to_mic - dir * impact_distance) < mic_radius + EPS) {
240+
// The length of this last hop
241+
auto travel_dist_at_mic = travel_dist + std::abs(impact_distance);
242+
auto coeff = get_energy_coeff(travel_dist_at_mic, mic_radius_sq);
243+
auto energy = energies / coeff;
244+
histograms[mic_idx][get_bin_idx(travel_dist_at_mic)] += energy;
245+
}
246+
}
247+
}
248+
249+
travel_dist += hit_distance;
250+
energies *= wall.reflection;
251+
252+
// Let's shoot the scattered ray induced by the rebound on the wall
253+
if (do_scattering) {
254+
scat_ray(histograms, wall, energies, origin, hit_point, travel_dist);
255+
energies *= (1. - wall.scattering);
256+
}
257+
258+
// Check if we reach the thresholds for this ray
259+
if (travel_dist > distance_thres || VAL(energies.max()) < energy_thres) {
260+
break;
261+
}
262+
263+
// set up for next iteration
264+
specular_counter += 1;
265+
dir = reflect(wall, dir);
266+
origin = hit_point;
267+
}
268+
}
269+
270+
///
271+
/// Scatters a ray towards the microphone(s), i.e. records its scattered
272+
/// energy in the histogram. Called when a ray hits a wall.
273+
///
274+
/// See also:
275+
/// https://github.com/LCAV/pyroomacoustics/blob/df8af24c88a87b5d51c6123087cd3cd2d361286a/pyroomacoustics/libroom_src/room.cpp#L761-L853
276+
void scat_ray(
277+
torch::Tensor& histograms,
278+
const Wall<scalar_t>& wall,
279+
const torch::Tensor& energies,
280+
const torch::Tensor& prev_hit_point,
281+
const torch::Tensor& hit_point,
282+
scalar_t travel_dist) {
283+
for (auto mic_idx = 0; mic_idx < mic_array.size(0); mic_idx++) {
284+
auto mic_pos = mic_array[mic_idx];
285+
if (side(wall, mic_pos) != side(wall, prev_hit_point)) {
286+
continue;
287+
}
288+
289+
// As the ray is shot towards the microphone center,
290+
// the hop dist can be easily computed
291+
torch::Tensor hit_point_to_mic = mic_pos - hit_point;
292+
auto hop_dist = NORM(hit_point_to_mic);
293+
auto travel_dist_at_mic = travel_dist + hop_dist;
294+
295+
// compute the scattered energy reaching the microphone
296+
auto h_sq = hop_dist * hop_dist;
297+
auto p_hit_equal = 1. - std::sqrt(1. - mic_radius_sq / h_sq);
298+
// cosine angle should be positive, but could be negative if normal is
299+
// facing out of room so we take abs
300+
auto p_lambert = (scalar_t)2. * std::abs(cosine(wall, hit_point_to_mic));
301+
auto scat_trans = wall.scattering * energies * p_hit_equal * p_lambert;
302+
303+
if (travel_dist_at_mic < distance_thres &&
304+
MAX(scat_trans) > energy_thres) {
305+
auto coeff = get_energy_coeff(travel_dist_at_mic, mic_radius_sq);
306+
auto energy = scat_trans / coeff;
307+
histograms[mic_idx][get_bin_idx(travel_dist_at_mic)] += energy;
308+
}
309+
}
310+
}
311+
};
312+
313+
/**
314+
* @brief Compute energy histogram via ray tracing. See Python wrapper for
315+
* detail about parameters and output.
316+
*/
317+
torch::Tensor ray_tracing(
318+
const torch::Tensor& room,
319+
const torch::Tensor& source,
320+
const torch::Tensor& mic_array,
321+
int64_t num_rays,
322+
const torch::Tensor& absorption,
323+
const torch::Tensor& scattering,
324+
double mic_radius,
325+
double sound_speed,
326+
double energy_thres,
327+
double time_thres, // TODO: rename to duration
328+
double hist_bin_size) {
329+
// TODO: Raise this to Python layer
330+
auto num_bins = (int)ceil(time_thres / hist_bin_size);
331+
switch (room.size(0)) {
332+
case 2: {
333+
return AT_DISPATCH_FLOATING_TYPES(
334+
room.scalar_type(), "ray_tracing_2d", [&] {
335+
RayTracer<scalar_t, 2> rt(
336+
room, mic_array, absorption, scattering, mic_radius);
337+
return rt.compute_histograms(
338+
source,
339+
num_rays,
340+
time_thres,
341+
energy_thres,
342+
sound_speed,
343+
num_bins);
344+
});
345+
}
346+
case 3: {
347+
return AT_DISPATCH_FLOATING_TYPES(
348+
room.scalar_type(), "ray_tracing_3d", [&] {
349+
RayTracer<scalar_t, 3> rt(
350+
room, mic_array, absorption, scattering, mic_radius);
351+
return rt.compute_histograms(
352+
source,
353+
num_rays,
354+
time_thres,
355+
energy_thres,
356+
sound_speed,
357+
num_bins);
358+
});
359+
}
360+
default:
361+
TORCH_CHECK(false, "Only 2D and 3D are supported.");
362+
}
363+
}
364+
365+
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
366+
m.impl("torchaudio::ray_tracing", torchaudio::rir::ray_tracing);
367+
}
368+
369+
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
370+
m.def(
371+
"torchaudio::ray_tracing(Tensor room, Tensor source, Tensor mic_array, int num_rays, Tensor absorption, Tensor scattering, float mic_radius, float sound_speed, float energy_thres, float time_thres, float hist_bin_size) -> Tensor");
372+
}
373+
374+
} // namespace
375+
} // namespace rir
376+
} // namespace torchaudio

0 commit comments

Comments
 (0)