Skip to content

Commit d48089a

Browse files
parmeetcpuhrsch
authored andcommitted
Import torchaudio #1412 c0bfb03
Summary: Import latest from github to fbcode Pass: 951 Skip: 19 Omit: 1 ListingSuccess: 26 Result available at: https://www.internalfb.com/intern/testinfra/testrun/8444249336935844 Reviewed By: mthrok Differential Revision: D27448988 fbshipit-source-id: 61f63ffa1295a31b4452abaf2c74ebfefb827dcf
1 parent f88723c commit d48089a

File tree

12 files changed

+153
-103
lines changed

12 files changed

+153
-103
lines changed

docs/source/functional.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ apply_codec
5555
-----------
5656

5757
.. autofunction:: apply_codec
58-
58+
5959
:hidden:`Complex Utility`
6060
~~~~~~~~~~~~~~~~~~~~~~~~~
6161

@@ -230,3 +230,8 @@ vad
230230
---------------------------
231231

232232
.. autofunction:: spectral_centroid
233+
234+
:hidden:`resample`
235+
---------------------------
236+
237+
.. autofunction:: resample

test/torchaudio_unittest/common_utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
skipIfNoModule,
1818
skipIfNoKaldi,
1919
skipIfNoSox,
20-
skipIfNoSoxBackend,
2120
)
2221
from .wav_utils import (
2322
get_wav_data,

test/torchaudio_unittest/common_utils/case_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import torch
99
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
10-
import torchaudio
1110
from torchaudio._internal.module_utils import (
1211
is_module_available,
1312
is_sox_available,
@@ -96,8 +95,6 @@ def skipIfNoModule(module, display_name=None):
9695
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')
9796

9897

99-
skipIfNoSoxBackend = unittest.skipIf(
100-
'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available')
10198
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
10299
skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available')
103100
skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available')

test/torchaudio_unittest/compliance_kaldi_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
4747

4848
@common_utils.skipIfNoSox
4949
class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
50-
backend = 'sox_io'
5150

5251
kaldi_output_dir = common_utils.get_asset_path('kaldi')
5352
test_filepath = common_utils.get_asset_path('kaldi_file.wav')
@@ -113,7 +112,7 @@ def _create_data_set(self):
113112
# clear the last 16 bits because they aren't used anyways
114113
y = ((y >> 16) << 16).float()
115114
torchaudio.save(self.test_filepath, y, sr)
116-
sound, sample_rate = torchaudio.load(self.test_filepath, normalization=False)
115+
sound, sample_rate = common_utils.load_wav(self.test_filepath, normalize=False)
117116
print(y >> 16)
118117
self.assertTrue(sample_rate == sr)
119118
self.assertEqual(y, sound)

test/torchaudio_unittest/functional/librosa_compatibility_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,25 @@ def test_amplitude_to_DB(self):
109109

110110
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
111111

112+
def test_resample(self):
113+
input_path = common_utils.get_asset_path('sinewave.wav')
114+
waveform, sample_rate = common_utils.load_wav(input_path)
115+
116+
upsample_rate = sample_rate * 2
117+
downsample_rate = sample_rate // 2
118+
119+
ta_upsampled = F.resample(waveform, sample_rate, upsample_rate)
120+
lr_upsampled = librosa.resample(waveform.squeeze(0).numpy(), sample_rate, upsample_rate)
121+
lr_upsampled = torch.from_numpy(lr_upsampled).unsqueeze(0)
122+
123+
self.assertEqual(ta_upsampled, lr_upsampled, atol=1e-2, rtol=1e-5)
124+
125+
ta_downsampled = F.resample(waveform, sample_rate, downsample_rate)
126+
lr_downsampled = librosa.resample(waveform.squeeze(0).numpy(), sample_rate, downsample_rate)
127+
lr_downsampled = torch.from_numpy(lr_downsampled).unsqueeze(0)
128+
129+
self.assertEqual(ta_downsampled, lr_downsampled, atol=1e-2, rtol=1e-5)
130+
112131

113132
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
114133
class TestPhaseVocoder(common_utils.TorchaudioTestCase):

test/torchaudio_unittest/functional/sox_compatibility_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torchaudio.functional as F
33

44
from torchaudio_unittest.common_utils import (
5-
skipIfNoSoxBackend,
5+
skipIfNoSox,
66
skipIfNoExec,
77
TempDirMixin,
88
TorchaudioTestCase,
@@ -14,7 +14,7 @@
1414
)
1515

1616

17-
@skipIfNoSoxBackend
17+
@skipIfNoSox
1818
@skipIfNoExec('sox')
1919
class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
2020
def run_sox_effect(self, input_file, effect):

test/torchaudio_unittest/transforms/sox_compatibility_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22
from parameterized import parameterized
33

44
from torchaudio_unittest.common_utils import (
5-
skipIfNoSoxBackend,
5+
skipIfNoSox,
66
skipIfNoExec,
77
TempDirMixin,
88
TorchaudioTestCase,
99
get_asset_path,
1010
sox_utils,
1111
load_wav,
12+
save_wav,
13+
get_whitenoise,
1214
)
1315

1416

15-
@skipIfNoSoxBackend
17+
@skipIfNoSox
1618
@skipIfNoExec('sox')
1719
class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
1820
def run_sox_effect(self, input_file, effect):
@@ -24,6 +26,14 @@ def assert_sox_effect(self, result, input_path, effects, atol=1e-04, rtol=1e-5):
2426
expected, _ = self.run_sox_effect(input_path, effects)
2527
self.assertEqual(result, expected, atol=atol, rtol=rtol)
2628

29+
def get_whitenoise(self, sample_rate=8000):
30+
noise = get_whitenoise(
31+
sample_rate=sample_rate, duration=3, scale_factor=0.9,
32+
)
33+
path = self.get_temp_path("whitenoise.wav")
34+
save_wav(path, noise, sample_rate)
35+
return noise, path
36+
2737
@parameterized.expand([
2838
('q', 'quarter_sine'),
2939
('h', 'half_sine'),

torchaudio/compliance/kaldi.py

Lines changed: 4 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import math
44
import torch
55
from torch import Tensor
6-
from torch.nn import functional as F
76

87
import torchaudio
98
import torchaudio._internal.fft
@@ -753,71 +752,16 @@ def mfcc(
753752
return feature
754753

755754

756-
def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
757-
device: torch.device, dtype: torch.dtype):
758-
assert lowpass_filter_width > 0
759-
kernels = []
760-
base_freq = min(orig_freq, new_freq)
761-
# This will perform antialiasing filtering by removing the highest frequencies.
762-
# At first I thought I only needed this when downsampling, but when upsampling
763-
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
764-
# which will add high freq artifacts.
765-
base_freq *= 0.99
766-
767-
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
768-
# using the sinc interpolation formula:
769-
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
770-
# We can then sample the function x(t) with a different sample rate:
771-
# y[j] = x(j / new_freq)
772-
# or,
773-
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
774-
775-
# We see here that y[j] is the convolution of x[i] with a specific filter, for which
776-
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
777-
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
778-
# Indeed:
779-
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
780-
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
781-
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
782-
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
783-
# This will explain the F.conv1d after, with a stride of orig_freq.
784-
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
785-
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
786-
# they will have a lot of almost zero values to the left or to the right...
787-
# There is probably a way to evaluate those filters more efficiently, but this is kept for
788-
# future work.
789-
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)
790-
791-
for i in range(new_freq):
792-
t = (-i / new_freq + idx / orig_freq) * base_freq
793-
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
794-
t *= math.pi
795-
# we do not use torch.hann_window here as we need to evaluate the window
796-
# at specific positions, not over a regular grid.
797-
window = torch.cos(t / lowpass_filter_width / 2)**2
798-
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
799-
kernel.mul_(window)
800-
kernels.append(kernel)
801-
802-
scale = base_freq / orig_freq
803-
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width
804-
805-
806755
def resample_waveform(waveform: Tensor,
807756
orig_freq: float,
808757
new_freq: float,
809758
lowpass_filter_width: int = 6) -> Tensor:
810-
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
811-
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
812-
a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
813-
the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to
814-
upsample/downsample the signal.
759+
r"""Resamples the waveform at the new frequency.
815760
816-
https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html
817-
https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56
761+
This is a wrapper around ``torchaudio.functional.resample``.
818762
819763
Args:
820-
waveform (Tensor): The input signal of size (c, n)
764+
waveform (Tensor): The input signal of size (..., time)
821765
orig_freq (float): The original frequency of the signal
822766
new_freq (float): The desired frequency
823767
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
@@ -826,21 +770,4 @@ def resample_waveform(waveform: Tensor,
826770
Returns:
827771
Tensor: The waveform at the new frequency
828772
"""
829-
assert waveform.dim() == 2
830-
assert orig_freq > 0.0 and new_freq > 0.0
831-
832-
orig_freq = int(orig_freq)
833-
new_freq = int(new_freq)
834-
gcd = math.gcd(orig_freq, new_freq)
835-
orig_freq = orig_freq // gcd
836-
new_freq = new_freq // gcd
837-
838-
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
839-
waveform.device, waveform.dtype)
840-
841-
num_wavs, length = waveform.shape
842-
waveform = F.pad(waveform, (width, width + orig_freq))
843-
resampled = F.conv1d(waveform[:, None], kernel, stride=orig_freq)
844-
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
845-
target_length = int(math.ceil(new_freq * length / orig_freq))
846-
return resampled[..., :target_length]
773+
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width)

torchaudio/datasets/yesno.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"release1": {
1717
"folder_in_archive": "waves_yesno",
1818
"url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz",
19-
"checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27",
19+
"checksum": "c3f49e0cca421f96b75b41640749167b52118f232498667ca7a5f9416aef8e73",
2020
}
2121
}
2222

@@ -54,7 +54,7 @@ def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, downloa
5454
if not os.path.isdir(self._path):
5555
if not os.path.isfile(archive):
5656
checksum = _RELEASE_CONFIGS["release1"]["checksum"]
57-
download_url(url, root, hash_value=checksum, hash_type="md5")
57+
download_url(url, root, hash_value=checksum)
5858
extract_archive(archive)
5959

6060
if not os.path.isdir(self._path):

torchaudio/functional/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
spectrogram,
2020
spectral_centroid,
2121
apply_codec,
22+
resample,
2223
)
2324
from .filtering import (
2425
allpass_biquad,
@@ -85,5 +86,6 @@
8586
'riaa_biquad',
8687
'treble_biquad',
8788
'vad',
88-
'apply_codec'
89+
'apply_codec',
90+
'resample',
8991
]

0 commit comments

Comments
 (0)