Skip to content

Commit 44ea317

Browse files
committed
fixup! Update stuff
1 parent 4b419c9 commit 44ea317

File tree

5 files changed

+51
-52
lines changed

5 files changed

+51
-52
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,14 @@
11
def name_func(func, _, params):
22
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
3+
4+
5+
def get_enc_params(dtype):
6+
if dtype == 'float32':
7+
return 'PCM_F', 32
8+
if dtype == 'int32':
9+
return 'PCM_S', 32
10+
if dtype == 'int16':
11+
return 'PCM_S', 16
12+
if dtype == 'uint8':
13+
return 'PCM_U', 8
14+
raise ValueError(f'Unexpected dtype: {dtype}')

test/torchaudio_unittest/backend/sox_io/roundtrip_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from .common import (
1414
name_func,
15+
get_enc_params,
1516
)
1617

1718

@@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
2728
def test_wav(self, dtype, sample_rate, num_channels):
2829
"""save/load round trip should not degrade data for wav formats"""
2930
original = get_wav_data(dtype, num_channels, normalize=False)
31+
enc, bps = get_enc_params(dtype)
3032
data = original
3133
for i in range(10):
3234
path = self.get_temp_path(f'{i}.wav')
33-
sox_io_backend.save(path, data, sample_rate)
35+
sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps)
3436
data, sr = sox_io_backend.load(path, normalize=False)
3537
assert sr == sample_rate
3638
self.assertEqual(original, data)

test/torchaudio_unittest/backend/sox_io/save_test.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from .common import (
1818
name_func,
19+
get_enc_params,
1920
)
2021

2122

