Skip to content

Commit 0ec7746

Browse files
committed
refactor code
1 parent 50de1e9 commit 0ec7746

File tree

8 files changed

+302
-92
lines changed

8 files changed

+302
-92
lines changed

.circleci/unittest/linux/scripts/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ fi
8080
(
8181
set -x
8282
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
83-
pip install kaldi-io SoundFile coverage pytest pytest-cov 'scipy==1.7.3' transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' demucs tinytag
83+
pip install kaldi-io SoundFile coverage pytest pytest-cov 'scipy==1.7.3' transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' demucs tinytag pyroomacoustics
8484
)
8585
# Install fairseq
8686
git clone https://github.com/pytorch/fairseq

.circleci/unittest/windows/scripts/install.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ esac
9090
unidecode \
9191
'protobuf<4.21.0' \
9292
demucs \
93-
tinytag
93+
tinytag \
94+
pyroomacoustics
9495
)
9596
# Install fairseq
9697
git clone https://github.com/pytorch/fairseq

docs/source/prototype.functional.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,8 @@ fftconvolve
1818
~~~~~~~~~~~
1919

2020
.. autofunction:: fftconvolve
21+
22+
simulate_rir_ism
23+
~~~~~~~~~~~~~~~~
24+
25+
.. autofunction:: simulate_rir_ism

test/torchaudio_unittest/prototype/functional/autograd_test_impl.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,21 @@ def test_add_noise(self):
3232
self.assertTrue(gradcheck(F.add_noise, (waveform, noise, lengths, snr)))
3333
self.assertTrue(gradgradcheck(F.add_noise, (waveform, noise, lengths, snr)))
3434

35-
def test_simulate_rir_ism(self):
36-
room = torch.tensor([9.0, 7.0, 3.0], dtype=self.dtype, device=self.device, requires_grad=True)
37-
mic_array = torch.tensor([0.1, 3.5, 1.5], dtype=self.dtype, device=self.device, requires_grad=True).reshape(1, -1).repeat(6,1)
38-
source = torch.tensor([8.8,3.5,1.5],dtype=self.dtype, device=self.device, requires_grad=True)
39-
max_order= 3
40-
e_absorption= torch.rand(7, 6, dtype=self.dtype, device=self.device, requires_grad=True)
41-
self.assertTrue(gradcheck(F.simulate_rir_ism, (room, source, mic_array, max_order, e_absorption), eps=1e-2, atol=1e-2))
42-
self.assertTrue(gradgradcheck(F.simulate_rir_ism, (room, source, mic_array, max_order, e_absorption), eps=1e-2, atol=1e-2))
35+
@parameterized.expand([(2, 1), (3, 4)])
36+
def test_simulate_rir_ism(self, D, channel):
37+
room = torch.rand(D, dtype=self.dtype, device=self.device, requires_grad=True)
38+
mic_array = torch.rand(channel, D, dtype=self.dtype, device=self.device, requires_grad=True)
39+
source = torch.rand(D, dtype=self.dtype, device=self.device, requires_grad=True)
40+
max_order = 2
41+
e_absorption = 0.5
42+
output_length = 1000
43+
self.assertTrue(
44+
gradcheck(
45+
F.simulate_rir_ism, (room, source, mic_array, max_order, e_absorption, output_length), atol=1e-3, rtol=1
46+
)
47+
)
48+
self.assertTrue(
49+
gradgradcheck(
50+
F.simulate_rir_ism, (room, source, mic_array, max_order, e_absorption, output_length), atol=1e-3, rtol=1
51+
)
52+
)

test/torchaudio_unittest/prototype/functional/functional_test_impl.py

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torchaudio.prototype.functional as F
55
from parameterized import parameterized
66
from scipy import signal
7-
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
7+
from torchaudio_unittest.common_utils import nested_params, skipIfNoModule, TestBaseMixin
88

99

1010
class FunctionalTestImpl(TestBaseMixin):
@@ -109,36 +109,76 @@ def test_add_noise_length_check(self):
109109
with self.assertRaisesRegex(ValueError, "Length dimensions"):
110110
F.add_noise(waveform, noise, lengths, snr)
111111

