Skip to content

Commit 69029b1

Browse files
committed
select autograd test from carolineechen#2
1 parent 838e1e0 commit 69029b1

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
from .autograd_impl import Autograd
3+
from torchaudio_unittest import common_utils
4+
from .utils import skipIfNoTransducer
5+
6+
7+
@skipIfNoTransducer
8+
class TestAutograd(Autograd, common_utils.PytorchTestCase):
9+
dtype = torch.float32
10+
device = torch.device('cpu')
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
from .autograd_impl import Autograd
3+
from torchaudio_unittest import common_utils
4+
from .utils import skipIfNoTransducer
5+
6+
7+
@skipIfNoTransducer
8+
class TestAutograd(Autograd, common_utils.PytorchTestCase):
9+
dtype = torch.float32
10+
device = torch.device('cuda')
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import Callable, Tuple
2+
import torch
3+
from torch import Tensor
4+
from torch.autograd import gradcheck
5+
from torchaudio_unittest.common_utils import (
6+
TestBaseMixin,
7+
)
8+
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss
9+
from parameterized import parameterized
10+
from .utils import (
11+
numpy_to_torch,
12+
get_B1_T10_U3_D4_data,
13+
get_B1_T10_U3_D4_data,
14+
get_numpy_data_B2_T4_U3_D3,
15+
get_numpy_data_B1_T2_U3_D5
16+
)
17+
from .numpy_transducer import NumpyTransducerLoss
18+
19+
20+
class Autograd(TestBaseMixin):
21+
@staticmethod
22+
def get_data(data_func, device):
23+
data_np = data_func()
24+
if type(data_np) == tuple:
25+
print("reference gradient")
26+
print(data_np[-1])
27+
data_np = data_np[0]
28+
data = numpy_to_torch(
29+
data=data_np, device=device, requires_grad=True
30+
)
31+
return data
32+
33+
def assert_grad(
34+
self,
35+
loss: Callable[..., Tensor],
36+
inputs: Tuple[torch.Tensor],
37+
*,
38+
enable_all_grad: bool = True,
39+
):
40+
# inputs_ = []
41+
# for i in inputs:
42+
# if torch.is_tensor(i):
43+
# i = i.to(dtype=self.dtype, device=self.device)
44+
# if enable_all_grad:
45+
# i.requires_grad = True
46+
# inputs_.append(i)
47+
assert gradcheck(loss, inputs, eps=1e-03, atol=1e-02, rtol=1e-02, nondet_tol=0.)
48+
49+
@parameterized.expand([
50+
# (get_B1_T10_U3_D4_data, ),
51+
(get_numpy_data_B2_T4_U3_D3, ),
52+
(get_numpy_data_B1_T2_U3_D5, ),
53+
])
54+
def test_RNNTLoss_gradcheck(self, data_func):
55+
data = self.get_data(data_func, self.device)
56+
inputs = (
57+
data["logits"].to(self.dtype),
58+
data["targets"],
59+
data["logit_lengths"],
60+
data["target_lengths"],
61+
)
62+
loss = RNNTLoss(blank=data["blank"])
63+
64+
self.assert_grad(loss, inputs, enable_all_grad=False)
65+
66+
@parameterized.expand([
67+
# (get_B1_T10_U3_D4_data, ),
68+
(get_numpy_data_B2_T4_U3_D3, ),
69+
(get_numpy_data_B1_T2_U3_D5, ),
70+
])
71+
def test_np_transducer_gradcheck(self, data_func):
72+
data = self.get_data(data_func, self.device)
73+
inputs = (
74+
data["logits"].to(self.dtype),
75+
data["logit_lengths"],
76+
data["target_lengths"],
77+
data["targets"],
78+
)
79+
loss = NumpyTransducerLoss(blank=data["blank"])
80+
81+
self.assert_grad(loss, inputs, enable_all_grad=False)

0 commit comments

Comments
 (0)