@@ -41,7 +42,6 @@ def assert_save_consistency(
4142
sample_rate: float = 8000,
4243
num_channels: int = 2,
4344
num_frames: float = 3 * 8000,
44-
src_dtype: str = 'int32',
4545
rtol: float = 1.3e-06,
4646
atol: float = 1e-05,
4747
):
@@ -52,7 +52,6 @@ def assert_save_consistency(
5252
compression (float, optional): `compression` value for `save` function
5353
encoding (str, optional): `encoding` value for `save` function
5454
bits_per_sample (int, optional): `bits_per_sample` value for `save` function.
55-
src_dtype: (str, optional): Dtype for generating the source WAV file.
5655
5756
To compare that the file produced by `save` function agains the file produced by
5857
the equivalent `sox` command, we need to load both files.
@@ -99,7 +98,7 @@ def assert_save_consistency(
9998
ref_path = self.get_temp_path('3.2.ref.wav')
10099

101100
# 1. Generate original wav
102-
data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames)
101+
data = get_wav_data('int32', num_channels, normalize=False, num_frames=num_frames)
103102
save_wav(src_path, data, sample_rate)
104103

105104
# 2.1. Convert the original wav to target format with torchaudio
@@ -128,17 +127,15 @@ def assert_save_consistency(
128127

129128
class SaveTestEncode(SaveTestBase):
130129
@parameterized.expand([
131-
('PCM_U', 8, 'uint8'),
132-
('PCM_S', 16, 'int16'),
133-
('PCM_S', 32, 'int32'),
134-
('PCM_F', 32, 'float32'),
135-
('ULAW', 8, 'float32'),
136-
('ALAW', 8, 'float32'),
130+
('PCM_U', 8),
131+
('PCM_S', 16),
132+
('PCM_S', 32),
133+
('PCM_F', 32),
134+
('ULAW', 8),
135+
('ALAW', 8),
137136
], name_func=name_func)
138-
def test_wav(self, encoding, bits_per_sample, src_dtype):
139-
self.assert_save_consistency(
140-
"wav", encoding=encoding, bits_per_sample=bits_per_sample, src_dtype=src_dtype,
141-
)
137+
def test_wav(self, encoding, bits_per_sample):
138+
self.assert_save_consistency("wav", encoding=encoding, bits_per_sample=bits_per_sample)
142139

143140
@parameterized.expand([
144141
(None, ),
@@ -249,7 +246,9 @@ def test_large(self, format):
249246
], name_func=name_func)
250247
def test_multi_channels(self, num_channels):
251248
"""`sox_io_backend.save` can save audio with many channels"""
252-
self.assert_save_consistency("wav", num_channels=num_channels)
249+
self.assert_save_consistency(
250+
"wav", encoding="PCM_U", bits_per_sample=16,
251+
num_channels=num_channels)
253252

254253

255254
@skipIfNoExec('sox')
@@ -260,10 +259,11 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
260259
def test_channels_first(self, channels_first):
261260
"""channels_first swaps axes"""
262261
path = self.get_temp_path('data.wav')
263-
data = get_wav_data('int32', 2, channels_first=channels_first)
262+
data = get_wav_data(
263+
'int16', 2, channels_first=channels_first, normalize=False)
264264
sox_io_backend.save(
265265
path, data, 8000, channels_first=channels_first)
266-
found = load_wav(path)[0]
266+
found = load_wav(path, normalize=False)[0]
267267
expected = data if channels_first else data.transpose(1, 0)
268268
self.assertEqual(found, expected)
269269

@@ -273,10 +273,12 @@ def test_channels_first(self, channels_first):
273273
def test_noncontiguous(self, dtype):
274274
"""Noncontiguous tensors are saved correctly"""
275275
path = self.get_temp_path('data.wav')
276-
expected = get_wav_data(dtype, 4)[::2, ::2]
276+
enc, bps = get_enc_params(dtype)
277+
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
277278
assert not expected.is_contiguous()
278-
sox_io_backend.save(path, expected, 8000)
279-
found = load_wav(path)[0]
279+
sox_io_backend.save(
280+
path, expected, 8000, encoding=enc, bits_per_sample=bps)
281+
found = load_wav(path, normalize=False)[0]
280282
self.assertEqual(found, expected)
281283

282284
@parameterized.expand([
@@ -285,7 +287,7 @@ def test_noncontiguous(self, dtype):
285287
def test_tensor_preserve(self, dtype):
286288
"""save function should not alter Tensor"""
287289
path = self.get_temp_path('data.wav')
288-
expected = get_wav_data(dtype, 4)[::2, ::2]
290+
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
289291

290292
data = expected.clone()
291293
sox_io_backend.save(path, data, 8000)

test/torchaudio_unittest/backend/sox_io/torchscript_test.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from .common import (
1919
name_func,
20+
get_enc_params,
2021
)
2122

2223

@@ -35,8 +36,12 @@ def py_save_func(
3536
sample_rate: int,
3637
channels_first: bool = True,
3738
compression: Optional[float] = None,
39+
encoding: Optional[str] = None,
40+
bits_per_sample: Optional[int] = None,
3841
):
39-
torchaudio.save(filepath, tensor, sample_rate, channels_first, compression)
42+
torchaudio.save(
43+
filepath, tensor, sample_rate, channels_first,
44+
compression, None, encoding, bits_per_sample)
4045

4146

4247
@skipIfNoExec('sox')
@@ -102,15 +107,16 @@ def test_save_wav(self, dtype, sample_rate, num_channels):
102107
torch.jit.script(py_save_func).save(script_path)
103108
ts_save_func = torch.jit.load(script_path)
104109

105-
expected = get_wav_data(dtype, num_channels)
110+
expected = get_wav_data(dtype, num_channels, normalize=False)
106111
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
107112
ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')
113+
enc, bps = get_enc_params(dtype)
108114

109-
py_save_func(py_path, expected, sample_rate, True, None)
110-
ts_save_func(ts_path, expected, sample_rate, True, None)
115+
py_save_func(py_path, expected, sample_rate, True, None, enc, bps)
116+
ts_save_func(ts_path, expected, sample_rate, True, None, enc, bps)
111117

112-
py_data, py_sr = load_wav(py_path)
113-
ts_data, ts_sr = load_wav(ts_path)
118+
py_data, py_sr = load_wav(py_path, normalize=False)
119+
ts_data, ts_sr = load_wav(ts_path, normalize=False)
114120

115121
self.assertEqual(sample_rate, py_sr)
116122
self.assertEqual(sample_rate, ts_sr)
@@ -131,8 +137,8 @@ def test_save_flac(self, sample_rate, num_channels, compression_level):
131137
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
132138
ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac')
133139

134-
py_save_func(py_path, expected, sample_rate, True, compression_level)
135-
ts_save_func(ts_path, expected, sample_rate, True, compression_level)
140+
py_save_func(py_path, expected, sample_rate, True, compression_level, None, None)
141+
ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None)
136142

137143
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
138144
py_path_wav = f'{py_path}.wav'

torchaudio/backend/sox_io_backend.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -147,29 +147,6 @@ def load(
147147
filepath, frame_offset, num_frames, normalize, channels_first, format)
148148

149149

150-
@torch.jit.unused
151-
def _save(
152-
filepath: str,
153-
src: torch.Tensor,
154-
sample_rate: int,
155-
channels_first: bool = True,
156-
compression: Optional[float] = None,
157-
format: Optional[str] = None,
158-
encoding: Optional[str] = None,
159-
bits_per_sample: Optional[int] = None,
160-
):
161-
if hasattr(filepath, 'write'):
162-
if format is None:
163-
raise RuntimeError('`format` is required when saving to file object.')
164-
torchaudio._torchaudio.save_audio_fileobj(
165-
filepath, src, sample_rate, channels_first, compression,
166-
format, encoding, bits_per_sample)
167-
else:
168-
torch.ops.torchaudio.sox_io_save_audio_file(
169-
os.fspath(filepath), src, sample_rate, channels_first, compression,
170-
format, encoding, bits_per_sample)
171-
172-
173150
@_mod_utils.requires_module('torchaudio._torchaudio')
174151
def save(
175152
filepath: str,

0 commit comments

Comments
 (0)