112-
def test_simulate_rir_ism(self):
113-
room_dim = torch.tensor([9.0, 9.0, 9.0], dtype=self.dtype, device=self.device, requires_grad=True)
114-
mic_array = torch.tensor([1, 1, 1], dtype=self.dtype, device=self.device, requires_grad=True).reshape(1, -1).repeat(6,1)
115-
source = torch.tensor([7,7,7],dtype=self.dtype, device=self.device, requires_grad=True)
116-
max_order= 3
117-
e_absorption= torch.rand(7, 6, dtype=self.dtype, device=self.device, requires_grad=True)
118-
walls = ["west", "east", "south", "north", "floor", "ceiling"]
119-
room2= pra.ShoeBox(
112+
@skipIfNoModule("pyroomacoustics")
113+
@parameterized.expand([(2, 1), (3, 4)])
114+
def test_simulate_rir_ism_single_band(self, D, channel):
115+
"""Test simulate_rir_ism when absorption coefficients are identical for all walls."""
116+
room_dim = torch.rand(D, dtype=self.dtype, device=self.device) + 10
117+
mic_array = torch.rand(channel, D, dtype=self.dtype, device=self.device)
118+
source = torch.rand(D, dtype=self.dtype, device=self.device)
119+
max_order = 3
120+
e_absorption = 0.5
121+
room = pra.ShoeBox(
122+
room_dim.detach().numpy(),
123+
fs=16000,
124+
materials=pra.Material(e_absorption),
125+
max_order=max_order,
126+
ray_tracing=False,
127+
air_absorption=False,
128+
)
129+
mic_locs = np.asarray([mic_array[i].tolist() for i in range(channel)]).swapaxes(0, 1)
130+
room.add_microphone_array(mic_locs)
131+
room.add_source(source.tolist())
132+
room.compute_rir()
133+
max_len = max([room.rir[i][0].shape[0] for i in range(channel)])
134+
actual = torch.zeros(channel, max_len, dtype=self.dtype, device=self.device)
135+
for i in range(channel):
136+
actual[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0])
137+
expected = F.simulate_rir_ism(room_dim, source, mic_array, max_order, e_absorption)
138+
self.assertEqual(expected, actual, atol=4e-4, rtol=2)
139+
140+
@skipIfNoModule("pyroomacoustics")
141+
@parameterized.expand([(2, 1), (3, 4)])
142+
def test_simulate_rir_ism_multi_band(self, D, channel):
143+
"""Test simulate_rir_ism when absorption coefficients are different for all walls."""
144+
room_dim = torch.rand(D, dtype=self.dtype, device=self.device) + 10
145+
mic_array = torch.rand(channel, D, dtype=self.dtype, device=self.device)
146+
source = torch.rand(D, dtype=self.dtype, device=self.device)
147+
max_order = 3
148+
if D == 2:
149+
e_absorption = torch.rand(7, 4, dtype=self.dtype, device=self.device)
150+
walls = ["west", "east", "south", "north"]
151+
else:
152+
e_absorption = torch.rand(7, 6, dtype=self.dtype, device=self.device)
153+
walls = ["west", "east", "south", "north", "floor", "ceiling"]
154+
room = pra.ShoeBox(
120155
room_dim.detach().numpy(),
121156
fs=16000,
122157
materials={
123-
walls[i] : pra.Material(
158+
walls[i]: pra.Material(
124159
{
125-
"coeffs": e_absorption[:, i].reshape(-1,).detach().numpy(),
160+
"coeffs": e_absorption[:, i]
161+
.reshape(
162+
-1,
163+
)
164+
.detach()
165+
.numpy(),
126166
"center_freqs": [125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0],
127167
}
128-
) for i in range(len(walls))
168+
)
169+
for i in range(len(walls))
129170
},
130171
max_order=max_order,
131172
ray_tracing=False,
132173
air_absorption=False,
133174
)
134-
mic_locs = np.asarray(
135-
[[1.0,1.0,1.0]for _ in range(6)] # mic 1
136-
).swapaxes(0,1)
137-
room2.add_microphone_array(mic_locs)
138-
room2.add_source([7.0,7.0,7.0])
139-
room2.compute_rir()
140-
actual = torch.concat([torch.tensor(room2.rir[0]) for i in range(6)]).to(self.dtype)
175+
mic_locs = np.asarray([mic_array[i].tolist() for i in range(channel)]).swapaxes(0, 1)
176+
room.add_microphone_array(mic_locs)
177+
room.add_source(source.tolist())
178+
room.compute_rir()
179+
max_len = max([room.rir[i][0].shape[0] for i in range(channel)])
180+
actual = torch.zeros(channel, max_len, dtype=self.dtype, device=self.device)
181+
for i in range(channel):
182+
actual[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0])
141183
expected = F.simulate_rir_ism(room_dim, source, mic_array, max_order, e_absorption)
142-
self.assertEqual(expected, actual)
143-
144-
184+
self.assertEqual(expected, actual, atol=4e-4, rtol=2)

test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,27 @@ def test_add_noise(self):
4848
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
4949

