@@ -220,31 +220,134 @@ const std::string get_filetype(const std::string path) {
220
220
return ext;
221
221
}
222
222
223
- sox_encoding_t get_encoding (
224
- const std::string filetype,
225
- const caffe2::TypeMeta dtype) {
226
- if (filetype == " mp3" )
227
- return SOX_ENCODING_MP3;
228
- if (filetype == " flac" )
229
- return SOX_ENCODING_FLAC;
230
- if (filetype == " ogg" || filetype == " vorbis" )
231
- return SOX_ENCODING_VORBIS;
232
- if (filetype == " wav" || filetype == " amb" ) {
233
- if (dtype == torch::kUInt8 )
234
- return SOX_ENCODING_UNSIGNED;
235
- if (dtype == torch::kInt16 )
236
- return SOX_ENCODING_SIGN2;
237
- if (dtype == torch::kInt32 )
238
- return SOX_ENCODING_SIGN2;
239
- if (dtype == torch::kFloat32 )
240
- return SOX_ENCODING_FLOAT;
241
- throw std::runtime_error (" Unsupported dtype." );
223
+ namespace {
224
+
225
+ std::tuple<sox_encoding_t , unsigned > get_save_encoding_for_wav (
226
+ const std::string format,
227
+ const c10::optional<std::string>& encoding,
228
+ const c10::optional<int64_t >& bits_per_sample) {
229
+ if (!encoding.has_value ()) {
230
+ if (!bits_per_sample.has_value ())
231
+ return std::make_tuple<>(SOX_ENCODING_SIGN2, 16 );
232
+ auto val = static_cast <unsigned >(bits_per_sample.value ());
233
+ if (val == 8 )
234
+ return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
235
+ return std::make_tuple<>(SOX_ENCODING_SIGN2, val);
242
236
}
243
- if (filetype == " sph" )
244
- return SOX_ENCODING_SIGN2;
245
- if (filetype == " amr-nb" )
246
- return SOX_ENCODING_AMR_NB;
247
- throw std::runtime_error (" Unsupported file type: " + filetype);
237
+ if (encoding == ENCODING_PCM_SIGNED) {
238
+ if (!bits_per_sample.has_value ())
239
+ return std::make_tuple<>(SOX_ENCODING_SIGN2, 16 );
240
+ auto val = static_cast <unsigned >(bits_per_sample.value ());
241
+ if (val == 8 ) {
242
+ TORCH_WARN_ONCE (" %s does not support 8-bit signed PCM encoding. Using 16-bit." , format);
243
+ val = 16 ;
244
+ }
245
+ return std::make_tuple<>(SOX_ENCODING_SIGN2, val);
246
+ }
247
+ if (encoding == ENCODING_PCM_UNSIGNED) {
248
+ if (!bits_per_sample.has_value ())
249
+ return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
250
+ auto val = static_cast <unsigned >(bits_per_sample.value ());
251
+ if (val != 8 )
252
+ TORCH_WARN_ONCE (" %s only supports 8-bit for unsigned PCM encoding. Using 8-bit." , format);
253
+ return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
254
+ }
255
+ if (encoding == ENCODING_PCM_FLOAT) {
256
+ auto val = static_cast <unsigned >(bits_per_sample.value_or (32 ));
257
+ if (val != 32 )
258
+ TORCH_WARN_ONCE (" %s only supports 32-bit for floating point PCM encoding. Using 32-bit." , format);
259
+ return std::make_tuple<>(SOX_ENCODING_FLOAT, 32 );
260
+ }
261
+ if (encoding == ENCODING_ULAW) {
262
+ auto val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
263
+ if (val != 8 )
264
+ TORCH_WARN_ONCE (" %s only supports 8-bit for mu-law encoding. Using 8-bit." , format);
265
+ return std::make_tuple<>(SOX_ENCODING_ULAW, 8 );
266
+ }
267
+ if (encoding == ENCODING_ALAW) {
268
+ auto val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
269
+ if (val != 8 )
270
+ TORCH_WARN_ONCE (" %s only supports 8-bit for a-law encoding. Using 8-bit." , format);
271
+ return std::make_tuple<>(SOX_ENCODING_ALAW, 8 );
272
+ }
273
+ std::ostringstream message;
274
+ message << format << " format does not support encoding: " << encoding.value ();
275
+ throw std::runtime_error (message.str ());
276
+ }
277
+
278
+ std::tuple<sox_encoding_t , unsigned > get_save_encoding (
279
+ const std::string& format,
280
+ const c10::optional<std::string>& encoding,
281
+ const c10::optional<int64_t >& bits_per_sample) {
282
+ if (format == " mp3" ) {
283
+ if (encoding.has_value ()) {
284
+ TORCH_WARN_ONCE (" mp3 does not support `encoding` option. Ignoring." );
285
+ }
286
+ if (bits_per_sample.has_value ()) {
287
+ TORCH_WARN_ONCE (" mp3 does not `bits_per_sample` option. Ignoring." );
288
+ }
289
+ return std::make_tuple<>(SOX_ENCODING_MP3, 16 );
290
+ }
291
+ if (format == " ogg" || format == " vorbis" ) {
292
+ if (encoding.has_value ()) {
293
+ TORCH_WARN_ONCE (" ogg/vorbis does not support `encoding` option. Ignoring." );
294
+ }
295
+ if (bits_per_sample.has_value ()) {
296
+ TORCH_WARN_ONCE (" ogg/vorbis does not `bits_per_sample` option. Ignoring." );
297
+ }
298
+ return std::make_tuple<>(SOX_ENCODING_VORBIS, 16 );
299
+ }
300
+ if (format == " amr-nb" ) {
301
+ if (encoding.has_value ()) {
302
+ TORCH_WARN_ONCE (" amr-nb does not support `encoding` option. Ignoring." );
303
+ }
304
+ if (bits_per_sample.has_value ()) {
305
+ TORCH_WARN_ONCE (" amr-nb does not `bits_per_sample` option. Ignoring." );
306
+ }
307
+ return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16 );
308
+ }
309
+ if (format == " wav" || format == " amb" ) {
310
+ return get_save_encoding_for_wav (format, encoding, bits_per_sample);
311
+ }
312
+ if (format == " flac" ) {
313
+ if (encoding.has_value ()) {
314
+ TORCH_WARN_ONCE (" flac does not support `encoding` option. Ignoring." );
315
+ }
316
+ unsigned bps = [&](){
317
+ unsigned val = static_cast <unsigned >(bits_per_sample.value_or (24 ));
318
+ if (val > 24 ) {
319
+ TORCH_WARN_ONCE (" flac does not support bits_per_sample larger than 24. Using 24." );
320
+ val = 24 ;
321
+ }
322
+ return val;
323
+ }();
324
+ return std::make_tuple<>(SOX_ENCODING_FLAC, bps);
325
+ }
326
+ if (format == " sph" ) {
327
+ if (!encoding.has_value () || encoding == ENCODING_PCM_SIGNED) {
328
+ if (!bits_per_sample.has_value ())
329
+ return std::make_tuple<>(SOX_ENCODING_SIGN2, 16 );
330
+ auto val = static_cast <unsigned >(bits_per_sample.value ());
331
+ return std::make_tuple<>(SOX_ENCODING_SIGN2, val);
332
+ }
333
+ if (encoding == ENCODING_PCM_UNSIGNED || encoding == ENCODING_PCM_FLOAT) {
334
+ TORCH_WARN_ONCE (" sph does not support unsigned integer PCM or floating point PCM. Using signed interger PCM" );
335
+ auto val = static_cast <unsigned >(bits_per_sample.value_or (16 ));
336
+ return std::make_tuple<>(SOX_ENCODING_UNSIGNED, val);
337
+ }
338
+ if (encoding == ENCODING_ULAW) {
339
+ auto val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
340
+ if (val != 8 )
341
+ TORCH_WARN_ONCE (" sph only supports 8-bit for mu-law encoding. Using 8-bit." );
342
+ return std::make_tuple<>(SOX_ENCODING_ULAW, 8 );
343
+ }
344
+ if (encoding == ENCODING_ALAW) {
345
+ auto val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
346
+ return std::make_tuple<>(SOX_ENCODING_ALAW, val);
347
+ }
348
+ throw std::runtime_error (" sph format does not support encoding: " + encoding.value ());
349
+ }
350
+ throw std::runtime_error (" Unsupported format: " + format);
248
351
}
249
352
250
353
unsigned get_precision (
@@ -278,6 +381,8 @@ unsigned get_precision(
278
381
throw std::runtime_error (" Unsupported file type: " + filetype);
279
382
}
280
383
384
+ } // namepsace
385
+
281
386
sox_signalinfo_t get_signalinfo (
282
387
const torch::Tensor* waveform,
283
388
const int64_t sample_rate,
@@ -326,12 +431,14 @@ sox_encodinginfo_t get_tensor_encodinginfo(
326
431
}
327
432
328
433
sox_encodinginfo_t get_encodinginfo_for_save (
329
- const std::string filetype,
330
- const caffe2::TypeMeta dtype,
331
- c10::optional<double >& compression) {
434
+ const std::string& format,
435
+ const c10::optional<double >& compression,
436
+ const c10::optional<std::string>& encoding,
437
+ const c10::optional<int64_t >& bits_per_sample) {
438
+ auto enc = get_save_encoding (format, encoding, bits_per_sample);
332
439
return sox_encodinginfo_t {
333
- /* encoding=*/ get_encoding (filetype, dtype ),
334
- /* bits_per_sample=*/ get_precision (filetype, dtype ),
440
+ /* encoding=*/ std::get< 0 >(enc ),
441
+ /* bits_per_sample=*/ std::get< 1 >(enc ),
335
442
/* compression=*/ compression.value_or (HUGE_VAL),
336
443
/* reverse_bytes=*/ sox_option_default,
337
444
/* reverse_nibbles=*/ sox_option_default,
0 commit comments