Skip to content

Commit 52e7bfd

Browse files
author
Caroline Chen
authored
Precompute transforms.Resample kernel (#1499)
1 parent 8a86c46 commit 52e7bfd

File tree

2 files changed

+69
-42
lines changed

2 files changed

+69
-42
lines changed

torchaudio/functional/functional.py

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,12 +1299,28 @@ def compute_kaldi_pitch(
12991299

13001300

13011301
def _get_sinc_resample_kernel(
1302-
orig_freq: int,
1303-
new_freq: int,
1302+
orig_freq: float,
1303+
new_freq: float,
1304+
gcd: int,
13041305
lowpass_filter_width: int,
1305-
rolloff: float,
1306-
device: torch.device,
1307-
dtype: torch.dtype):
1306+
rolloff: float):
1307+
1308+
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
1309+
warnings.warn(
1310+
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
1311+
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
1312+
"Using non-integer valued frequencies will throw an error in the next release. "
1313+
"To work around this issue, manually convert both frequencies to integer values "
1314+
"that maintain their resampling rate ratio before passing them into the function "
1315+
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
1316+
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
1317+
"For more information or to leave feedback about this change, please refer to "
1318+
"https://github.com/pytorch/audio/issues/1487."
1319+
)
1320+
1321+
orig_freq = int(orig_freq) // gcd
1322+
new_freq = int(new_freq) // gcd
1323+
13081324
assert lowpass_filter_width > 0
13091325
kernels = []
13101326
base_freq = min(orig_freq, new_freq)
@@ -1336,7 +1352,7 @@ def _get_sinc_resample_kernel(
13361352
# they will have a lot of almost zero values to the left or to the right...
13371353
# There is probably a way to evaluate those filters more efficiently, but this is kept for
13381354
# future work.
1339-
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)
1355+
idx = torch.arange(-width, width + orig_freq)
13401356

13411357
for i in range(new_freq):
13421358
t = (-i / new_freq + idx / orig_freq) * base_freq
@@ -1353,6 +1369,34 @@ def _get_sinc_resample_kernel(
13531369
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width
13541370

13551371

1372+
def _apply_sinc_resample_kernel(
1373+
waveform: Tensor,
1374+
orig_freq: float,
1375+
new_freq: float,
1376+
gcd: int,
1377+
kernel: Tensor,
1378+
width: int,
1379+
):
1380+
orig_freq = int(orig_freq) // gcd
1381+
new_freq = int(new_freq) // gcd
1382+
1383+
# pack batch
1384+
shape = waveform.size()
1385+
waveform = waveform.view(-1, shape[-1])
1386+
kernel = kernel.to(device=waveform.device, dtype=waveform.dtype)
1387+
1388+
num_wavs, length = waveform.shape
1389+
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
1390+
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
1391+
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
1392+
target_length = int(math.ceil(new_freq * length / orig_freq))
1393+
resampled = resampled[..., :target_length]
1394+
1395+
# unpack batch
1396+
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
1397+
return resampled
1398+
1399+
13561400
def resample(
13571401
waveform: Tensor,
13581402
orig_freq: float,
@@ -1380,42 +1424,15 @@ def resample(
13801424
13811425
Returns:
13821426
Tensor: The waveform at the new frequency of dimension (..., time).
1427+
1428+
Note: ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
1429+
more efficient computation if resampling multiple waveforms with the same resampling parameters.
13831430
"""
1384-
# pack batch
1385-
shape = waveform.size()
1386-
waveform = waveform.view(-1, shape[-1])
13871431

13881432
assert orig_freq > 0.0 and new_freq > 0.0
13891433

1390-
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
1391-
warnings.warn(
1392-
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
1393-
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
1394-
"Using non-integer valued frequencies will throw an error in the next release. "
1395-
"To work around this issue, manually convert both frequencies to integer values "
1396-
"that maintain their resampling rate ratio before passing them into the function "
1397-
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
1398-
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
1399-
"For more information or to leave feedback about this change, please refer to "
1400-
"https://github.com/pytorch/audio/issues/1487."
1401-
)
1402-
1403-
orig_freq = int(orig_freq)
1404-
new_freq = int(new_freq)
1405-
gcd = math.gcd(orig_freq, new_freq)
1406-
orig_freq = orig_freq // gcd
1407-
new_freq = new_freq // gcd
1408-
1409-
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
1410-
rolloff, waveform.device, waveform.dtype)
1411-
1412-
num_wavs, length = waveform.shape
1413-
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
1414-
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
1415-
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
1416-
target_length = int(math.ceil(new_freq * length / orig_freq))
1417-
resampled = resampled[..., :target_length]
1434+
gcd = math.gcd(int(orig_freq), int(new_freq))
14181435

1419-
# unpack batch
1420-
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
1436+
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff)
1437+
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
14211438
return resampled

torchaudio/transforms.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from torch import Tensor
99
from torchaudio import functional as F
1010

11+
from .functional.functional import (
12+
_get_sinc_resample_kernel,
13+
_apply_sinc_resample_kernel,
14+
)
1115

1216
__all__ = [
1317
'Spectrogram',
@@ -661,18 +665,23 @@ class Resample(torch.nn.Module):
661665
"""
662666

663667
def __init__(self,
664-
orig_freq: int = 16000,
665-
new_freq: int = 16000,
668+
orig_freq: float = 16000,
669+
new_freq: float = 16000,
666670
resampling_method: str = 'sinc_interpolation',
667671
lowpass_filter_width: int = 6,
668672
rolloff: float = 0.99) -> None:
669673
super(Resample, self).__init__()
674+
670675
self.orig_freq = orig_freq
671676
self.new_freq = new_freq
677+
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
672678
self.resampling_method = resampling_method
673679
self.lowpass_filter_width = lowpass_filter_width
674680
self.rolloff = rolloff
675681

682+
self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
683+
self.lowpass_filter_width, self.rolloff)
684+
676685
def forward(self, waveform: Tensor) -> Tensor:
677686
r"""
678687
Args:
@@ -682,7 +691,8 @@ def forward(self, waveform: Tensor) -> Tensor:
682691
Tensor: Output signal of dimension (..., time).
683692
"""
684693
if self.resampling_method == 'sinc_interpolation':
685-
return F.resample(waveform, self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff)
694+
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
695+
self.kernel, self.width)
686696

687697
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
688698

0 commit comments

Comments
 (0)