Skip to content

Commit be44256

Browse files
authored
Add format override to load and related I/O functions (#1104)
1 parent c4f0a11 commit be44256

File tree

14 files changed

+100
-62
lines changed

14 files changed

+100
-62
lines changed
15.1 KB
Binary file not shown.

test/torchaudio_unittest/sox_effect/sox_effect_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
TempDirMixin,
99
PytorchTestCase,
1010
skipIfNoExtension,
11+
get_asset_path,
1112
get_sinusoid,
1213
get_wav_data,
1314
save_wav,
@@ -243,3 +244,21 @@ def test_vorbis(self, sample_rate, num_channels):
243244

244245
assert sr == expected_sr
245246
self.assertEqual(found, expected)
247+
248+
249+
@skipIfNoExtension
250+
class TestApplyEffectFileWithoutExtension(PytorchTestCase):
251+
def test_mp3(self):
252+
"""Providing format allows to read mp3 without extension
253+
254+
libsox does not check header for mp3
255+
256+
https://github.com/pytorch/audio/issues/1040
257+
258+
The file was generated with the following command
259+
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
260+
"""
261+
effects = [['band', '300', '10']]
262+
path = get_asset_path("mp3_without_ext")
263+
_, sr = sox_effects.apply_effects_file(path, effects, format="mp3")
264+
assert sr == 16000

test/torchaudio_unittest/sox_io_backend/info_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,20 @@ def test_opus(self, bitrate, num_channels, compression_level):
167167
assert info.sample_rate == 48000
168168
assert info.num_frames == 32768
169169
assert info.num_channels == num_channels
170+
171+
172+
@skipIfNoExtension
173+
class TestLoadWithoutExtension(PytorchTestCase):
174+
def test_mp3(self):
175+
"""Providing `format` allows to read mp3 without extension
176+
177+
libsox does not check header for mp3
178+
179+
https://github.com/pytorch/audio/issues/1040
180+
181+
The file was generated with the following command
182+
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
183+
"""
184+
path = get_asset_path("mp3_without_ext")
185+
sinfo = sox_io_backend.info(path, format="mp3")
186+
assert sinfo.sample_rate == 16000

test/torchaudio_unittest/sox_io_backend/load_test.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import itertools
32

43
from torchaudio.backend import sox_io_backend
@@ -355,26 +354,18 @@ def test_channels_first(self, channels_first):
355354
self.assertEqual(found, expected)
356355

357356

358-
@skipIfNoExec('sox')
359357
@skipIfNoExtension
360-
class TestLoadExtensionLess(TempDirMixin, PytorchTestCase):
361-
"""Given `format` parameter, `sox_io_backend.load` can load files without extension"""
362-
original = None
363-
path = None
358+
class TestLoadWithoutExtension(PytorchTestCase):
359+
def test_mp3(self):
360+
"""Providing format allows to read mp3 without extension
364361
365-
def _make_file(self, format_):
366-
sample_rate = 8000
367-
path = self.get_temp_path(f'test.{format_}')
368-
sox_utils.gen_audio_file(f'{path}', sample_rate, num_channels=2)
369-
self.original = sox_io_backend.load(path)[0]
370-
self.path = os.path.splitext(path)[0]
371-
os.rename(path, self.path)
372-
373-
@parameterized.expand([
374-
('WAV', ), ('wav', ), ('MP3', ), ('mp3', ), ('FLAC', ), ('flac',),
375-
], name_func=name_func)
376-
def test_format(self, format_):
377-
"""Providing format allows to read file without extension"""
378-
self._make_file(format_)
379-
found, _ = sox_io_backend.load(self.path)
380-
self.assertEqual(found, self.original)
362+
libsox does not check header for mp3
363+
364+
https://github.com/pytorch/audio/issues/1040
365+
366+
The file was generated with the following command
367+
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
368+
"""
369+
path = get_asset_path("mp3_without_ext")
370+
_, sr = sox_io_backend.load(path, format="mp3")
371+
assert sr == 16000

third_party/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ ExternalProject_Add(libsox
8484
DOWNLOAD_DIR ${ARCHIVE_DIR}
8585
URL https://downloads.sourceforge.net/project/sox/sox/14.4.2/sox-14.4.2.tar.bz2
8686
URL_HASH SHA256=81a6956d4330e75b5827316e44ae381e6f1e8928003c6aa45896da9041ea149c
87-
PATCH_COMMAND patch -p0 < ${CMAKE_CURRENT_SOURCE_DIR}/patch/libsox.patch
8887
# OpenMP is by default compiled against GNU OpenMP, which conflicts with the version of OpenMP that PyTorch uses.
8988
# See https://github.com/pytorch/audio/pull/1026
9089
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp

third_party/patch/libsox.patch

Lines changed: 0 additions & 24 deletions
This file was deleted.

torchaudio/backend/_soundfile_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def load(
4545
num_frames: int = -1,
4646
normalize: bool = True,
4747
channels_first: bool = True,
48+
format: Optional[str] = None,
4849
) -> Tuple[torch.Tensor, int]:
4950
"""Load audio data from file.
5051
@@ -99,6 +100,8 @@ def load(
99100
channels_first (bool):
100101
When True, the returned Tensor has dimension ``[channel, time]``.
101102
Otherwise, the returned Tensor's dimension is ``[time, channel]``.
103+
format (str, optional):
104+
Not used. PySoundFile does not accept format hint.
102105
103106
Returns:
104107
torch.Tensor:

torchaudio/backend/sox_io_backend.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,27 @@
99

1010

1111
@_mod_utils.requires_module('torchaudio._torchaudio')
12-
def info(filepath: str) -> AudioMetaData:
12+
def info(
13+
filepath: str,
14+
format: Optional[str] = None,
15+
) -> AudioMetaData:
1316
"""Get signal information of an audio file.
1417
1518
Args:
1619
filepath (str or pathlib.Path):
1720
Path to audio file. This function also handles ``pathlib.Path`` objects,
1821
but is annotated as ``str`` for TorchScript compatibility.
22+
format (str, optional):
23+
Override the format detection with the given format.
24+
Providing the argument might help when libsox can not infer the format
25+
from header or extension,
1926
2027
Returns:
2128
AudioMetaData: Metadata of the given audio.
2229
"""
2330
# Cast to str in case type is `pathlib.Path`
2431
filepath = str(filepath)
25-
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath)
32+
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
2633
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())
2734

2835

@@ -33,6 +40,7 @@ def load(
3340
num_frames: int = -1,
3441
normalize: bool = True,
3542
channels_first: bool = True,
43+
format: Optional[str] = None,
3644
) -> Tuple[torch.Tensor, int]:
3745
"""Load audio data from file.
3846
@@ -93,6 +101,10 @@ def load(
93101
channels_first (bool):
94102
When True, the returned Tensor has dimension ``[channel, time]``.
95103
Otherwise, the returned Tensor's dimension is ``[time, channel]``.
104+
format (str, optional):
105+
Override the format detection with the given format.
106+
Providing the argument might help when libsox can not infer the format
107+
from header or extension,
96108
97109
Returns:
98110
torch.Tensor:
@@ -103,7 +115,7 @@ def load(
103115
# Cast to str in case type is `pathlib.Path`
104116
filepath = str(filepath)
105117
signal = torch.ops.torchaudio.sox_io_load_audio_file(
106-
filepath, frame_offset, num_frames, normalize, channels_first)
118+
filepath, frame_offset, num_frames, normalize, channels_first, format)
107119
return signal.get_tensor(), signal.get_sample_rate()
108120

109121

torchaudio/csrc/register.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,14 @@ TORCH_LIBRARY(torchaudio, m) {
4949

5050
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info);
5151
m.def(
52-
"torchaudio::sox_io_load_audio_file",
52+
"torchaudio::sox_io_load_audio_file("
53+
"str path,"
54+
"int? frame_offset=None,"
55+
"int? num_frames=None,"
56+
"bool? normalize=True,"
57+
"bool? channels_first=False,"
58+
"str? format=None"
59+
") -> __torch__.torch.classes.torchaudio.TensorSignal",
5360
&torchaudio::sox_io::load_audio_file);
5461
m.def(
5562
"torchaudio::sox_io_save_audio_file",

torchaudio/csrc/sox_effects.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,14 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
9292
const std::string path,
9393
std::vector<std::vector<std::string>> effects,
9494
c10::optional<bool>& normalize,
95-
c10::optional<bool>& channels_first) {
95+
c10::optional<bool>& channels_first,
96+
c10::optional<std::string>& format) {
9697
// Open input file
9798
SoxFormat sf(sox_open_read(
9899
path.c_str(),
99100
/*signal=*/nullptr,
100101
/*encoding=*/nullptr,
101-
/*filetype=*/nullptr));
102+
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
102103

103104
validate_input_file(sf);
104105

0 commit comments

Comments
 (0)