Skip to content

Commit b119836

Browse files
committed
WIP: adding encoding and bits_per_sample option
1 parent 58ac6b9 commit b119836

File tree

5 files changed

+193
-89
lines changed

5 files changed

+193
-89
lines changed

torchaudio/backend/sox_io_backend.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,19 @@ def _save(
175175
channels_first: bool = True,
176176
compression: Optional[float] = None,
177177
format: Optional[str] = None,
178-
dtype: Optional[str] = None,
178+
encoding: Optional[str] = None,
179+
bits_per_sample: Optional[int] = None,
179180
):
180181
if hasattr(filepath, 'write'):
181182
if format is None:
182183
raise RuntimeError('`format` is required when saving to file object.')
183184
torchaudio._torchaudio.save_audio_fileobj(
184-
filepath, src, sample_rate, channels_first, compression, format, dtype)
185+
filepath, src, sample_rate, channels_first, compression,
186+
format, encoding, bits_per_sample)
185187
else:
186188
torch.ops.torchaudio.sox_io_save_audio_file(
187-
os.fspath(filepath), src, sample_rate, channels_first, compression, format, dtype)
189+
os.fspath(filepath), src, sample_rate, channels_first, compression,
190+
format, encoding, bits_per_sample)
188191

189192

190193
@_mod_utils.requires_module('torchaudio._torchaudio')
@@ -195,7 +198,8 @@ def save(
195198
channels_first: bool = True,
196199
compression: Optional[float] = None,
197200
format: Optional[str] = None,
198-
dtype: Optional[str] = None,
201+
encoding: Optional[str] = None,
202+
bits_per_sample: Optional[int] = None,
199203
):
200204
"""Save audio data to file.
201205
@@ -248,16 +252,11 @@ def save(
248252
``dtype=None`` means no conversion is performed.
249253
``dtype`` parameter is only effective for ``float32`` Tensor.
250254
"""
251-
if src.dtype == torch.float32 and dtype is None:
252-
warnings.warn(
253-
'`dtype` default value will be changed to `int16` in 0.9 release.'
254-
'Specify `dtype` to suppress this warning.'
255-
)
256255
if not torch.jit.is_scripting():
257-
_save(filepath, src, sample_rate, channels_first, compression, format, dtype)
256+
_save(filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample)
258257
return
259258
torch.ops.torchaudio.sox_io_save_audio_file(
260-
filepath, src, sample_rate, channels_first, compression, format, dtype)
259+
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample)
261260

262261

263262
@_mod_utils.requires_module('torchaudio._torchaudio')

