Skip to content

Commit 50de1e9

Browse files
committed
add simulate_rir_ism method
1 parent 76fca37 commit 50de1e9

File tree

7 files changed

+292
-1
lines changed

7 files changed

+292
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ endif()
5858
# Options
5959
option(BUILD_SOX "Build libsox statically" ON)
6060
option(BUILD_KALDI "Build kaldi statically" ON)
61+
option(BUILD_RIR "Enable RIR simulation" ON)
6162
option(BUILD_RNNT "Enable RNN transducer" ON)
6263
option(BUILD_CTC_DECODER "Build Flashlight CTC decoder" ON)
6364
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)

test/torchaudio_unittest/prototype/functional/autograd_test_impl.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,12 @@ def test_add_noise(self):
3131

3232
self.assertTrue(gradcheck(F.add_noise, (waveform, noise, lengths, snr)))
3333
self.assertTrue(gradgradcheck(F.add_noise, (waveform, noise, lengths, snr)))
34+
35+
def test_simulate_rir_ism(self):
36+
room = torch.tensor([9.0, 7.0, 3.0], dtype=self.dtype, device=self.device, requires_grad=True)
37+
mic_array = torch.tensor([0.1, 3.5, 1.5], dtype=self.dtype, device=self.device, requires_grad=True).reshape(1, -1).repeat(6,1)
38+
source = torch.tensor([8.8,3.5,1.5],dtype=self.dtype, device=self.device, requires_grad=True)
39+
max_order= 3
40+
e_absorption= torch.rand(7, 6, dtype=self.dtype, device=self.device, requires_grad=True)
41+
self.assertTrue(gradcheck(F.simulate_rir_ism, (room, source, mic_array, max_order, e_absorption), eps=1e-2, atol=1e-2))
42+
self.assertTrue(gradgradcheck(F.simulate_rir_ism, (room, source, mic_array, max_order, e_absorption), eps=1e-2, atol=1e-2))

test/torchaudio_unittest/prototype/functional/functional_test_impl.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pyroomacoustics as pra
23
import torch
34
import torchaudio.prototype.functional as F
45
from parameterized import parameterized
@@ -107,3 +108,37 @@ def test_add_noise_length_check(self):
107108

108109
with self.assertRaisesRegex(ValueError, "Length dimensions"):
109110
F.add_noise(waveform, noise, lengths, snr)
111+
112+
def test_simulate_rir_ism(self):
113+
room_dim = torch.tensor([9.0, 9.0, 9.0], dtype=self.dtype, device=self.device, requires_grad=True)
114+
mic_array = torch.tensor([1, 1, 1], dtype=self.dtype, device=self.device, requires_grad=True).reshape(1, -1).repeat(6,1)
115+
source = torch.tensor([7,7,7],dtype=self.dtype, device=self.device, requires_grad=True)
116+
max_order= 3
117+
e_absorption= torch.rand(7, 6, dtype=self.dtype, device=self.device, requires_grad=True)
118+
walls = ["west", "east", "south", "north", "floor", "ceiling"]
119+
room2= pra.ShoeBox(
120+
room_dim.detach().numpy(),
121+
fs=16000,
122+
materials={
123+
walls[i] : pra.Material(
124+
{
125+
"coeffs": e_absorption[:, i].reshape(-1,).detach().numpy(),
126+
"center_freqs": [125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0],
127+
}
128+
) for i in range(len(walls))
129+
},
130+
max_order=max_order,
131+
ray_tracing=False,
132+
air_absorption=False,
133+
)
134+
mic_locs = np.asarray(
135+
[[1.0,1.0,1.0]for _ in range(6)] # mic 1
136+
).swapaxes(0,1)
137+
room2.add_microphone_array(mic_locs)
138+
room2.add_source([7.0,7.0,7.0])
139+
room2.compute_rir()
140+
actual = torch.concat([torch.tensor(room2.rir[0]) for i in range(6)]).to(self.dtype)
141+
expected = F.simulate_rir_ism(room_dim, source, mic_array, max_order, e_absorption)
142+
self.assertEqual(expected, actual)
143+
144+

