Skip to content

Commit 135e966

Browse files
authored
Refactor get_encodinginfo logic (#1233)
* Distinguish get_encodinginfo for Tensor I/O and save output * Isolate get_tensor_encodinginfo so as not to use the same helper function
1 parent 8b93bd6 commit 135e966

File tree

4 files changed

+40
-19
lines changed

4 files changed

+40
-19
lines changed

torchaudio/csrc/sox/effects.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ std::tuple<torch::Tensor, int64_t> apply_effects_tensor(
6060
// Create SoxEffectsChain
6161
const auto dtype = waveform.dtype();
6262
torchaudio::sox_effects_chain::SoxEffectsChain chain(
63-
/*input_encoding=*/get_encodinginfo("wav", dtype),
64-
/*output_encoding=*/get_encodinginfo("wav", dtype));
63+
/*input_encoding=*/get_tensor_encodinginfo(dtype),
64+
/*output_encoding=*/get_tensor_encodinginfo(dtype));
6565

6666
// Prepare output buffer
6767
std::vector<sox_sample_t> out_buffer;
@@ -112,7 +112,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_file(
112112
// Create and run SoxEffectsChain
113113
torchaudio::sox_effects_chain::SoxEffectsChain chain(
114114
/*input_encoding=*/sf->encoding,
115-
/*output_encoding=*/get_encodinginfo("wav", dtype));
115+
/*output_encoding=*/get_tensor_encodinginfo(dtype));
116116

117117
chain.addInputFile(sf);
118118
for (const auto& effect : effects) {
@@ -214,7 +214,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
214214
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
215215
torchaudio::sox_effects_chain::SoxEffectsChain chain(
216216
/*input_encoding=*/sf->encoding,
217-
/*output_encoding=*/get_encodinginfo("wav", dtype));
217+
/*output_encoding=*/get_tensor_encodinginfo(dtype));
218218
chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj);
219219
for (const auto& effect : effects) {
220220
chain.addEffect(effect);

torchaudio/csrc/sox/io.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ void save_audio_file(
143143
}
144144
const auto signal_info =
145145
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
146-
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression);
146+
const auto encoding_info =
147+
get_encodinginfo_for_save(filetype, tgt_dtype, compression);
147148

148149
SoxFormat sf(sox_open_write(
149150
path.c_str(),
@@ -158,7 +159,7 @@ void save_audio_file(
158159
}
159160

160161
torchaudio::sox_effects_chain::SoxEffectsChain chain(
161-
/*input_encoding=*/get_encodinginfo("wav", tensor.dtype()),
162+
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
162163
/*output_encoding=*/sf->encoding);
163164
chain.addInputTensor(&tensor, sample_rate, channels_first);
164165
chain.addOutputFile(sf);
@@ -281,7 +282,8 @@ void save_audio_fileobj(
281282
}
282283
const auto signal_info =
283284
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
284-
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression);
285+
const auto encoding_info =
286+
get_encodinginfo_for_save(filetype, tgt_dtype, compression);
285287

286288
AutoReleaseBuffer buffer;
287289

@@ -299,7 +301,7 @@ void save_audio_fileobj(
299301
}
300302

301303
torchaudio::sox_effects_chain::SoxEffectsChain chain(
302-
/*input_encoding=*/get_encodinginfo("wav", tensor.dtype()),
304+
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
303305
/*output_encoding=*/sf->encoding);
304306
chain.addInputTensor(&tensor, sample_rate, channels_first);
305307
chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj);

torchaudio/csrc/sox/utils.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,20 +291,40 @@ sox_signalinfo_t get_signalinfo(
291291
/*length=*/static_cast<uint64_t>(waveform->numel())};
292292
}
293293

294-
sox_encodinginfo_t get_encodinginfo(
295-
const std::string filetype,
296-
const caffe2::TypeMeta dtype) {
294+
sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) {
295+
sox_encoding_t encoding = [&]() {
296+
if (dtype == torch::kUInt8)
297+
return SOX_ENCODING_UNSIGNED;
298+
if (dtype == torch::kInt16)
299+
return SOX_ENCODING_SIGN2;
300+
if (dtype == torch::kInt32)
301+
return SOX_ENCODING_SIGN2;
302+
if (dtype == torch::kFloat32)
303+
return SOX_ENCODING_FLOAT;
304+
throw std::runtime_error("Unsupported dtype.");
305+
}();
306+
unsigned bits_per_sample = [&]() {
307+
if (dtype == torch::kUInt8)
308+
return 8;
309+
if (dtype == torch::kInt16)
310+
return 16;
311+
if (dtype == torch::kInt32)
312+
return 32;
313+
if (dtype == torch::kFloat32)
314+
return 32;
315+
throw std::runtime_error("Unsupported dtype.");
316+
}();
297317
return sox_encodinginfo_t{
298-
/*encoding=*/get_encoding(filetype, dtype),
299-
/*bits_per_sample=*/get_precision(filetype, dtype),
318+
/*encoding=*/encoding,
319+
/*bits_per_sample=*/bits_per_sample,
300320
/*compression=*/HUGE_VAL,
301321
/*reverse_bytes=*/sox_option_default,
302322
/*reverse_nibbles=*/sox_option_default,
303323
/*reverse_bits=*/sox_option_default,
304324
/*opposite_endian=*/sox_false};
305325
}
306326

307-
sox_encodinginfo_t get_encodinginfo(
327+
sox_encodinginfo_t get_encodinginfo_for_save(
308328
const std::string filetype,
309329
const caffe2::TypeMeta dtype,
310330
c10::optional<double>& compression) {

torchaudio/csrc/sox/utils.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,11 @@ sox_signalinfo_t get_signalinfo(
108108
const std::string filetype,
109109
const bool channels_first);
110110

111-
/// Get sox_encofinginfo_t for saving audoi file
112-
sox_encodinginfo_t get_encodinginfo(
113-
const std::string filetype,
114-
const caffe2::TypeMeta dtype);
111+
/// Get sox_encodinginfo_t for Tensor I/O
112+
sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype);
115113

116-
sox_encodinginfo_t get_encodinginfo(
114+
/// Get sox_encodinginfo_t for saving to file/file object
115+
sox_encodinginfo_t get_encodinginfo_for_save(
117116
const std::string filetype,
118117
const caffe2::TypeMeta dtype,
119118
c10::optional<double>& compression);

0 commit comments

Comments
 (0)