Skip to content

Commit 4406a6b

Browse files
authored
Add AMB/AMR-NB/AMR-WB support to "sox_io" backend (#1066)
1 parent 2a02d7f commit 4406a6b

File tree

8 files changed

+207
-13
lines changed

8 files changed

+207
-13
lines changed

build_tools/setup_helpers/extension.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def _get_extra_objects():
8989
'libvorbisfile.a',
9090
'libvorbis.a',
9191
'libogg.a',
92+
'libopencore-amrnb.a',
93+
'libopencore-amrwb.a',
9294
]
9395
for lib in libs:
9496
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))

test/torchaudio_unittest/sox_io_backend/info_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,36 @@ def test_sphere(self, sample_rate, num_channels):
122122
assert info.num_frames == sample_rate * duration
123123
assert info.num_channels == num_channels
124124

125+
@parameterized.expand(list(itertools.product(
126+
['float32', 'int32', 'int16', 'uint8'],
127+
[8000, 16000],
128+
[1, 2],
129+
)), name_func=name_func)
130+
def test_amb(self, dtype, sample_rate, num_channels):
131+
"""`sox_io_backend.info` can check amb file correctly"""
132+
duration = 1
133+
path = self.get_temp_path('data.amb')
134+
sox_utils.gen_audio_file(
135+
path, sample_rate, num_channels,
136+
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
137+
info = sox_io_backend.info(path)
138+
assert info.sample_rate == sample_rate
139+
assert info.num_frames == sample_rate * duration
140+
assert info.num_channels == num_channels
141+
142+
def test_amr_nb(self):
143+
"""`sox_io_backend.info` can check amr-nb file correctly"""
144+
duration = 1
145+
num_channels = 1
146+
sample_rate = 8000
147+
path = self.get_temp_path('data.amr-nb')
148+
sox_utils.gen_audio_file(
149+
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration)
150+
info = sox_io_backend.info(path)
151+
assert info.sample_rate == sample_rate
152+
assert info.num_frames == sample_rate * duration
153+
assert info.num_channels == num_channels
154+
125155

126156
@skipIfNoExtension
127157
class TestInfoOpus(PytorchTestCase):

test/torchaudio_unittest/sox_io_backend/load_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,53 @@ def assert_sphere(self, sample_rate, num_channels, duration):
142142
assert sr == sample_rate
143143
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
144144

145+
def assert_amb(self, dtype, sample_rate, num_channels, normalize, duration):
146+
"""`sox_io_backend.load` can load amb format.
147+
148+
This test takes the same strategy as mp3 to compare the result
149+
"""
150+
path = self.get_temp_path('1.original.amb')
151+
ref_path = self.get_temp_path('2.reference.wav')
152+
153+
# 1. Generate amb with sox
154+
sox_utils.gen_audio_file(
155+
path, sample_rate, num_channels,
156+
encoding=sox_utils.get_encoding(dtype),
157+
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
158+
# 2. Convert to wav with sox
159+
sox_utils.convert_audio_file(path, ref_path)
160+
# 3. Load amb with torchaudio
161+
data, sr = sox_io_backend.load(path, normalize=normalize)
162+
# 4. Load wav with scipy
163+
data_ref = load_wav(ref_path, normalize=normalize)[0]
164+
# 5. Compare
165+
assert sr == sample_rate
166+
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
167+
168+
def assert_amr_nb(self, duration):
169+
"""`sox_io_backend.load` can load amr-nb format.
170+
171+
This test takes the same strategy as mp3 to compare the result
172+
"""
173+
sample_rate = 8000
174+
num_channels = 1
175+
path = self.get_temp_path('1.original.amr-nb')
176+
ref_path = self.get_temp_path('2.reference.wav')
177+
178+
# 1. Generate amr-nb with sox
179+
sox_utils.gen_audio_file(
180+
path, sample_rate, num_channels,
181+
bit_depth=32, duration=duration)
182+
# 2. Convert to wav with sox
183+
sox_utils.convert_audio_file(path, ref_path)
184+
# 3. Load amr-nb with torchaudio
185+
data, sr = sox_io_backend.load(path)
186+
# 4. Load wav with scipy
187+
data_ref = load_wav(ref_path)[0]
188+
# 5. Compare
189+
assert sr == sample_rate
190+
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
191+
145192

146193
@skipIfNoExec('sox')
147194
@skipIfNoExtension
@@ -260,6 +307,20 @@ def test_sphere(self, sample_rate, num_channels):
260307
"""`sox_io_backend.load` can load sph format correctly."""
261308
self.assert_sphere(sample_rate, num_channels, duration=1)
262309

310+
@parameterized.expand(list(itertools.product(
311+
['float32', 'int32', 'int16'],
312+
[8000, 16000],
313+
[1, 2],
314+
[False, True],
315+
)), name_func=name_func)
316+
def test_amb(self, dtype, sample_rate, num_channels, normalize):
317+
"""`sox_io_backend.load` can load sph format correctly."""
318+
self.assert_amb(dtype, sample_rate, num_channels, normalize, duration=1)
319+
320+
def test_amr_nb(self):
321+
"""`sox_io_backend.load` can load amr_nb format correctly."""
322+
self.assert_amr_nb(duration=1)
323+
263324

264325
@skipIfNoExec('sox')
265326
@skipIfNoExtension

test/torchaudio_unittest/sox_io_backend/save_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,68 @@ def assert_sphere(self, sample_rate, num_channels, duration):
200200

201201
self.assertEqual(found, expected)
202202

203+
def assert_amb(self, dtype, sample_rate, num_channels, duration):
204+
"""`sox_io_backend.save` can save amb format.
205+
206+
This test takes the same strategy as mp3 to compare the result
207+
"""
208+
src_path = self.get_temp_path('1.reference.wav')
209+
amb_path = self.get_temp_path('2.1.torchaudio.amb')
210+
wav_path = self.get_temp_path('2.2.torchaudio.wav')
211+
amb_path_sox = self.get_temp_path('3.1.sox.amb')
212+
wav_path_sox = self.get_temp_path('3.2.sox.wav')
213+
214+
# 1. Generate original wav
215+
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
216+
save_wav(src_path, data, sample_rate)
217+
# 2.1. Convert the original wav to amb with torchaudio
218+
sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate)
219+
# 2.2. Convert the amb to wav with Sox
220+
sox_utils.convert_audio_file(amb_path, wav_path)
221+
# 2.3. Load
222+
found = load_wav(wav_path)[0]
223+
224+
# 3.1. Convert the original wav to amb with SoX
225+
sox_utils.convert_audio_file(src_path, amb_path_sox)
226+
# 3.2. Convert the amb to wav with Sox
227+
sox_utils.convert_audio_file(amb_path_sox, wav_path_sox)
228+
# 3.3. Load
229+
expected = load_wav(wav_path_sox)[0]
230+
231+
self.assertEqual(found, expected)
232+
233+
def assert_amr_nb(self, duration):
234+
"""`sox_io_backend.save` can save amr_nb format.
235+
236+
This test takes the same strategy as mp3 to compare the result
237+
"""
238+
sample_rate = 8000
239+
num_channels = 1
240+
src_path = self.get_temp_path('1.reference.wav')
241+
amr_path = self.get_temp_path('2.1.torchaudio.amr-nb')
242+
wav_path = self.get_temp_path('2.2.torchaudio.wav')
243+
amr_path_sox = self.get_temp_path('3.1.sox.amr-nb')
244+
wav_path_sox = self.get_temp_path('3.2.sox.wav')
245+
246+
# 1. Generate original wav
247+
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
248+
save_wav(src_path, data, sample_rate)
249+
# 2.1. Convert the original wav to amr_nb with torchaudio
250+
sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate)
251+
# 2.2. Convert the amr_nb to wav with Sox
252+
sox_utils.convert_audio_file(amr_path, wav_path)
253+
# 2.3. Load
254+
found = load_wav(wav_path)[0]
255+
256+
# 3.1. Convert the original wav to amr_nb with SoX
257+
sox_utils.convert_audio_file(src_path, amr_path_sox)
258+
# 3.2. Convert the amr_nb to wav with Sox
259+
sox_utils.convert_audio_file(amr_path_sox, wav_path_sox)
260+
# 3.3. Load
261+
expected = load_wav(wav_path_sox)[0]
262+
263+
self.assertEqual(found, expected)
264+
203265

