Skip to content

Commit f65959d

Browse files
committed
Add Tacotron2 loss function
1 parent d6ae55c commit f65959d

File tree

2 files changed

+167
-0
lines changed

2 files changed

+167
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of the NVIDIA CORPORATION nor the
12+
# names of its contributors may be used to endorse or promote products
13+
# derived from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
#
26+
# *****************************************************************************
27+
28+
from typing import Tuple
29+
30+
from torch import nn, Tensor
31+
32+
33+
class Tacotron2Loss(nn.Module):
34+
"""Tacotron2 loss function adapted from:
35+
https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py
36+
"""
37+
38+
def __init__(self):
39+
super().__init__()
40+
41+
def forward(self,
42+
model_outputs: Tuple[Tensor, Tensor, Tensor],
43+
targets: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
44+
r"""Pass the input through the Tacotron2 loss.
45+
46+
Args:
47+
model_outputs (tuple of three Tensors): The outputs of the
48+
Tacotron2. These outputs should include three items:
49+
(1) the predicted mel spectrogram before the postnet (mel_specgram) with shape (n_batch, n_mel, n_time),
50+
(2) predicted mel spectrogram after the postnet (mel_specgram_postnet)
51+
with shape (n_batch, n_mel, n_time), and
52+
(3) the stop token prediction (gate_out) with shape (n_batch).
53+
targets (tuple of two Tensors): The ground truth mel spectrogram (n_batch, n_mel, n_time) and
54+
stop token with shape (n_batch).
55+
56+
Returns:
57+
mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram with shape (n_batch, ).
58+
mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and
59+
ground truth mel spectrogram with shape (n_batch, ).
60+
gate_loss (Tensor): The mean binary cross entropy loss of
61+
the prediction on the stop token with shape (n_batch, ).
62+
"""
63+
mel_target, gate_target = targets[0], targets[1]
64+
mel_target.requires_grad = False
65+
gate_target.requires_grad = False
66+
gate_target = gate_target.view(-1, 1)
67+
68+
mel_specgram, mel_specgram_postnet, gate_out = model_outputs
69+
gate_out = gate_out.view(-1, 1)
70+
mel_loss = nn.MSELoss(reduction="mean")(mel_specgram, mel_target)
71+
mel_postnet_loss = nn.MSELoss(reduction="mean")(mel_specgram_postnet, mel_target)
72+
gate_loss = nn.BCEWithLogitsLoss(reduction="mean")(gate_out, gate_target)
73+
return mel_loss, mel_postnet_loss, gate_loss
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import os
2+
import unittest
3+
import tempfile
4+
5+
import torch
6+
from torch.autograd import gradcheck, gradgradcheck
7+
8+
from loss_function import Tacotron2Loss
9+
10+
11+
class TempDirMixin:
12+
"""Mixin to provide easy access to temp dir"""
13+
temp_dir_ = None
14+
15+
@classmethod
16+
def get_base_temp_dir(cls):
17+
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
18+
# this is handy for debugging.
19+
key = 'TORCHAUDIO_TEST_TEMP_DIR'
20+
if key in os.environ:
21+
return os.environ[key]
22+
if cls.temp_dir_ is None:
23+
cls.temp_dir_ = tempfile.TemporaryDirectory()
24+
return cls.temp_dir_.name
25+
26+
@classmethod
27+
def tearDownClass(cls):
28+
super().tearDownClass()
29+
if cls.temp_dir_ is not None:
30+
cls.temp_dir_.cleanup()
31+
cls.temp_dir_ = None
32+
33+
def get_temp_path(self, *paths):
34+
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
35+
path = os.path.join(temp_dir, *paths)
36+
os.makedirs(os.path.dirname(path), exist_ok=True)
37+
return path
38+
39+
40+
class Tacotron2LossTest(unittest.TestCase, TempDirMixin):
41+
42+
dtype = torch.float64
43+
device = "cpu"
44+
45+
def _assert_torchscript_consistency(self, fn, tensors):
46+
path = self.get_temp_path('func.zip')
47+
torch.jit.script(fn).save(path)
48+
ts_func = torch.jit.load(path)
49+
50+
torch.random.manual_seed(40)
51+
output = fn(*tensors)
52+
53+
torch.random.manual_seed(40)
54+
ts_output = ts_func(*tensors)
55+
56+
self.assertEqual(ts_output, output)
57+
58+
def _get_inputs(self):
59+
n_mel, n_batch, max_mel_specgram_length = 10, 8, 20
60+
mel_specgram = torch.rand(n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device)
61+
mel_specgram_postnet = torch.rand(n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device)
62+
gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device)
63+
truth_mel_specgram = torch.rand(n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device)
64+
truth_gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device)
65+
66+
return mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out
67+
68+
def test_torchscript_consistency(self):
69+
f"""Validate the torchscript consistency of Tacotron2Loss."""
70+
71+
def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out):
72+
loss_fn = Tacotron2Loss()
73+
return loss_fn((mel_specgram, mel_specgram_postnet, gate_out), (truth_mel_specgram, truth_gate_out))
74+
75+
self._assert_torchscript_consistency(_fn, self._get_inputs())
76+
77+
def test_gradcheck(self):
78+
f"""Performing gradient check on Tacotron2Loss."""
79+
80+
mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out = self._get_inputs()
81+
82+
mel_specgram.requires_grad_(True)
83+
mel_specgram_postnet.requires_grad_(True)
84+
gate_out.requires_grad_(True)
85+
86+
def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out):
87+
loss_fn = Tacotron2Loss()
88+
return loss_fn((mel_specgram, mel_specgram_postnet, gate_out), (truth_mel_specgram, truth_gate_out))
89+
90+
gradcheck(_fn, (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out))
91+
gradgradcheck(_fn, (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out))
92+
93+
if __name__ == "__main__":
94+
unittest.main()

0 commit comments

Comments
 (0)