torchaudio/csrc/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ if(BUILD_RNNT)
5353
endif()
5454
endif()
5555

56+
if(BUILD_RIR)
57+
list(
58+
APPEND
59+
LIBTORCHAUDIO_SOURCES
60+
build_rir.cpp
61+
)
62+
endif()
63+
5664
if(USE_CUDA)
5765
list(
5866
APPEND

torchaudio/csrc/build_rir.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include <math.h>
2+
#include <torch/script.h>
3+
#include <torch/torch.h>
4+
using namespace torch::indexing;
5+
6+
namespace torchaudio {
7+
namespace rir {
8+
9+
template <typename scalar_t>
10+
void build_rir_impl(
11+
const torch::Tensor& irs,
12+
const torch::Tensor& delay,
13+
torch::Tensor& rirs,
14+
const int64_t rir_length,
15+
const int64_t num_band,
16+
const int64_t num_image,
17+
const int64_t num_mic,
18+
const int64_t ir_length) {
19+
const scalar_t* input_data = irs.data_ptr<scalar_t>();
20+
const int* delay_data = delay.data_ptr<int>();
21+
scalar_t* output_data = rirs.data_ptr<scalar_t>();
22+
at::parallel_for(
23+
0, num_band * num_image * num_mic, 0, [&](int64_t start, int64_t end) {
24+
for (auto i = start; i < end; i++) {
25+
int64_t offset_input = i * ir_length;
26+
int64_t mic = i % num_mic;
27+
int64_t image = ((i - mic) / num_mic) % num_image;
28+
int64_t band = (i - mic - image * num_mic) / (num_image * num_mic);
29+
int64_t offset_output = (band * num_mic + mic) * rir_length;
30+
int64_t offset_delay = image * num_mic + mic;
31+
for (auto j = 0; j < ir_length; j++) {
32+
output_data[offset_output + j + delay_data[offset_delay]] +=
33+
input_data[offset_input + j];
34+
}
35+
}
36+
});
37+
}
38+
39+
torch::Tensor build_rir(
40+
const torch::Tensor irs,
41+
const torch::Tensor delay,
42+
const int64_t rir_length) {
43+
const int64_t num_band = irs.size(0);
44+
const int64_t num_image = irs.size(1);
45+
const int64_t num_mic = irs.size(2);
46+
const int64_t ir_length = irs.size(3);
47+
torch::Tensor rirs =
48+
torch::zeros({num_band, num_mic, rir_length}, irs.dtype());
49+
rirs.requires_grad_(true);
50+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(irs.scalar_type(), "build_rir", [&] {
51+
build_rir_impl<scalar_t>(
52+
irs, delay, rirs, rir_length, num_band, num_image, num_mic, ir_length);
53+
});
54+
return rirs;
55+
}
56+
57+
torch::Tensor make_filter(
58+
torch::Tensor centers,
59+
double sample_rate,
60+
int64_t n_fft) {
61+
int64_t n = centers.size(0);
62+
torch::Tensor new_bands = torch::zeros({n, 2});
63+
new_bands.requires_grad_(true);
64+
float* newband_data = new_bands.data_ptr<float>();
65+
const float* centers_data = centers.data_ptr<float>();
66+
at::parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
67+
for (int64_t i = start; i < end; i++) {
68+
if (i == 0) {
69+
newband_data[i * 2] = centers_data[0] / 2;
70+
newband_data[i * 2 + 1] = centers_data[1];
71+
} else if (i == n - 1) {
72+
newband_data[i * 2] = centers_data[n - 2];
73+
newband_data[i * 2 + 1] = sample_rate / 2;
74+
} else {
75+
newband_data[i * 2] = centers_data[i - 1];
76+
newband_data[i * 2 + 1] = centers_data[i + 1];
77+
}
78+
}
79+
});
80+
auto n_freq = n_fft / 2 + 1;
81+
torch::Tensor freq_resp = torch::zeros({n_freq, n});
82+
torch::Tensor freq = torch::arange(n_freq) / n_fft * sample_rate;
83+
const float* freq_data = freq.data_ptr<float>();
84+
float* freqreq_data = freq_resp.data_ptr<float>();
85+
86+
at::parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
87+
at::parallel_for(0, n_freq, 0, [&](int64_t start2, int64_t end2) {
88+
for (auto i = start; i < end; i++) {
89+
for (auto j = start2; j < end2; j++) {
90+
if (freq_data[j] >= newband_data[i * 2] &&
91+
freq_data[j] < centers_data[i]) {
92+
freqreq_data[j * n + i] =
93+
0.5 * (1 + cos(2 * M_PI * freq_data[j] / centers_data[i]));
94+
}
95+
if (i != n - 1 && freq_data[j] >= centers_data[i] &&
96+
freq_data[j] < newband_data[i * 2 + 1]) {
97+
freqreq_data[j * n + i] = 0.5 *
98+
(1 - cos(2 * M_PI * freq_data[j] / newband_data[i * 2 + 1]));
99+
}
100+
if (i == n - 1 && centers_data[i] <= freq_data[j]) {
101+
freqreq_data[j * n + i] = 1.0;
102+
}
103+
}
104+
}
105+
});
106+
});
107+
torch::Tensor filters =
108+
torch::fft::fftshift(torch::fft::irfft(freq_resp, n_fft, 0), 0);
109+
return filters.index({Slice(1)}).transpose(0, 1);
110+
}
111+
112+
TORCH_LIBRARY(rir, m) {
113+
m.def(
114+
"rir::build_rir(Tensor irs, Tensor delay_i, int rir_length) -> Tensor",
115+
&torchaudio::rir::build_rir);
116+
m.def("rir::make_filter", &torchaudio::rir::make_filter);
117+
}
118+
119+
} // namespace rir
120+
} // namespace torchaudio
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .functional import add_noise, convolve, fftconvolve
2+
from .rir import simulate_rir_ism
23