torchaudio/csrc/sox/io.cpp

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -46,32 +46,31 @@ namespace {
4646

4747
std::string get_encoding(sox_encoding_t encoding) {
4848
switch (encoding) {
49-
case SOX_ENCODING_UNKNOWN:
50-
return "UNKNOWN";
5149
case SOX_ENCODING_SIGN2:
52-
return "PCM_S";
50+
return ENCODING_PCM_SIGNED;
5351
case SOX_ENCODING_UNSIGNED:
54-
return "PCM_U";
52+
return ENCODING_PCM_UNSIGNED;
5553
case SOX_ENCODING_FLOAT:
56-
return "PCM_F";
54+
return ENCODING_PCM_FLOAT;
5755
case SOX_ENCODING_FLAC:
58-
return "FLAC";
56+
return ENCODING_FLAC;
5957
case SOX_ENCODING_ULAW:
60-
return "ULAW";
58+
return ENCODING_ULAW;
6159
case SOX_ENCODING_ALAW:
62-
return "ALAW";
60+
return ENCODING_ALAW;
6361
case SOX_ENCODING_MP3:
64-
return "MP3";
62+
return ENCODING_MP3;
6563
case SOX_ENCODING_VORBIS:
66-
return "VORBIS";
64+
return ENCODING_VORBIS;
6765
case SOX_ENCODING_AMR_WB:
68-
return "AMR_WB";
66+
return ENCODING_AMR_WB;
6967
case SOX_ENCODING_AMR_NB:
70-
return "AMR_NB";
68+
return ENCODING_AMR_NB;
7169
case SOX_ENCODING_OPUS:
72-
return "OPUS";
70+
return ENCODING_OPUS;
71+
case SOX_ENCODING_UNKNOWN:
7372
default:
74-
return "UNKNOWN";
73+
return ENCODING_UNKNOWN;
7574
}
7675
}
7776

@@ -148,34 +147,26 @@ void save_audio_file(
148147
torch::Tensor tensor,
149148
int64_t sample_rate,
150149
bool channels_first,
151-
c10::optional<double> compression,
152-
c10::optional<std::string> format,
153-
c10::optional<std::string> dtype) {
150+
c10::optional<double>& compression,
151+
c10::optional<std::string>& format,
152+
c10::optional<std::string>& encoding,
153+
c10::optional<int64_t>& bits_per_sample) {
154154
validate_input_tensor(tensor);
155155

156-
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
157-
throw std::runtime_error(
158-
"dtype conversion only supported for float32 tensors");
159-
}
160-
const auto tgt_dtype =
161-
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
162-
? get_dtype_from_str(dtype.value())
163-
: tensor.dtype();
164-
165156
const auto filetype = [&]() {
166157
if (format.has_value())
167158
return format.value();
168159
return get_filetype(path);
169160
}();
161+
170162
if (filetype == "amr-nb") {
171163
const auto num_channels = tensor.size(channels_first ? 0 : 1);
172164
TORCH_CHECK(
173165
num_channels == 1, "amr-nb format only supports single channel audio.");
174-
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
175166
}
176167
const auto signal_info =
177168
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
178-
const auto encoding_info = get_encodinginfo_for_save(filetype, tgt_dtype, compression);
169+
const auto encoding_info = get_encodinginfo_for_save(filetype, compression, encoding, bits_per_sample);
179170

180171
SoxFormat sf(sox_open_write(
181172
path.c_str(),
@@ -289,31 +280,22 @@ void save_audio_fileobj(
289280
torch::Tensor tensor,
290281
int64_t sample_rate,
291282
bool channels_first,
292-
c10::optional<double> compression,
283+
c10::optional<double>& compression,
293284
std::string filetype,
294-
c10::optional<std::string> dtype) {
285+
c10::optional<std::string>& encoding,
286+
c10::optional<int64_t>& bits_per_sample) {
295287
validate_input_tensor(tensor);
296288

297-
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
298-
throw std::runtime_error(
299-
"dtype conversion only supported for float32 tensors");
300-
}
301-
const auto tgt_dtype =
302-
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
303-
? get_dtype_from_str(dtype.value())
304-
: tensor.dtype();
305-
306289
if (filetype == "amr-nb") {
307290
const auto num_channels = tensor.size(channels_first ? 0 : 1);
308291
if (num_channels != 1) {
309292
throw std::runtime_error(
310293
"amr-nb format only supports single channel audio.");
311294
}
312-
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
313295
}
314296
const auto signal_info =
315297
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
316-
const auto encoding_info = get_encodinginfo_for_save(filetype, tgt_dtype, compression);
298+
const auto encoding_info = get_encodinginfo_for_save(filetype, compression, encoding, bits_per_sample);
317299

318300
AutoReleaseBuffer buffer;
319301

torchaudio/csrc/sox/io.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ void save_audio_file(
4848
torch::Tensor tensor,
4949
int64_t sample_rate,
5050
bool channels_first,
51-
c10::optional<double> compression,
52-
c10::optional<std::string> format,
53-
c10::optional<std::string> dtype);
51+
c10::optional<double>& compression,
52+
c10::optional<std::string>& format,
53+
c10::optional<std::string>& encoding,
54+
c10::optional<int64_t>& bits_per_sample);
5455

5556
#ifdef TORCH_API_INCLUDE_EXTENSION_H
5657

@@ -71,9 +72,10 @@ void save_audio_fileobj(
7172
torch::Tensor tensor,
7273
int64_t sample_rate,
7374
bool channels_first,
74-
c10::optional<double> compression,
75+
c10::optional<double>& compression,
7576
std::string filetype,
76-
c10::optional<std::string> dtype);
77+
c10::optional<std::string>& encoding,
78+
c10::optional<int64_t>& bits_per_sample);
7779

7880
#endif // TORCH_API_INCLUDE_EXTENSION_H
7981

torchaudio/csrc/sox/utils.cpp

Lines changed: 136 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -220,31 +220,134 @@ const std::string get_filetype(const std::string path) {
220220
return ext;
221221
}
222222

223-
sox_encoding_t get_encoding(
224-
const std::string filetype,
225-
const caffe2::TypeMeta dtype) {
226-
if (filetype == "mp3")
227-
return SOX_ENCODING_MP3;
228-
if (filetype == "flac")
229-
return SOX_ENCODING_FLAC;
230-
if (filetype == "ogg" || filetype == "vorbis")
231-
return SOX_ENCODING_VORBIS;
232-
if (filetype == "wav" || filetype == "amb") {
233-
if (dtype == torch::kUInt8)
234-
return SOX_ENCODING_UNSIGNED;
235-
if (dtype == torch::kInt16)
236-
return SOX_ENCODING_SIGN2;
237-
if (dtype == torch::kInt32)
238-
return SOX_ENCODING_SIGN2;
239-
if (dtype == torch::kFloat32)
240-
return SOX_ENCODING_FLOAT;
241-
throw std::runtime_error("Unsupported dtype.");
223+
namespace {
224+
225+
std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
226+
const std::string format,
227+
const c10::optional<std::string>& encoding,
228+
const c10::optional<int64_t>& bits_per_sample) {
229+
if (!encoding.has_value()) {
230+
if (!bits_per_sample.has_value())
231+
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
232+
auto val = static_cast<unsigned>(bits_per_sample.value());
233+
if (val == 8)
234+
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
235+
return std::make_tuple<>(SOX_ENCODING_SIGN2, val);
242236
}
243-
if (filetype == "sph")
244-
return SOX_ENCODING_SIGN2;
245-
if (filetype == "amr-nb")
246-
return SOX_ENCODING_AMR_NB;
247-
throw std::runtime_error("Unsupported file type: " + filetype);
237+
if (encoding == ENCODING_PCM_SIGNED) {
238+
if (!bits_per_sample.has_value())
239+
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
240+
auto val = static_cast<unsigned>(bits_per_sample.value());
241+
if (val == 8) {
242+
TORCH_WARN_ONCE("%s does not support 8-bit signed PCM encoding. Using 16-bit.", format);
243+
val = 16;
244+
}
245+
return std::make_tuple<>(SOX_ENCODING_SIGN2, val);
246+
}
247+
if (encoding == ENCODING_PCM_UNSIGNED) {
248+
if (!bits_per_sample.has_value())
249+
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
250+
auto val = static_cast<unsigned>(bits_per_sample.value());
251+
if (val != 8)
252+
TORCH_WARN_ONCE("%s only supports 8-bit for unsigned PCM encoding. Using 8-bit.", format);
253+
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
254+
}
255+
if (encoding == ENCODING_PCM_FLOAT) {
256+
auto val = static_cast<unsigned>(bits_per_sample.value_or(32));
257+
if (val != 32)
258+
TORCH_WARN_ONCE("%s only supports 32-bit for floating point PCM encoding. Using 32-bit.", format);
259+
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
260+
}
261+
if (encoding == ENCODING_ULAW) {
262+
auto val = static_cast<unsigned>(bits_per_sample.value_or(8));
263+
if (val != 8)
264+
TORCH_WARN_ONCE("%s only supports 8-bit for mu-law encoding. Using 8-bit.", format);
265+
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
266+
}
267+
if (encoding == ENCODING_ALAW) {
268+
auto val = static_cast<unsigned>(bits_per_sample.value_or(8));
269+
if (val != 8)
270+
TORCH_WARN_ONCE("%s only supports 8-bit for a-law encoding. Using 8-bit.", format);
271+
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
272+
}
273+
std::ostringstream message;
274+
message << format << " format does not support encoding: " << encoding.value();
275+
throw std::runtime_error(message.str());
276+
}
277+
278+
std::tuple<sox_encoding_t, unsigned> get_save_encoding(
279+
const std::string& format,
280+
const c10::optional<std::string>& encoding,
281+
const c10::optional<int64_t>& bits_per_sample) {
282+
if (format == "mp3") {
283+
if (encoding.has_value()) {
284+
TORCH_WARN_ONCE("mp3 does not support `encoding` option. Ignoring.");
285+
}
286+
if (bits_per_sample.has_value()) {
287+
TORCH_WARN_ONCE("mp3 does not `bits_per_sample` option. Ignoring.");
288+
}
289+
return std::make_tuple<>(SOX_ENCODING_MP3, 16);
290+
}
291+
if (format == "ogg" || format == "vorbis") {
292+
if (encoding.has_value()) {
293+
TORCH_WARN_ONCE("ogg/vorbis does not support `encoding` option. Ignoring.");
294+
}
295+
if (bits_per_sample.has_value()) {
296+
TORCH_WARN_ONCE("ogg/vorbis does not `bits_per_sample` option. Ignoring.");
297+
}
298+
return std::make_tuple<>(SOX_ENCODING_VORBIS, 16);
299+
}
300+
if (format == "amr-nb") {
301+
if (encoding.has_value()) {
302+
TORCH_WARN_ONCE("amr-nb does not support `encoding` option. Ignoring.");
303+
}
304+
if (bits_per_sample.has_value()) {
305+
TORCH_WARN_ONCE("amr-nb does not `bits_per_sample` option. Ignoring.");
306+
}
307+
return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16);
308+
}
309+
if (format == "wav" || format == "amb") {
310+
return get_save_encoding_for_wav(format, encoding, bits_per_sample);
311+
}
312+
if (format == "flac") {
313+
if (encoding.has_value()) {
314+
TORCH_WARN_ONCE("flac does not support `encoding` option. Ignoring.");
315+
}
316+
unsigned bps = [&](){
317+
unsigned val = static_cast<unsigned>(bits_per_sample.value_or(24));
318+
if (val > 24) {
319+
TORCH_WARN_ONCE("flac does not support bits_per_sample larger than 24. Using 24.");
320+
val = 24;
321+
}
322+
return val;
323+
}();
324+
return std::make_tuple<>(SOX_ENCODING_FLAC, bps);
325+
}
326+
if (format == "sph") {
327+
if (!encoding.has_value() || encoding == ENCODING_PCM_SIGNED) {
328+
if (!bits_per_sample.has_value())
329+
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
330+
auto val = static_cast<unsigned>(bits_per_sample.value());
331+
return std::make_tuple<>(SOX_ENCODING_SIGN2, val);
332+
}
333+
if (encoding == ENCODING_PCM_UNSIGNED || encoding == ENCODING_PCM_FLOAT) {
334+
TORCH_WARN_ONCE("sph does not support unsigned integer PCM or floating point PCM. Using signed interger PCM");
335+
auto val = static_cast<unsigned>(bits_per_sample.value_or(16));
336+
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, val);
337+
}
338+
if (encoding == ENCODING_ULAW) {
339+
auto val = static_cast<unsigned>(bits_per_sample.value_or(8));
340+
if (val != 8)
341+
TORCH_WARN_ONCE("sph only supports 8-bit for mu-law encoding. Using 8-bit.");
342+
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
343+
}
344+
if (encoding == ENCODING_ALAW) {
345+
auto val = static_cast<unsigned>(bits_per_sample.value_or(8));
346+
return std::make_tuple<>(SOX_ENCODING_ALAW, val);
347+
}
348+
throw std::runtime_error("sph format does not support encoding: " + encoding.value());
349+
}
350+
throw std::runtime_error("Unsupported format: " + format);
248351
}
249352

250353
unsigned get_precision(
@@ -278,6 +381,8 @@ unsigned get_precision(
278381
throw std::runtime_error("Unsupported file type: " + filetype);
279382
}
280383

384+
} // namepsace
385+
281386
sox_signalinfo_t get_signalinfo(
282387
const torch::Tensor* waveform,
283388
const int64_t sample_rate,
@@ -326,12 +431,14 @@ sox_encodinginfo_t get_tensor_encodinginfo(
326431
}
327432

328433
sox_encodinginfo_t get_encodinginfo_for_save(
329-
const std::string filetype,
330-
const caffe2::TypeMeta dtype,
331-
c10::optional<double>& compression) {
434+
const std::string& format,
435+
const c10::optional<double>& compression,
436+
const c10::optional<std::string>& encoding,
437+
const c10::optional<int64_t>& bits_per_sample) {
438+
auto enc = get_save_encoding(format, encoding, bits_per_sample);
332439
return sox_encodinginfo_t{
333-
/*encoding=*/get_encoding(filetype, dtype),
334-
/*bits_per_sample=*/get_precision(filetype, dtype),
440+
/*encoding=*/std::get<0>(enc),
441+
/*bits_per_sample=*/std::get<1>(enc),
335442
/*compression=*/compression.value_or(HUGE_VAL),
336443
/*reverse_bytes=*/sox_option_default,
337444
/*reverse_nibbles=*/sox_option_default,

0 commit comments

Comments
 (0)