Skip to content

Commit 1996fec

Browse files
committed
refactor tacotron2 loss tests
1 parent 3365afa commit 1996fec

File tree

5 files changed

+219
-209
lines changed

5 files changed

+219
-209
lines changed

examples/pipeline_tacotron2/loss/loss_function.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ class Tacotron2Loss(nn.Module):
3838
def __init__(self):
3939
super().__init__()
4040

41+
self.mse_loss = nn.MSELoss(reduction="mean")
42+
self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
43+
4144
def forward(
4245
self,
4346
model_outputs: Tuple[Tensor, Tensor, Tensor],
@@ -56,27 +59,24 @@ def forward(
5659
with shape (batch, mel, time).
5760
(2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``)
5861
with shape (batch, mel, time), and
59-
(3) the stop token prediction (``gate_out``) with shape (batch).
62+
(3) the stop token prediction (``gate_out``) with shape (batch, ).
6063
targets (tuple of two Tensors): The ground truth mel spectrogram (batch, mel, time) and
61-
stop token with shape (batch).
64+
stop token with shape (batch, ).
6265
6366
Returns:
64-
mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram with shape (batch, ).
67+
mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram
68+
with shape ``torch.Size([])``.
6569
mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and
66-
ground truth mel spectrogram with shape (batch, ).
70+
ground truth mel spectrogram with shape ``torch.Size([])``.
6771
gate_loss (Tensor): The mean binary cross entropy loss of
68-
the prediction on the stop token with shape (batch, ).
72+
the prediction on the stop token with shape ``torch.Size([])``.
6973
"""
7074
mel_target, gate_target = targets[0], targets[1]
71-
mel_target.requires_grad = False
72-
gate_target.requires_grad = False
7375
gate_target = gate_target.view(-1, 1)
7476

7577
mel_specgram, mel_specgram_postnet, gate_out = model_outputs
7678
gate_out = gate_out.view(-1, 1)
77-
mel_loss = nn.MSELoss(reduction="mean")(mel_specgram, mel_target)
78-
mel_postnet_loss = nn.MSELoss(reduction="mean")(
79-
mel_specgram_postnet, mel_target
80-
)
81-
gate_loss = nn.BCEWithLogitsLoss(reduction="mean")(gate_out, gate_target)
79+
mel_loss = self.mse_loss(mel_specgram, mel_target)
80+
mel_postnet_loss = self.mse_loss(mel_specgram_postnet, mel_target)
81+
gate_loss = self.bce_loss(gate_out, gate_target)
8282
return mel_loss, mel_postnet_loss, gate_loss
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
import unittest
3+
4+
from .tacotron2_loss_impl import (
5+
Tacotron2LossShapeTests,
6+
Tacotron2LossTorchscriptTests,
7+
Tacotron2LossGradcheckTests,
8+
)
9+
10+
11+
class TestTacotron2LossShapeFloat32CPU(unittest.TestCase, Tacotron2LossShapeTests):
12+
dtype = torch.float32
13+
device = torch.device("cpu")
14+
15+
16+
class TestTacotron2TorchsciptFloat32CPU(unittest.TestCase, Tacotron2LossTorchscriptTests):
17+
dtype = torch.float32
18+
device = torch.device("cpu")
19+
20+
21+
class TestTacotron2GradcheckFloat64CPU(unittest.TestCase, Tacotron2LossGradcheckTests):
22+
dtype = torch.float64 # gradcheck needs a higher numerical accuracy
23+
device = torch.device("cpu")
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import unittest
3+
4+
import torch
5+
6+
from .tacotron2_loss_impl import (
7+
Tacotron2LossShapeTests,
8+
Tacotron2LossTorchscriptTests,
9+
Tacotron2LossGradcheckTests,
10+
)
11+
12+
13+
def skipIfNoCuda(test_item):
14+
if torch.cuda.is_available():
15+
return test_item
16+
force_cuda_test = os.environ.get("TORCHAUDIO_TEST_FORCE_CUDA", "0")
17+
if force_cuda_test not in ["0", "1"]:
18+
raise ValueError('"TORCHAUDIO_TEST_FORCE_CUDA" must be either "0" or "1".')
19+
if force_cuda_test == "1":
20+
raise RuntimeError(
21+
'"TORCHAUDIO_TEST_FORCE_CUDA" is set but CUDA is not available.'
22+
)
23+
return unittest.skip("CUDA is not available.")(test_item)
24+
25+
26+
@skipIfNoCuda
27+
class TestTacotron2LossShapeFloat32CUDA(unittest.TestCase, Tacotron2LossShapeTests):
28+
dtype = torch.float32
29+
device = torch.device("cuda")
30+
31+
32+
@skipIfNoCuda
33+
class TestTacotron2TorchsciptFloat32CUDA(unittest.TestCase, Tacotron2LossTorchscriptTests):
34+
dtype = torch.float32
35+
device = torch.device("cuda")
36+
37+
38+
@skipIfNoCuda
39+
class TestTacotron2GradcheckFloat64CUDA(unittest.TestCase, Tacotron2LossGradcheckTests):
40+
dtype = torch.float64 # gradcheck needs a higher numerical accuracy
41+
device = torch.device("cuda")
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import os
2+
import unittest
3+
import tempfile
4+
5+
import torch
6+
from torch.autograd import gradcheck, gradgradcheck
7+
8+
from .loss_function import Tacotron2Loss
9+
10+
11+
class TempDirMixin:
12+
"""Mixin to provide easy access to temp dir"""
13+
14+
temp_dir_ = None
15+
16+
@classmethod
17+
def get_base_temp_dir(cls):
18+
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
19+
# this is handy for debugging.
20+
key = "TORCHAUDIO_TEST_TEMP_DIR"
21+
if key in os.environ:
22+
return os.environ[key]
23+
if cls.temp_dir_ is None:
24+
cls.temp_dir_ = tempfile.TemporaryDirectory()
25+
return cls.temp_dir_.name
26+
27+
@classmethod
28+
def tearDownClass(cls):
29+
super().tearDownClass()
30+
if cls.temp_dir_ is not None:
31+
cls.temp_dir_.cleanup()
32+
cls.temp_dir_ = None
33+
34+
def get_temp_path(self, *paths):
35+
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
36+
path = os.path.join(temp_dir, *paths)
37+
os.makedirs(os.path.dirname(path), exist_ok=True)
38+
return path
39+
40+
41+
class Tacotron2LossInputMixin(TempDirMixin):
42+
43+
def _get_inputs(self, n_mel=80, n_batch=16, max_mel_specgram_length=300):
44+
mel_specgram = torch.rand(
45+
n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device
46+
)
47+
mel_specgram_postnet = torch.rand(
48+
n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device
49+
)
50+
gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device)
51+
truth_mel_specgram = torch.rand(
52+
n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device
53+
)
54+
truth_gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device)
55+
56+
truth_mel_specgram.requires_grad = False
57+
truth_gate_out.requires_grad = False
58+
59+
return (
60+
mel_specgram,
61+
mel_specgram_postnet,
62+
gate_out,
63+
truth_mel_specgram,
64+
truth_gate_out,
65+
)
66+
67+
68+
class Tacotron2LossShapeTests(Tacotron2LossInputMixin):
69+
70+
def test_tacotron2_loss_shape(self):
71+
f"""Validate the output shape of Tacotron2Loss."""
72+
n_batch = 16
73+
74+
(
75+
mel_specgram,
76+
mel_specgram_postnet,
77+
gate_out,
78+
truth_mel_specgram,
79+
truth_gate_out,
80+
) = self._get_inputs(n_batch=n_batch)
81+
82+
mel_loss, mel_postnet_loss, gate_loss = Tacotron2Loss()(
83+
(mel_specgram, mel_specgram_postnet, gate_out),
84+
(truth_mel_specgram, truth_gate_out)
85+
)
86+
87+
self.assertEqual(mel_loss.size(), torch.Size([]))
88+
self.assertEqual(mel_postnet_loss.size(), torch.Size([]))
89+
self.assertEqual(gate_loss.size(), torch.Size([]))
90+
91+
92+
class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin):
93+
94+
def _assert_torchscript_consistency(self, fn, tensors):
95+
path = self.get_temp_path("func.zip")
96+
torch.jit.script(fn).save(path)
97+
ts_func = torch.jit.load(path)
98+
99+
output = fn(tensors[:3], tensors[3:])
100+
ts_output = ts_func(tensors[:3], tensors[3:])
101+
102+
self.assertEqual(ts_output, output)
103+
104+
def test_tacotron2_loss_torchscript_consistency(self):
105+
f"""Validate the torchscript consistency of Tacotron2Loss."""
106+
107+
loss_fn = Tacotron2Loss()
108+
self._assert_torchscript_consistency(loss_fn, self._get_inputs())
109+
110+
111+
class Tacotron2LossGradcheckTests(Tacotron2LossInputMixin):
112+
113+
def test_tacotron2_loss_gradcheck(self):
114+
f"""Performing gradient check on Tacotron2Loss."""
115+
(
116+
mel_specgram,
117+
mel_specgram_postnet,
118+
gate_out,
119+
truth_mel_specgram,
120+
truth_gate_out,
121+
) = self._get_inputs()
122+
123+
mel_specgram.requires_grad_(True)
124+
mel_specgram_postnet.requires_grad_(True)
125+
gate_out.requires_grad_(True)
126+
127+
def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out):
128+
loss_fn = Tacotron2Loss()
129+
return loss_fn(
130+
(mel_specgram, mel_specgram_postnet, gate_out),
131+
(truth_mel_specgram, truth_gate_out),
132+
)
133+
134+
gradcheck(
135+
_fn,
136+
(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out),
137+
fast_mode=True,
138+
)
139+
gradgradcheck(
140+
_fn,
141+
(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out),
142+
fast_mode=True,
143+
)

0 commit comments

Comments
 (0)