3-
__all__ = ["add_noise", "convolve", "fftconvolve"]
4+
__all__ = ["add_noise", "convolve", "fftconvolve", "simulate_rir_ism"]
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import math
2+
from random import sample
3+
from typing import Union
4+
5+
import torch
6+
import torchaudio
7+
from torch import Tensor
8+
9+
_CENTER_FREQUENCY = torch.tensor([125, 250, 500, 1000, 2000, 4000, 8000], dtype=torch.float)
10+
11+
12+
def _compute_image_sources(room, source, max_order, e_abs, e_scatter=None):
13+
if e_scatter is None:
14+
e_scatter = torch.zeros_like(e_abs)
15+
# reflection coefficients
16+
tr = torch.sqrt(1 - e_abs) * torch.sqrt(1 - e_scatter)
17+
18+
ind = torch.arange(-max_order, max_order + 1, device=source.device)
19+
XYZ = torch.meshgrid(ind, ind, ind, indexing="ij")
20+
XYZ = torch.stack([c.reshape((-1,)) for c in XYZ], dim=-1)
21+
XYZ = XYZ[XYZ.abs().sum(dim=-1) <= max_order]
22+
23+
# location of image sources
24+
d = room[None, :]
25+
s = source[None, :]
26+
img_loc = torch.where(XYZ % 2 == 1, d * (XYZ + 1) - s, d * XYZ + s)
27+
28+
# attenuation
29+
exp_lo = abs(torch.floor(XYZ / 2))
30+
exp_hi = abs(torch.floor((XYZ + 1) / 2))
31+
t_lo = tr[:, ::2].unsqueeze(1).repeat(1, XYZ.shape[0], 1) # num_band, left walls
32+
t_hi = tr[:, 1::2].unsqueeze(1).repeat(1, XYZ.shape[0], 1) # num_band, right walls
33+
att = torch.prod((t_lo**exp_lo) * (t_hi**exp_hi), dim=-1) # num_band, num_image_source
34+
return img_loc, att
35+
36+
37+
def _hann(x, T):
38+
"""Compute he Hann window."""
39+
y = torch.where(
40+
torch.abs(x) <= T / 2,
41+
0.5 * (1 + torch.cos(2 * math.pi * x / T)),
42+
x.new_zeros(1),
43+
)
44+
return y
45+
46+
47+
def _frac_delay(tau, filter_len=41):
48+
if filter_len % 2 != 1:
49+
raise ValueError("The filter length must be odd")
50+
51+
pad = filter_len // 2
52+
n = torch.arange(-pad, pad + 1, device=tau.device)
53+
tau = tau[..., None]
54+
55+
return torch.special.sinc(n - tau) * _hann(n - tau, 2 * pad)
56+
57+
58+
def simulate_rir_ism(
59+
room: Tensor,
60+
source: Tensor,
61+
mic_array: Tensor,
62+
max_order: int,
63+
e_absorption: Union[float, Tensor],
64+
sound_speed: float = 343.0,
65+
sample_rate: float = 16000.0,
66+
) -> Tensor:
67+
"""Compute Room Impulse Response (RIR) based on image source method.
68+
69+
Args:
70+
room (torch.Tensor): The 1D Tensor to determine the room size. The shape is
71+
`(D,)`, where D is 2 if room is a 2D room, or 3 if room is a 3D room.
72+
source (torch.Tensor): The coordinate of the sound source. Tensor with dimensions
73+
`(D)`.
74+
mic_array (torch.Tensor): The coordinate of microphone array. Tensor with dimensions
75+
`(channel, D)`.
76+
max_order (int): The maximum order of relfections of image sources.
77+
e_absorption (float or torch.Tensor): The absorption coefficients of wall materials.
78+
If the dtype is ``float``, the absorption coefficient is identical to all walls and
79+
all frequencies.
80+
If ``e_absorption`` is a 1D Tensor, the shape must be `(4)` if the room is a 2D room,
81+
or `(6)` if the room is a 3D room, where 4 represents 4 walls, 6 represents 4 walls,
82+
ceiling, and floor.
83+
If ``e_absorption`` is a 2D Tensor, the shape must be `(4, 7)` if the room is a 2D room,
84+
or `(6, 7)` if the room is a 3D room, where 7 represents the number of frequency bands.
85+
sound_speed (float): The speed of sound. (Default: ``343.0``)
86+
sample_rate (float): The sample rate of the generated room impulse response signal.
87+
(Default: ``16000.0``)
88+
89+
Returns:
90+
(torch.Tensor): The simulated room impulse response waveform. Tensor with dimensions
91+
`(channel, rir_length)`.
92+
"""
93+
if isinstance(e_absorption, float):
94+
e_absorption = torch.ones(1, 6) * e_absorption
95+
96+
img_location, att = _compute_image_sources(room, source, max_order, e_absorption)
97+
vec = img_location[:, None, :] - mic_array[None, :, :]
98+
99+
dist = torch.linalg.norm(vec, dim=-1) # (num_band, n_img, n_mics)
100+
101+
img_src_att = att[..., None] / dist[None, ...] # (n_band, n_img_src, n_mics)
102+
103+
# separate delays in integer / frac part
104+
delay = dist / sound_speed * sample_rate # distance to delay in samples
105+
delay_i = torch.round(delay) # integer part
106+
delay_f = delay - delay_i # frac part, in [-0.5, 0.5)
107+
108+
# compute the shorts IRs corresponding to each image source
109+
irs = img_src_att[..., None] * _frac_delay(delay_f, filter_len=81)[None, ...]
110+
111+
rir_length = int(delay_i.max() + irs.shape[-1])
112+
rir = torch.ops.rir.build_rir(irs, delay_i.type(torch.int32), rir_length)
113+
if rir.shape[0] > 1:
114+
filters = torch.ops.rir.make_filter(_CENTER_FREQUENCY.to(room.device), sample_rate, 512)
115+
rir = torchaudio.prototype.functional.fftconvolve(rir, filters.unsqueeze(1).repeat(1, rir.shape[1], 1))
116+
rir = rir[..., (filters.shape[-1]-1) // 2 : -(filters.shape[-1]-1) // 2]
117+
return rir.sum(0)

0 commit comments

Comments
 (0)