Skip to content

Commit 3365afa

Browse files
committed
Refactor the code
1 parent b217afd commit 3365afa

File tree

4 files changed

+220
-107
lines changed

4 files changed

+220
-107
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .loss_function import Tacotron2Loss

examples/pipeline_tacotron2/loss_function.py renamed to examples/pipeline_tacotron2/loss/loss_function.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,34 @@ class Tacotron2Loss(nn.Module):
3838
def __init__(self):
3939
super().__init__()
4040

41-
def forward(self,
42-
model_outputs: Tuple[Tensor, Tensor, Tensor],
43-
targets: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
41+
def forward(
42+
self,
43+
model_outputs: Tuple[Tensor, Tensor, Tensor],
44+
targets: Tuple[Tensor, Tensor],
45+
) -> Tuple[Tensor, Tensor, Tensor]:
4446
r"""Pass the input through the Tacotron2 loss.
4547
48+
The original implementation was introduced in
49+
*Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
50+
[:footcite:`shen2018natural`].
51+
4652
Args:
4753
model_outputs (tuple of three Tensors): The outputs of the
4854
Tacotron2. These outputs should include three items:
49-
(1) the predicted mel spectrogram before the postnet (mel_specgram) with shape (n_batch, n_mel, n_time),
50-
(2) predicted mel spectrogram after the postnet (mel_specgram_postnet)
51-
with shape (n_batch, n_mel, n_time), and
52-
(3) the stop token prediction (gate_out) with shape (n_batch).
53-
targets (tuple of two Tensors): The ground truth mel spectrogram (n_batch, n_mel, n_time) and
54-
stop token with shape (n_batch).
55+
(1) the predicted mel spectrogram before the postnet (``mel_specgram``)
56+
with shape (batch, mel, time).
57+
(2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``)
58+
with shape (batch, mel, time), and
59+
(3) the stop token prediction (``gate_out``) with shape (batch).
60+
targets (tuple of two Tensors): The ground truth mel spectrogram (batch, mel, time) and
61+
stop token with shape (batch).
5562
5663
Returns:
57-
mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram with shape (n_batch, ).
64+
mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram with shape (batch, ).
5865
mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and
59-
ground truth mel spectrogram with shape (n_batch, ).
66+
ground truth mel spectrogram with shape (batch, ).
6067
gate_loss (Tensor): The mean binary cross entropy loss of
61-
the prediction on the stop token with shape (n_batch, ).
68+
the prediction on the stop token with shape (batch, ).
6269
"""
6370
mel_target, gate_target = targets[0], targets[1]
6471
mel_target.requires_grad = False
@@ -68,6 +75,8 @@ def forward(self,
6875
mel_specgram, mel_specgram_postnet, gate_out = model_outputs
6976
gate_out = gate_out.view(-1, 1)
7077
mel_loss = nn.MSELoss(reduction="mean")(mel_specgram, mel_target)
71-
mel_postnet_loss = nn.MSELoss(reduction="mean")(mel_specgram_postnet, mel_target)
78+
mel_postnet_loss = nn.MSELoss(reduction="mean")(
79+
mel_specgram_postnet, mel_target
80+
)
7281
gate_loss = nn.BCEWithLogitsLoss(reduction="mean")(gate_out, gate_target)
7382
return mel_loss, mel_postnet_loss, gate_loss
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import os
2+
import unittest
3+
import tempfile
4+
5+
import torch
6+
from torch.autograd import gradcheck, gradgradcheck
7+
8+
from loss_function import Tacotron2Loss
9+
10+
11+
def skipIfNoCuda(test_item):
12+
if torch.cuda.is_available():
13+
return test_item
14+
force_cuda_test = os.environ.get("TORCHAUDIO_TEST_FORCE_CUDA", "0")
15+
if force_cuda_test not in ["0", "1"]:
16+
raise ValueError('"TORCHAUDIO_TEST_FORCE_CUDA" must be either "0" or "1".')
17+
if force_cuda_test == "1":
18+
raise RuntimeError(
19+
'"TORCHAUDIO_TEST_FORCE_CUDA" is set but CUDA is not available.'
20+
)
21+
return unittest.skip("CUDA is not available.")(test_item)
22+
23+
24+
class TempDirMixin:
25+
"""Mixin to provide easy access to temp dir"""
26+
27+
temp_dir_ = None
28+
29+
@classmethod
30+
def get_base_temp_dir(cls):
31+
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
32+
# this is handy for debugging.
33+
key = "TORCHAUDIO_TEST_TEMP_DIR"
34+
if key in os.environ:
35+
return os.environ[key]
36+
if cls.temp_dir_ is None:
37+
cls.temp_dir_ = tempfile.TemporaryDirectory()
38+
return cls.temp_dir_.name
39+
40+
@classmethod
41+
def tearDownClass(cls):
42+
super().tearDownClass()
43+
if cls.temp_dir_ is not None:
44+
cls.temp_dir_.cleanup()
45+
cls.temp_dir_ = None
46+
47+
def get_temp_path(self, *paths):
48+
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
49+
path = os.path.join(temp_dir, *paths)
50+
os.makedirs(os.path.dirname(path), exist_ok=True)
51+
return path
52+
53+
54+
def _get_inputs(dtype, device):
55+
n_mel, n_batch, max_mel_specgram_length = 3, 2, 4
56+
mel_specgram = torch.rand(
57+
n_batch, n_mel, max_mel_specgram_length, dtype=dtype, device=device
58+
)
59+
mel_specgram_postnet = torch.rand(
60+
n_batch, n_mel, max_mel_specgram_length, dtype=dtype, device=device
61+
)
62+
gate_out = torch.rand(n_batch, dtype=dtype, device=device)
63+
truth_mel_specgram = torch.rand(
64+
n_batch, n_mel, max_mel_specgram_length, dtype=dtype, device=device
65+
)
66+
truth_gate_out = torch.rand(n_batch, dtype=dtype, device=device)
67+
68+
return (
69+
mel_specgram,
70+
mel_specgram_postnet,
71+
gate_out,
72+
truth_mel_specgram,
73+
truth_gate_out,
74+
)
75+
76+
77+
class Tacotron2LossTest(unittest.TestCase, TempDirMixin):
78+
79+
dtype = torch.float64
80+
device = "cpu"
81+
82+
def _assert_torchscript_consistency(self, fn, tensors):
83+
path = self.get_temp_path("func.zip")
84+
torch.jit.script(fn).save(path)
85+
ts_func = torch.jit.load(path)
86+
87+
torch.random.manual_seed(40)
88+
output = fn(*tensors)
89+
90+
torch.random.manual_seed(40)
91+
ts_output = ts_func(*tensors)
92+
93+
self.assertEqual(ts_output, output)
94+
95+
def test_cpu_torchscript_consistency(self):
96+
f"""Validate the torchscript consistency of Tacotron2Loss."""
97+
dtype = torch.float32
98+
device = torch.device("cpu")
99+
100+
def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out):
101+
loss_fn = Tacotron2Loss()
102+
return loss_fn(
103+
(mel_specgram, mel_specgram_postnet, gate_out),
104+
(truth_mel_specgram, truth_gate_out),
105+
)
106+
107+
self._assert_torchscript_consistency(_fn, _get_inputs(dtype, device))
108+
109+
@skipIfNoCuda
110+
def test_gpu_torchscript_consistency(self):
111+
f"""Validate the torchscript consistency of Tacotron2Loss."""
112+
dtype = torch.float32
113+
device = torch.device("cuda")
114+
115+
def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out):
116+
loss_fn = Tacotron2Loss()
117+
return loss_fn(
118+
(mel_specgram, mel_specgram_postnet, gate_out),
119+
(truth_mel_specgram, truth_gate_out),
120+
)
121+
122+
self._assert_torchscript_consistency(_fn, self._get_inputs(dtype, device))
123+
124+
def test_cpu_gradcheck(self):
125+
f"""Performing gradient check on Tacotron2Loss."""
126+
dtype = torch.float64 # gradcheck needs a higher numerical accuracy
127+
device = torch.device("cuda")
128+
129+
(
130+
mel_specgram,
131+
mel_specgram_postnet,
132+
gate_out,
133+
truth_mel_specgram,
134+
truth_gate_out,
135+
) = _get_inputs(dtype, device)
136+
137+
mel_specgram.requires_grad_(True)
138+
mel_specgram_postnet.requires_grad_(True)
139+
gate_out.requires_grad_(True)
140+
141+
def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out):
142+
loss_fn = Tacotron2Loss()
143+
return loss_fn(
144+
(mel_specgram, mel_specgram_postnet, gate_out),
145+
(truth_mel_specgram, truth_gate_out),
146+
)
147+
148+
gradcheck(
149+
_fn,
150+
(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out),
151+
fast_mode=True,
152+
)
153+
gradgradcheck(
154+
_fn,
155+
(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out),
156+
fast_mode=True,
157+
)
158+
159+
@skipIfNoCuda
160+
def test_gpu_gradcheck(self):
161+
f"""Performing gradient check on Tacotron2Loss."""
162+
dtype = torch.float64 # gradcheck needs a higher numerical accuracy
163+
device = torch.device("cuda")
164+
165+
(
166+
mel_specgram,
167+
mel_specgram_postnet,
168+
gate_out,
169+
truth_mel_specgram,
170+
truth_gate_out,
171+
) = _get_inputs(dtype, device)
172+
173+
mel_specgram.requires_grad_(True)
174+
mel_specgram_postnet.requires_grad_(True)
175+
gate_out.requires_grad_(True)
176+
177+
def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out):
178+
loss_fn = Tacotron2Loss()
179+
return loss_fn(
180+
(mel_specgram, mel_specgram_postnet, gate_out),
181+
(truth_mel_specgram, truth_gate_out),
182+
)
183+
184+
gradcheck(
185+
_fn,
186+
(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out),
187+
fast_mode=True,
188+
)
189+
gradgradcheck(
190+
_fn,
191+
(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out),
192+
fast_mode=True,
193+
)
194+
195+
196+
if __name__ == "__main__":
197+
unittest.main()

examples/pipeline_tacotron2/test_tacotron2_loss.py

Lines changed: 0 additions & 94 deletions
This file was deleted.

0 commit comments

Comments
 (0)