@@ -1299,12 +1299,28 @@ def compute_kaldi_pitch(
1299
1299
1300
1300
1301
1301
def _get_sinc_resample_kernel (
1302
- orig_freq : int ,
1303
- new_freq : int ,
1302
+ orig_freq : float ,
1303
+ new_freq : float ,
1304
+ gcd : int ,
1304
1305
lowpass_filter_width : int ,
1305
- rolloff : float ,
1306
- device : torch .device ,
1307
- dtype : torch .dtype ):
1306
+ rolloff : float ):
1307
+
1308
+ if not (int (orig_freq ) == orig_freq and int (new_freq ) == new_freq ):
1309
+ warnings .warn (
1310
+ "Non-integer frequencies are being cast to ints and may result in poor resampling quality "
1311
+ "because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
1312
+ "Using non-integer valued frequencies will throw an error in the next release. "
1313
+ "To work around this issue, manually convert both frequencies to integer values "
1314
+ "that maintain their resampling rate ratio before passing them into the function "
1315
+ "Example: To downsample a 44100 hz waveform by a factor of 8, use "
1316
+ "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
1317
+ "For more information or to leave feedback about this change, please refer to "
1318
+ "https://github.com/pytorch/audio/issues/1487."
1319
+ )
1320
+
1321
+ orig_freq = int (orig_freq ) // gcd
1322
+ new_freq = int (new_freq ) // gcd
1323
+
1308
1324
assert lowpass_filter_width > 0
1309
1325
kernels = []
1310
1326
base_freq = min (orig_freq , new_freq )
@@ -1336,7 +1352,7 @@ def _get_sinc_resample_kernel(
1336
1352
# they will have a lot of almost zero values to the left or to the right...
1337
1353
# There is probably a way to evaluate those filters more efficiently, but this is kept for
1338
1354
# future work.
1339
- idx = torch .arange (- width , width + orig_freq , device = device , dtype = dtype )
1355
+ idx = torch .arange (- width , width + orig_freq )
1340
1356
1341
1357
for i in range (new_freq ):
1342
1358
t = (- i / new_freq + idx / orig_freq ) * base_freq
@@ -1353,6 +1369,34 @@ def _get_sinc_resample_kernel(
1353
1369
return torch .stack (kernels ).view (new_freq , 1 , - 1 ).mul_ (scale ), width
1354
1370
1355
1371
1372
+ def _apply_sinc_resample_kernel (
1373
+ waveform : Tensor ,
1374
+ orig_freq : float ,
1375
+ new_freq : float ,
1376
+ gcd : int ,
1377
+ kernel : Tensor ,
1378
+ width : int ,
1379
+ ):
1380
+ orig_freq = int (orig_freq ) // gcd
1381
+ new_freq = int (new_freq ) // gcd
1382
+
1383
+ # pack batch
1384
+ shape = waveform .size ()
1385
+ waveform = waveform .view (- 1 , shape [- 1 ])
1386
+ kernel = kernel .to (device = waveform .device , dtype = waveform .dtype )
1387
+
1388
+ num_wavs , length = waveform .shape
1389
+ waveform = torch .nn .functional .pad (waveform , (width , width + orig_freq ))
1390
+ resampled = torch .nn .functional .conv1d (waveform [:, None ], kernel , stride = orig_freq )
1391
+ resampled = resampled .transpose (1 , 2 ).reshape (num_wavs , - 1 )
1392
+ target_length = int (math .ceil (new_freq * length / orig_freq ))
1393
+ resampled = resampled [..., :target_length ]
1394
+
1395
+ # unpack batch
1396
+ resampled = resampled .view (shape [:- 1 ] + resampled .shape [- 1 :])
1397
+ return resampled
1398
+
1399
+
1356
1400
def resample (
1357
1401
waveform : Tensor ,
1358
1402
orig_freq : float ,
@@ -1380,42 +1424,15 @@ def resample(
1380
1424
1381
1425
Returns:
1382
1426
Tensor: The waveform at the new frequency of dimension (..., time).
1427
+
1428
+ Note: ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
1429
+ more efficient computation if resampling multiple waveforms with the same resampling parameters.
1383
1430
"""
1384
- # pack batch
1385
- shape = waveform .size ()
1386
- waveform = waveform .view (- 1 , shape [- 1 ])
1387
1431
1388
1432
assert orig_freq > 0.0 and new_freq > 0.0
1389
1433
1390
- if not (int (orig_freq ) == orig_freq and int (new_freq ) == new_freq ):
1391
- warnings .warn (
1392
- "Non-integer frequencies are being cast to ints and may result in poor resampling quality "
1393
- "because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
1394
- "Using non-integer valued frequencies will throw an error in the next release. "
1395
- "To work around this issue, manually convert both frequencies to integer values "
1396
- "that maintain their resampling rate ratio before passing them into the function "
1397
- "Example: To downsample a 44100 hz waveform by a factor of 8, use "
1398
- "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
1399
- "For more information or to leave feedback about this change, please refer to "
1400
- "https://github.com/pytorch/audio/issues/1487."
1401
- )
1402
-
1403
- orig_freq = int (orig_freq )
1404
- new_freq = int (new_freq )
1405
- gcd = math .gcd (orig_freq , new_freq )
1406
- orig_freq = orig_freq // gcd
1407
- new_freq = new_freq // gcd
1408
-
1409
- kernel , width = _get_sinc_resample_kernel (orig_freq , new_freq , lowpass_filter_width ,
1410
- rolloff , waveform .device , waveform .dtype )
1411
-
1412
- num_wavs , length = waveform .shape
1413
- waveform = torch .nn .functional .pad (waveform , (width , width + orig_freq ))
1414
- resampled = torch .nn .functional .conv1d (waveform [:, None ], kernel , stride = orig_freq )
1415
- resampled = resampled .transpose (1 , 2 ).reshape (num_wavs , - 1 )
1416
- target_length = int (math .ceil (new_freq * length / orig_freq ))
1417
- resampled = resampled [..., :target_length ]
1434
+ gcd = math .gcd (int (orig_freq ), int (new_freq ))
1418
1435
1419
- # unpack batch
1420
- resampled = resampled . view ( shape [: - 1 ] + resampled . shape [ - 1 :] )
1436
+ kernel , width = _get_sinc_resample_kernel ( orig_freq , new_freq , gcd , lowpass_filter_width , rolloff )
1437
+ resampled = _apply_sinc_resample_kernel ( waveform , orig_freq , new_freq , gcd , kernel , width )
1421
1438
return resampled
0 commit comments