|
1 | 1 | from typing import List
|
| 2 | +import unittest |
2 | 3 |
|
3 | 4 | from parameterized import parameterized
|
4 | 5 | import torch
|
@@ -35,10 +36,16 @@ def assert_grad(
|
35 | 36 | ):
|
36 | 37 | transform = transform.to(dtype=torch.float64, device=self.device)
|
37 | 38 |
|
| 39 | + # gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or |
| 40 | + # `torch.cdouble`, when the default eps and tolerance values are used. |
38 | 41 | inputs_ = []
|
39 | 42 | for i in inputs:
|
40 |
| - i.requires_grad = True |
41 |
| - inputs_.append(i.to(dtype=torch.float64, device=self.device)) |
| 43 | + if torch.is_tensor(i): |
| 44 | + i = i.to( |
| 45 | + dtype=torch.cdouble if i.is_complex() else torch.double, |
| 46 | + device=self.device) |
| 47 | + i.requires_grad = True |
| 48 | + inputs_.append(i) |
42 | 49 | assert gradcheck(transform, inputs_)
|
43 | 50 | assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)
|
44 | 51 |
|
@@ -129,3 +136,48 @@ def test_amplitude_to_db(self):
|
129 | 136 | transform = T.AmplitudeToDB()
|
130 | 137 | waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
|
131 | 138 | self.assert_grad(transform, [waveform])
|
| 139 | + |
| 140 | + @unittest.expectedFailure |
| 141 | + def test_timestretch_zeros_fail(self): |
| 142 | + """Test that ``T.TimeStretch`` fails gradcheck at 0 |
| 143 | +
|
| 144 | + This is because ``F.phase_vocoder`` converts data from cartesian to polar coordinate, |
| 145 | + which performs ``atan2(img, real)``, and gradient is not defined at 0. |
| 146 | + """ |
| 147 | + n_fft = 16 |
| 148 | + transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=0.99) |
| 149 | + waveform = torch.zeros(2, 40) |
| 150 | + spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None) |
| 151 | + self.assert_grad(transform, [spectrogram]) |
| 152 | + |
| 153 | + @nested_params( |
| 154 | + [0.7, 0.8, 0.9, 1.0, 1.3], |
| 155 | + [False, True], |
| 156 | + ) |
| 157 | + def test_timestretch_non_zero(self, rate, test_pseudo_complex): |
| 158 | + """Verify that ``T.TimeStretch`` does not fail if it's not close to 0 |
| 159 | +
|
| 160 | + ``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability |
| 161 | + for cases where input is not zero. |
| 162 | +
|
| 163 | + As tested above, when spectrogram contains values close to zero, the gradients are unstable |
| 164 | + and gradcheck fails. |
| 165 | +
|
| 166 | + In this test, we generate spectrogram from random signal, then we push the points around |
| 167 | + zero away from the origin. |
| 168 | +
|
| 169 | + This process does not reflect the real use-case, and it is not practical for users, but |
| 170 | + this helps us understand to what degree the function is differentiable and when not. |
| 171 | + """ |
| 172 | + n_fft = 16 |
| 173 | + transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate) |
| 174 | + waveform = get_whitenoise(sample_rate=40, duration=1, n_channels=2) |
| 175 | + spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None) |
| 176 | + |
| 177 | + # 1e-3 is too small (on CPU) |
| 178 | + epsilon = 1e-2 |
| 179 | + too_close = spectrogram.abs() < epsilon |
| 180 | + spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs() |
| 181 | + if test_pseudo_complex: |
| 182 | + spectrogram = torch.view_as_real(spectrogram) |
| 183 | + self.assert_grad(transform, [spectrogram]) |
0 commit comments