5050
self._assert_consistency(F.add_noise, (waveform, noise, lengths, snr))
51+
52+
def test_simulate_rir_ism_single_band(self):
53+
room_dim = torch.tensor([9.0, 9.0, 9.0], dtype=self.dtype, device=self.device)
54+
mic_array = torch.tensor([1, 1, 1], dtype=self.dtype, device=self.device).reshape(1, -1).repeat(6, 1)
55+
source = torch.tensor([7, 7, 7], dtype=self.dtype, device=self.device)
56+
max_order = 3
57+
e_absorption = 0.5
58+
center_frequency = torch.tensor([125, 250, 500, 1000, 2000, 4000, 8000], dtype=self.dtype, device=self.device)
59+
self._assert_consistency(
60+
F.simulate_rir_ism,
61+
(room_dim, source, mic_array, max_order, e_absorption, 1000, 81, center_frequency, 343.0, 16000.0),
62+
)
63+
64+
def test_simulate_rir_ism_multi_band(self):
65+
room_dim = torch.tensor([9.0, 9.0, 9.0], dtype=self.dtype, device=self.device)
66+
mic_array = torch.tensor([1, 1, 1], dtype=self.dtype, device=self.device).reshape(1, -1).repeat(6, 1)
67+
source = torch.tensor([7, 7, 7], dtype=self.dtype, device=self.device)
68+
max_order = 3
69+
e_absorption = torch.rand(7, 6, dtype=self.dtype, device=self.device)
70+
center_frequency = torch.tensor([125, 250, 500, 1000, 2000, 4000, 8000], dtype=self.dtype, device=self.device)
71+
self._assert_consistency(
72+
F.simulate_rir_ism,
73+
(room_dim, source, mic_array, max_order, e_absorption, 1000, 81, center_frequency, 343.0, 16000.0),
74+
)

torchaudio/csrc/build_rir.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,17 @@ torch::Tensor build_rir(
5454
return rirs;
5555
}
5656

57-
torch::Tensor make_filter(
58-
torch::Tensor centers,
57+
template <typename scalar_t>
58+
void make_filter_impl(
59+
torch::Tensor& centers,
5960
double sample_rate,
60-
int64_t n_fft) {
61+
int64_t n_fft,
62+
torch::Tensor& filters) {
6163
int64_t n = centers.size(0);
62-
torch::Tensor new_bands = torch::zeros({n, 2});
64+
torch::Tensor new_bands = torch::zeros({n, 2}, centers.dtype());
6365
new_bands.requires_grad_(true);
64-
float* newband_data = new_bands.data_ptr<float>();
65-
const float* centers_data = centers.data_ptr<float>();
66+
scalar_t* newband_data = new_bands.data_ptr<scalar_t>();
67+
const scalar_t* centers_data = centers.data_ptr<scalar_t>();
6668
at::parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
6769
for (int64_t i = start; i < end; i++) {
6870
if (i == 0) {
@@ -78,10 +80,11 @@ torch::Tensor make_filter(
7880
}
7981
});
8082
auto n_freq = n_fft / 2 + 1;
81-
torch::Tensor freq_resp = torch::zeros({n_freq, n});
82-
torch::Tensor freq = torch::arange(n_freq) / n_fft * sample_rate;
83-
const float* freq_data = freq.data_ptr<float>();
84-
float* freqreq_data = freq_resp.data_ptr<float>();
83+
torch::Tensor freq_resp = torch::zeros({n_freq, n}, centers.dtype());
84+
torch::Tensor freq =
85+
torch::arange(n_freq, centers.dtype()) / n_fft * sample_rate;
86+
const scalar_t* freq_data = freq.data_ptr<scalar_t>();
87+
scalar_t* freqreq_data = freq_resp.data_ptr<scalar_t>();
8588

8689
at::parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
8790
at::parallel_for(0, n_freq, 0, [&](int64_t start2, int64_t end2) {
@@ -104,9 +107,20 @@ torch::Tensor make_filter(
104107
}
105108
});
106109
});
107-
torch::Tensor filters =
108-
torch::fft::fftshift(torch::fft::irfft(freq_resp, n_fft, 0), 0);
109-
return filters.index({Slice(1)}).transpose(0, 1);
110+
filters = torch::fft::fftshift(torch::fft::irfft(freq_resp, n_fft, 0), 0);
111+
filters = filters.index({Slice(1)}).transpose(0, 1);
112+
}
113+
114+
torch::Tensor make_filter(
115+
torch::Tensor centers,
116+
double sample_rate,
117+
int64_t n_fft) {
118+
torch::Tensor filters;
119+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
120+
centers.scalar_type(), "make_filter", [&] {
121+
make_filter_impl<scalar_t>(centers, sample_rate, n_fft, filters);
122+
});
123+
return filters;
110124
}
111125

112126
TORCH_LIBRARY(rir, m) {

0 commit comments

Comments
 (0)