Skip to content

Commit b6288bc

Browse files
committed
Add the test for tacotron2 loss to CI
1 parent 7c522f0 commit b6288bc

File tree

8 files changed

+37
-103
lines changed

8 files changed

+37
-103
lines changed

examples/pipeline_tacotron2/README.md

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1 @@
11
This is an example pipeline for text-to-speech using Tacotron2.
2-
3-
4-
## Instructions for running tests
5-
6-
#### Install required the package for testing
7-
8-
```bash
9-
pip install pytest
10-
```
11-
12-
#### Run tests
13-
14-
Execute the following command in the directory `examples/pipeline_tacotron2/`
15-
16-
```bash
17-
pytest .
18-
```

examples/pipeline_tacotron2/loss/loss_function.py renamed to examples/pipeline_tacotron2/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
class Tacotron2Loss(nn.Module):
34-
"""Tacotron2 loss function adapted from:
34+
"""Tacotron2 loss function modified from:
3535
https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py
3636
"""
3737

examples/pipeline_tacotron2/loss/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/pipeline_tacotron2/loss/tacotron2_loss_gpu_test.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

test/torchaudio_unittest/example/tacotron2/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
import torch
2-
import unittest
32

43
from .tacotron2_loss_impl import (
54
Tacotron2LossShapeTests,
65
Tacotron2LossTorchscriptTests,
76
Tacotron2LossGradcheckTests,
87
)
8+
from torchaudio_unittest.common_utils import PytorchTestCase
99

1010

11-
class TestTacotron2LossShapeFloat32CPU(unittest.TestCase, Tacotron2LossShapeTests):
11+
class TestTacotron2LossShapeFloat32CPU(PytorchTestCase, Tacotron2LossShapeTests):
1212
dtype = torch.float32
1313
device = torch.device("cpu")
1414

1515

16-
class TestTacotron2TorchsciptFloat32CPU(unittest.TestCase, Tacotron2LossTorchscriptTests):
16+
class TestTacotron2TorchsciptFloat32CPU(PytorchTestCase, Tacotron2LossTorchscriptTests):
1717
dtype = torch.float32
1818
device = torch.device("cpu")
1919

2020

21-
class TestTacotron2GradcheckFloat64CPU(unittest.TestCase, Tacotron2LossGradcheckTests):
21+
class TestTacotron2GradcheckFloat64CPU(PytorchTestCase, Tacotron2LossGradcheckTests):
2222
dtype = torch.float64 # gradcheck needs a higher numerical accuracy
2323
device = torch.device("cpu")
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
3+
from .tacotron2_loss_impl import (
4+
Tacotron2LossShapeTests,
5+
Tacotron2LossTorchscriptTests,
6+
Tacotron2LossGradcheckTests,
7+
)
8+
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
9+
10+
11+
@skipIfNoCuda
12+
class TestTacotron2LossShapeFloat32CUDA(PytorchTestCase, Tacotron2LossShapeTests):
13+
dtype = torch.float32
14+
device = torch.device("cuda")
15+
16+
17+
@skipIfNoCuda
18+
class TestTacotron2TorchsciptFloat32CUDA(PytorchTestCase, Tacotron2LossTorchscriptTests):
19+
dtype = torch.float32
20+
device = torch.device("cuda")
21+
22+
23+
@skipIfNoCuda
24+
class TestTacotron2GradcheckFloat64CUDA(PytorchTestCase, Tacotron2LossGradcheckTests):
25+
dtype = torch.float64 # gradcheck needs a higher numerical accuracy
26+
device = torch.device("cuda")

examples/pipeline_tacotron2/loss/tacotron2_loss_impl.py renamed to test/torchaudio_unittest/example/tacotron2/tacotron2_loss_impl.py

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,8 @@
1-
import os
2-
import unittest
3-
import tempfile
4-
51
import torch
62
from torch.autograd import gradcheck, gradgradcheck
73

8-
from .loss_function import Tacotron2Loss
9-
10-
11-
class TempDirMixin:
12-
"""Mixin to provide easy access to temp dir"""
13-
14-
temp_dir_ = None
15-
16-
@classmethod
17-
def get_base_temp_dir(cls):
18-
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
19-
# this is handy for debugging.
20-
key = "TORCHAUDIO_TEST_TEMP_DIR"
21-
if key in os.environ:
22-
return os.environ[key]
23-
if cls.temp_dir_ is None:
24-
cls.temp_dir_ = tempfile.TemporaryDirectory()
25-
return cls.temp_dir_.name
26-
27-
@classmethod
28-
def tearDownClass(cls):
29-
super().tearDownClass()
30-
if cls.temp_dir_ is not None:
31-
cls.temp_dir_.cleanup()
32-
cls.temp_dir_ = None
33-
34-
def get_temp_path(self, *paths):
35-
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
36-
path = os.path.join(temp_dir, *paths)
37-
os.makedirs(os.path.dirname(path), exist_ok=True)
38-
return path
4+
from pipeline_tacotron2.loss import Tacotron2Loss
5+
from torchaudio_unittest.common_utils import TempDirMixin
396

407

418
class Tacotron2LossInputMixin(TempDirMixin):
@@ -68,7 +35,7 @@ def _get_inputs(self, n_mel=80, n_batch=16, max_mel_specgram_length=300):
6835
class Tacotron2LossShapeTests(Tacotron2LossInputMixin):
6936

7037
def test_tacotron2_loss_shape(self):
71-
f"""Validate the output shape of Tacotron2Loss."""
38+
"""Validate the output shape of Tacotron2Loss."""
7239
n_batch = 16
7340

7441
(
@@ -79,7 +46,7 @@ def test_tacotron2_loss_shape(self):
7946
truth_gate_out,
8047
) = self._get_inputs(n_batch=n_batch)
8148

82-
mel_loss, mel_postnet_loss, gate_loss = Tacotron2Loss()(
49+
mel_loss, mel_postnet_loss, gate_loss = Tacotron2Loss()(
8350
(mel_specgram, mel_specgram_postnet, gate_out),
8451
(truth_mel_specgram, truth_gate_out)
8552
)
@@ -102,7 +69,7 @@ def _assert_torchscript_consistency(self, fn, tensors):
10269
self.assertEqual(ts_output, output)
10370

10471
def test_tacotron2_loss_torchscript_consistency(self):
105-
f"""Validate the torchscript consistency of Tacotron2Loss."""
72+
"""Validate the torchscript consistency of Tacotron2Loss."""
10673

10774
loss_fn = Tacotron2Loss()
10875
self._assert_torchscript_consistency(loss_fn, self._get_inputs())
@@ -111,7 +78,7 @@ def test_tacotron2_loss_torchscript_consistency(self):
11178
class Tacotron2LossGradcheckTests(Tacotron2LossInputMixin):
11279

11380
def test_tacotron2_loss_gradcheck(self):
114-
f"""Performing gradient check on Tacotron2Loss."""
81+
"""Performing gradient check on Tacotron2Loss."""
11582
(
11683
mel_specgram,
11784
mel_specgram_postnet,

0 commit comments

Comments
 (0)