Skip to content

Commit 1b52e72

Browse files
authored
Add Tacotron2 loss function (#1625)
1 parent 37dbf29 commit 1b52e72

File tree

6 files changed

+242
-0
lines changed

6 files changed

+242
-0
lines changed

examples/pipeline_tacotron2/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This is an example pipeline for text-to-speech using Tacotron2.

examples/pipeline_tacotron2/loss.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of the NVIDIA CORPORATION nor the
12+
# names of its contributors may be used to endorse or promote products
13+
# derived from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
#
26+
# *****************************************************************************
27+
28+
from typing import Tuple
29+
30+
from torch import nn, Tensor
31+
32+
33+
class Tacotron2Loss(nn.Module):
34+
"""Tacotron2 loss function modified from:
35+
https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py
36+
"""
37+
38+
def __init__(self):
39+
super().__init__()
40+
41+
self.mse_loss = nn.MSELoss(reduction="mean")
42+
self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
43+
44+
def forward(
45+
self,
46+
model_outputs: Tuple[Tensor, Tensor, Tensor],
47+
targets: Tuple[Tensor, Tensor],
48+
) -> Tuple[Tensor, Tensor, Tensor]:
49+
r"""Pass the input through the Tacotron2 loss.
50+
51+
The original implementation was introduced in
52+
*Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
53+
[:footcite:`shen2018natural`].
54+
55+
Args:
56+
model_outputs (tuple of three Tensors): The outputs of the
57+
Tacotron2. These outputs should include three items:
58+
(1) the predicted mel spectrogram before the postnet (``mel_specgram``)
59+
with shape (batch, mel, time).
60+
(2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``)
61+
with shape (batch, mel, time), and
62+
(3) the stop token prediction (``gate_out``) with shape (batch, ).
63+
targets (tuple of two Tensors): The ground truth mel spectrogram (batch, mel, time) and
64+
stop token with shape (batch, ).
65+
66+
Returns:
67+
mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram
68+
with shape ``torch.Size([])``.
69+
mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and
70+
ground truth mel spectrogram with shape ``torch.Size([])``.
71+
gate_loss (Tensor): The mean binary cross entropy loss of
72+
the prediction on the stop token with shape ``torch.Size([])``.
73+
"""
74+
mel_target, gate_target = targets[0], targets[1]
75+
gate_target = gate_target.view(-1, 1)
76+
77+
mel_specgram, mel_specgram_postnet, gate_out = model_outputs
78+
gate_out = gate_out.view(-1, 1)
79+
mel_loss = self.mse_loss(mel_specgram, mel_target)
80+
mel_postnet_loss = self.mse_loss(mel_specgram_postnet, mel_target)
81+
gate_loss = self.bce_loss(gate_out, gate_target)
82+
return mel_loss, mel_postnet_loss, gate_loss

test/torchaudio_unittest/example/tacotron2/__init__.py

Whitespace-only changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
3+
from .tacotron2_loss_impl import (
4+
Tacotron2LossShapeTests,
5+
Tacotron2LossTorchscriptTests,
6+
Tacotron2LossGradcheckTests,
7+
)
8+
from torchaudio_unittest.common_utils import PytorchTestCase
9+
10+
11+
class TestTacotron2LossShapeFloat32CPU(PytorchTestCase, Tacotron2LossShapeTests):
12+
dtype = torch.float32
13+
device = torch.device("cpu")
14+
15+
16+
class TestTacotron2TorchsciptFloat32CPU(PytorchTestCase, Tacotron2LossTorchscriptTests):
17+
dtype = torch.float32
18+
device = torch.device("cpu")
19+
20+
21+
class TestTacotron2GradcheckFloat64CPU(PytorchTestCase, Tacotron2LossGradcheckTests):
22+
dtype = torch.float64 # gradcheck needs a higher numerical accuracy
23+
device = torch.device("cpu")
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
3+
from .tacotron2_loss_impl import (
4+
Tacotron2LossShapeTests,
5+
Tacotron2LossTorchscriptTests,
6+
Tacotron2LossGradcheckTests,
7+
)
8+
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
9+
10+
11+
@skipIfNoCuda
12+
class TestTacotron2LossShapeFloat32CUDA(PytorchTestCase, Tacotron2LossShapeTests):
13+
dtype = torch.float32
14+
device = torch.device("cuda")
15+
16+
17+
@skipIfNoCuda
18+
class TestTacotron2TorchsciptFloat32CUDA(PytorchTestCase, Tacotron2LossTorchscriptTests):
19+
dtype = torch.float32
20+
device = torch.device("cuda")
21+
22+
23+
@skipIfNoCuda
24+
class TestTacotron2GradcheckFloat64CUDA(PytorchTestCase, Tacotron2LossGradcheckTests):
25+
dtype = torch.float64 # gradcheck needs a higher numerical accuracy
26+
device = torch.device("cuda")
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
from torch.autograd import gradcheck, gradgradcheck
3+
4+
from pipeline_tacotron2.loss import Tacotron2Loss
5+
from torchaudio_unittest.common_utils import TempDirMixin
6+
7+
8+
class Tacotron2LossInputMixin(TempDirMixin):
9+
10+
def _get_inputs(self, n_mel=80, n_batch=16, max_mel_specgram_length=300):
11+
mel_specgram = torch.rand(
12+
n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device
13+
)
14+
mel_specgram_postnet = torch.rand(
15+
n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device
16+
)
17+
gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device)
18+
truth_mel_specgram = torch.rand(
19+
n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device
20+
)
21+
truth_gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device)
22+
23+
truth_mel_specgram.requires_grad = False
24+
truth_gate_out.requires_grad = False
25+
26+
return (
27+
mel_specgram,
28+
mel_specgram_postnet,
29+
gate_out,
30+
truth_mel_specgram,
31+
truth_gate_out,
32+
)
33+
34+
35+
class Tacotron2LossShapeTests(Tacotron2LossInputMixin):
36+
37+
def test_tacotron2_loss_shape(self):
38+
"""Validate the output shape of Tacotron2Loss."""
39+
n_batch = 16
40+
41+
(
42+
mel_specgram,
43+
mel_specgram_postnet,
44+
gate_out,
45+
truth_mel_specgram,
46+
truth_gate_out,
47+
) = self._get_inputs(n_batch=n_batch)
48+
49+
mel_loss, mel_postnet_loss, gate_loss = Tacotron2Loss()(
50+
(mel_specgram, mel_specgram_postnet, gate_out),
51+
(truth_mel_specgram, truth_gate_out)
52+
)
53+
54+
self.assertEqual(mel_loss.size(), torch.Size([]))
55+
self.assertEqual(mel_postnet_loss.size(), torch.Size([]))
56+
self.assertEqual(gate_loss.size(), torch.Size([]))
57+
58+
59+
class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin):
60+
61+
def _assert_torchscript_consistency(self, fn, tensors):
62+
path = self.get_temp_path("func.zip")
63+
torch.jit.script(fn).save(path)
64+
ts_func = torch.jit.load(path)
65+
66+
output = fn(tensors[:3], tensors[3:])
67+
ts_output = ts_func(tensors[:3], tensors[3:])
68+
69+
self.assertEqual(ts_output, output)
70+
71+
def test_tacotron2_loss_torchscript_consistency(self):
72+
"""Validate the torchscript consistency of Tacotron2Loss."""
73+
74+
loss_fn = Tacotron2Loss()
75+
self._assert_torchscript_consistency(loss_fn, self._get_inputs())
76+
77+
78+
class Tacotron2LossGradcheckTests(Tacotron2LossInputMixin):
79+
80+
def test_tacotron2_loss_gradcheck(self):
81+
"""Performing gradient check on Tacotron2Loss."""
82+
(
83+
mel_specgram,
84+
mel_specgram_postnet,
85+
gate_out,
86+
truth_mel_specgram,
87+
truth_gate_out,
88+
) = self._get_inputs()
89+
90+
mel_specgram.requires_grad_(True)
91+
mel_specgram_postnet.requires_grad_(True)
92+
gate_out.requires_grad_(True)
93+
94+
def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out):
95+
loss_fn = Tacotron2Loss()
96+
return loss_fn(
97+
(mel_specgram, mel_specgram_postnet, gate_out),
98+
(truth_mel_specgram, truth_gate_out),
99+
)
100+
101+
gradcheck(
102+
_fn,
103+
(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out),
104+
fast_mode=True,
105+
)
106+
gradgradcheck(
107+
_fn,
108+
(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out),
109+
fast_mode=True,
110+
)

0 commit comments

Comments
 (0)