16
16
)
17
17
from .common import (
18
18
name_func ,
19
+ get_enc_params ,
19
20
)
20
21
21
22
@@ -41,7 +42,6 @@ def assert_save_consistency(
41
42
sample_rate : float = 8000 ,
42
43
num_channels : int = 2 ,
43
44
num_frames : float = 3 * 8000 ,
44
- src_dtype : str = 'int32' ,
45
45
rtol : float = 1.3e-06 ,
46
46
atol : float = 1e-05 ,
47
47
):
@@ -52,7 +52,6 @@ def assert_save_consistency(
52
52
compression (float, optional): `compression` value for `save` function
53
53
encoding (str, optional): `encoding` value for `save` function
54
54
bits_per_sample (int, optional): `bits_per_sample` value for `save` function.
55
- src_dtype: (str, optional): Dtype for generating the source WAV file.
56
55
57
56
To compare that the file produced by `save` function agains the file produced by
58
57
the equivalent `sox` command, we need to load both files.
@@ -99,7 +98,7 @@ def assert_save_consistency(
99
98
ref_path = self .get_temp_path ('3.2.ref.wav' )
100
99
101
100
# 1. Generate original wav
102
- data = get_wav_data (src_dtype , num_channels , normalize = False , num_frames = num_frames )
101
+ data = get_wav_data ('int32' , num_channels , normalize = False , num_frames = num_frames )
103
102
save_wav (src_path , data , sample_rate )
104
103
105
104
# 2.1. Convert the original wav to target format with torchaudio
@@ -128,17 +127,15 @@ def assert_save_consistency(
128
127
129
128
class SaveTestEncode (SaveTestBase ):
130
129
@parameterized .expand ([
131
- ('PCM_U' , 8 , 'uint8' ),
132
- ('PCM_S' , 16 , 'int16' ),
133
- ('PCM_S' , 32 , 'int32' ),
134
- ('PCM_F' , 32 , 'float32' ),
135
- ('ULAW' , 8 , 'float32' ),
136
- ('ALAW' , 8 , 'float32' ),
130
+ ('PCM_U' , 8 ),
131
+ ('PCM_S' , 16 ),
132
+ ('PCM_S' , 32 ),
133
+ ('PCM_F' , 32 ),
134
+ ('ULAW' , 8 ),
135
+ ('ALAW' , 8 ),
137
136
], name_func = name_func )
138
- def test_wav (self , encoding , bits_per_sample , src_dtype ):
139
- self .assert_save_consistency (
140
- "wav" , encoding = encoding , bits_per_sample = bits_per_sample , src_dtype = src_dtype ,
141
- )
137
+ def test_wav (self , encoding , bits_per_sample ):
138
+ self .assert_save_consistency ("wav" , encoding = encoding , bits_per_sample = bits_per_sample )
142
139
143
140
@parameterized .expand ([
144
141
(None , ),
@@ -249,7 +246,9 @@ def test_large(self, format):
249
246
], name_func = name_func )
250
247
def test_multi_channels (self , num_channels ):
251
248
"""`sox_io_backend.save` can save audio with many channels"""
252
- self .assert_save_consistency ("wav" , num_channels = num_channels )
249
+ self .assert_save_consistency (
250
+ "wav" , encoding = "PCM_U" , bits_per_sample = 16 ,
251
+ num_channels = num_channels )
253
252
254
253
255
254
@skipIfNoExec ('sox' )
@@ -260,10 +259,11 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
260
259
def test_channels_first (self , channels_first ):
261
260
"""channels_first swaps axes"""
262
261
path = self .get_temp_path ('data.wav' )
263
- data = get_wav_data ('int32' , 2 , channels_first = channels_first )
262
+ data = get_wav_data (
263
+ 'int16' , 2 , channels_first = channels_first , normalize = False )
264
264
sox_io_backend .save (
265
265
path , data , 8000 , channels_first = channels_first )
266
- found = load_wav (path )[0 ]
266
+ found = load_wav (path , normalize = False )[0 ]
267
267
expected = data if channels_first else data .transpose (1 , 0 )
268
268
self .assertEqual (found , expected )
269
269
@@ -273,10 +273,12 @@ def test_channels_first(self, channels_first):
273
273
def test_noncontiguous (self , dtype ):
274
274
"""Noncontiguous tensors are saved correctly"""
275
275
path = self .get_temp_path ('data.wav' )
276
- expected = get_wav_data (dtype , 4 )[::2 , ::2 ]
276
+ enc , bps = get_enc_params (dtype )
277
+ expected = get_wav_data (dtype , 4 , normalize = False )[::2 , ::2 ]
277
278
assert not expected .is_contiguous ()
278
- sox_io_backend .save (path , expected , 8000 )
279
- found = load_wav (path )[0 ]
279
+ sox_io_backend .save (
280
+ path , expected , 8000 , encoding = enc , bits_per_sample = bps )
281
+ found = load_wav (path , normalize = False )[0 ]
280
282
self .assertEqual (found , expected )
281
283
282
284
@parameterized .expand ([
@@ -285,7 +287,7 @@ def test_noncontiguous(self, dtype):
285
287
def test_tensor_preserve (self , dtype ):
286
288
"""save function should not alter Tensor"""
287
289
path = self .get_temp_path ('data.wav' )
288
- expected = get_wav_data (dtype , 4 )[::2 , ::2 ]
290
+ expected = get_wav_data (dtype , 4 , normalize = False )[::2 , ::2 ]
289
291
290
292
data = expected .clone ()
291
293
sox_io_backend .save (path , data , 8000 )
0 commit comments