Skip to content

Commit ad534c1

Browse files
authored
Add autograd test to T.TimeStretch (and F.phase_vocoder) (#1420)
1 parent 5c696b5 commit ad534c1

File tree

1 file changed

+54
-2
lines changed

1 file changed

+54
-2
lines changed

test/torchaudio_unittest/transforms/autograd_test_impl.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List
2+
import unittest
23

34
from parameterized import parameterized
45
import torch
@@ -35,10 +36,16 @@ def assert_grad(
3536
):
3637
transform = transform.to(dtype=torch.float64, device=self.device)
3738

39+
# gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or
40+
# `torch.cdouble`, when the default eps and tolerance values are used.
3841
inputs_ = []
3942
for i in inputs:
40-
i.requires_grad = True
41-
inputs_.append(i.to(dtype=torch.float64, device=self.device))
43+
if torch.is_tensor(i):
44+
i = i.to(
45+
dtype=torch.cdouble if i.is_complex() else torch.double,
46+
device=self.device)
47+
i.requires_grad = True
48+
inputs_.append(i)
4249
assert gradcheck(transform, inputs_)
4350
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)
4451

@@ -129,3 +136,48 @@ def test_amplitude_to_db(self):
129136
transform = T.AmplitudeToDB()
130137
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
131138
self.assert_grad(transform, [waveform])
139+
140+
@unittest.expectedFailure
141+
def test_timestretch_zeros_fail(self):
142+
"""Test that ``T.TimeStretch`` fails gradcheck at 0
143+
144+
This is because ``F.phase_vocoder`` converts data from cartesian to polar coordinate,
145+
which performs ``atan2(img, real)``, and gradient is not defined at 0.
146+
"""
147+
n_fft = 16
148+
transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=0.99)
149+
waveform = torch.zeros(2, 40)
150+
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
151+
self.assert_grad(transform, [spectrogram])
152+
153+
@nested_params(
154+
[0.7, 0.8, 0.9, 1.0, 1.3],
155+
[False, True],
156+
)
157+
def test_timestretch_non_zero(self, rate, test_pseudo_complex):
158+
"""Verify that ``T.TimeStretch`` does not fail if it's not close to 0
159+
160+
``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
161+
for cases where input is not zero.
162+
163+
As tested above, when spectrogram contains values close to zero, the gradients are unstable
164+
and gradcheck fails.
165+
166+
In this test, we generate spectrogram from random signal, then we push the points around
167+
zero away from the origin.
168+
169+
This process does not reflect the real use-case, and it is not practical for users, but
170+
this helps us understand to what degree the function is differentiable and when not.
171+
"""
172+
n_fft = 16
173+
transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate)
174+
waveform = get_whitenoise(sample_rate=40, duration=1, n_channels=2)
175+
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
176+
177+
# 1e-3 is too small (on CPU)
178+
epsilon = 1e-2
179+
too_close = spectrogram.abs() < epsilon
180+
spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs()
181+
if test_pseudo_complex:
182+
spectrogram = torch.view_as_real(spectrogram)
183+
self.assert_grad(transform, [spectrogram])

0 commit comments

Comments
 (0)