204266
@skipIfNoExec('sox')
205267
@skipIfNoExtension
@@ -302,6 +364,19 @@ def test_sphere(self, sample_rate, num_channels):
302364
"""`sox_io_backend.save` can save sph format."""
303365
self.assert_sphere(sample_rate, num_channels, duration=1)
304366

367+
@parameterized.expand(list(itertools.product(
368+
['float32', 'int32', 'int16', 'uint8'],
369+
[8000, 16000],
370+
[1, 2],
371+
)), name_func=name_func)
372+
def test_amb(self, dtype, sample_rate, num_channels):
373+
"""`sox_io_backend.save` can save amb format."""
374+
self.assert_amb(dtype, sample_rate, num_channels, duration=1)
375+
376+
def test_amr_nb(self):
377+
"""`sox_io_backend.save` can save amr-nb format."""
378+
self.assert_amr_nb(duration=1)
379+
305380

306381
@skipIfNoExec('sox')
307382
@skipIfNoExtension

third_party/CMakeLists.txt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ ExternalProject_Add(libmad
1616
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/src/libmad/configure ${COMMON_ARGS}
1717
)
1818

19+
ExternalProject_Add(amr
20+
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
21+
DOWNLOAD_DIR ${ARCHIVE_DIR}
22+
URL https://sourceforge.net/projects/opencore-amr/files/opencore-amr/opencore-amr-0.1.5.tar.gz
23+
URL_HASH SHA256=2c006cb9d5f651bfb5e60156dbff6af3c9d35c7bbcc9015308c0aff1e14cd341
24+
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/src/amr/configure ${COMMON_ARGS}
25+
)
26+
1927
ExternalProject_Add(libmp3lame
2028
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
2129
DOWNLOAD_DIR ${ARCHIVE_DIR}
@@ -72,11 +80,11 @@ ExternalProject_Add(opusfile
7280

7381
ExternalProject_Add(libsox
7482
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
75-
DEPENDS libogg libflac libvorbis opusfile libmp3lame libmad
83+
DEPENDS libogg libflac libvorbis opusfile libmp3lame libmad amr
7684
DOWNLOAD_DIR ${ARCHIVE_DIR}
7785
URL https://downloads.sourceforge.net/project/sox/sox/14.4.2/sox-14.4.2.tar.bz2
7886
URL_HASH SHA256=81a6956d4330e75b5827316e44ae381e6f1e8928003c6aa45896da9041ea149c
7987
# OpenMP is by default compiled against GNU OpenMP, which conflicts with the version of OpenMP that PyTorch uses.
8088
# See https://github.com/pytorch/audio/pull/1026
81-
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --disable-openmp
89+
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp
8290
)

torchaudio/backend/sox_io_backend.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,19 @@ def load(
4040
This function can handle all the codecs that underlying libsox can handle,
4141
however it is tested on the following formats;
4242
43-
* WAV
43+
* WAV, AMB
4444
4545
* 32-bit floating-point
4646
* 32-bit signed integer
4747
* 16-bit signed integer
48-
* 8-bit unsigned integer
48+
* 8-bit unsigned integer (WAV only)
4949
5050
* MP3
5151
* FLAC
5252
* OGG/VORBIS
5353
* OPUS
5454
* SPHERE
55+
* AMR-NB
5556
5657
To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
5758
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
@@ -119,7 +120,7 @@ def save(
119120
Note:
120121
Supported formats are;
121122
122-
* WAV
123+
* WAV, AMB
123124
124125
* 32-bit floating-point
125126
* 32-bit signed integer
@@ -130,6 +131,7 @@ def save(
130131
* FLAC
131132
* OGG/VORBIS
132133
* SPHERE
134+
* AMR-NB
133135
134136
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
135137
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
@@ -160,7 +162,7 @@ def save(
160162
filepath = str(filepath)
161163
if compression is None:
162164
ext = str(filepath).split('.')[-1].lower()
163-
if ext in ['wav', 'sph']:
165+
if ext in ['wav', 'sph', 'amb', 'amr-nb']:
164166
compression = 0.
165167
elif ext == 'mp3':
166168
compression = -4.5

torchaudio/csrc/sox_io.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,17 @@ void save_audio_file(
8585
const std::string& file_name,
8686
const c10::intrusive_ptr<TensorSignal>& signal,
8787
const double compression) {
88-
const auto tensor = signal->getTensor();
88+
auto tensor = signal->tensor;
8989

9090
validate_input_tensor(tensor);
9191

9292
const auto filetype = get_filetype(file_name);
93+
if (filetype == "amr-nb") {
94+
const auto num_channels = tensor.size(signal->channels_first ? 0 : 1);
95+
TORCH_CHECK(
96+
num_channels == 1, "amr-nb format only supports single channel audio.");
97+
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
98+
}
9399
const auto signal_info = get_signalinfo(signal.get(), filetype);
94100
const auto encoding_info =
95101
get_encodinginfo(filetype, tensor.dtype(), compression);

torchaudio/csrc/sox_utils.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ sox_encoding_t get_encoding(
223223
return SOX_ENCODING_FLAC;
224224
if (filetype == "ogg" || filetype == "vorbis")
225225
return SOX_ENCODING_VORBIS;
226-
if (filetype == "wav") {
226+
if (filetype == "wav" || filetype == "amb") {
227227
if (dtype == torch::kUInt8)
228228
return SOX_ENCODING_UNSIGNED;
229229
if (dtype == torch::kInt16)
@@ -236,7 +236,9 @@ sox_encoding_t get_encoding(
236236
}
237237
if (filetype == "sph")
238238
return SOX_ENCODING_SIGN2;
239-
throw std::runtime_error("Unsupported file type.");
239+
if (filetype == "amr-nb")
240+
return SOX_ENCODING_AMR_NB;
241+
throw std::runtime_error("Unsupported file type: " + filetype);
240242
}
241243

242244
unsigned get_precision(
@@ -248,7 +250,7 @@ unsigned get_precision(
248250
return 24;
249251
if (filetype == "ogg" || filetype == "vorbis")
250252
return SOX_UNSPEC;
251-
if (filetype == "wav") {
253+
if (filetype == "wav" || filetype == "amb") {
252254
if (dtype == torch::kUInt8)
253255
return 8;
254256
if (dtype == torch::kInt16)
@@ -261,7 +263,13 @@ unsigned get_precision(
261263
}
262264
if (filetype == "sph")
263265
return 32;
264-
throw std::runtime_error("Unsupported file type.");
266+
if (filetype == "amr-nb") {
267+
TORCH_INTERNAL_ASSERT(
268+
dtype == torch::kInt16,
269+
"When saving to AMR-NB format, the input tensor must be int16 type.");
270+
return 16;
271+
}
272+
throw std::runtime_error("Unsupported file type: " + filetype);
265273
}
266274

267275
sox_signalinfo_t get_signalinfo(
@@ -287,11 +295,13 @@ sox_encodinginfo_t get_encodinginfo(
287295
return compression;
288296
if (filetype == "ogg" || filetype == "vorbis")
289297
return compression;
290-
if (filetype == "wav")
298+
if (filetype == "wav" || filetype == "amb")
291299
return 0.;
292300
if (filetype == "sph")
293301
return 0.;
294-
throw std::runtime_error("Unsupported file type.");
302+
if (filetype == "amr-nb")
303+
return 0.;
304+
throw std::runtime_error("Unsupported file type: " + filetype);
295305
}();
296306

297307
return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),

0 commit comments

Comments
 (0)