Skip to content

Commit 1f13667

Browse files
discortvincentqb
andauthored
Add vanilla DeepSpeech model (#1399)
Co-authored-by: Vincent Quenneville-Belair <[email protected]>
1 parent 4b2de71 commit 1f13667

File tree

4 files changed

+120
-1
lines changed

4 files changed

+120
-1
lines changed

docs/source/models.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ The models subpackage contains definitions of models for addressing common audio
1717
.. automethod:: forward
1818

1919

20+
:hidden:`DeepSpeech`
21+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
22+
23+
.. autoclass:: DeepSpeech
24+
25+
.. automethod:: forward
26+
27+
2028
:hidden:`Wav2Letter`
2129
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2230

test/torchaudio_unittest/models_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
from parameterized import parameterized
6-
from torchaudio.models import ConvTasNet, Wav2Letter, WaveRNN
6+
from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN
77
from torchaudio.models.wavernn import MelResNet, UpsampleNetwork
88
from torchaudio_unittest import common_utils
99

@@ -174,3 +174,20 @@ def test_paper_configuration(self, num_sources, model_params):
174174
output = model(tensor)
175175

176176
assert output.shape == (batch_size, num_sources, num_frames)
177+
178+
179+
class TestDeepSpeech(common_utils.TorchaudioTestCase):
180+
181+
def test_deepspeech(self):
182+
n_batch = 2
183+
n_feature = 1
184+
n_channel = 1
185+
n_class = 40
186+
n_time = 320
187+
188+
model = DeepSpeech(n_feature=n_feature, n_class=n_class)
189+
190+
x = torch.rand(n_batch, n_channel, n_time, n_feature)
191+
out = model(x)
192+
193+
assert out.size() == (n_batch, n_time, n_class)

torchaudio/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from .wav2letter import Wav2Letter
22
from .wavernn import WaveRNN
33
from .conv_tasnet import ConvTasNet
4+
from .deepspeech import DeepSpeech
45

56
__all__ = [
67
'Wav2Letter',
78
'WaveRNN',
89
'ConvTasNet',
10+
'DeepSpeech',
911
]

torchaudio/models/deepspeech.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
3+
__all__ = ["DeepSpeech"]
4+
5+
6+
class FullyConnected(torch.nn.Module):
7+
"""
8+
Args:
9+
n_feature: Number of input features
10+
n_hidden: Internal hidden unit size.
11+
"""
12+
13+
def __init__(self,
14+
n_feature: int,
15+
n_hidden: int,
16+
dropout: float,
17+
relu_max_clip: int = 20) -> None:
18+
super(FullyConnected, self).__init__()
19+
self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True)
20+
self.relu_max_clip = relu_max_clip
21+
self.dropout = dropout
22+
23+
def forward(self, x: torch.Tensor) -> torch.Tensor:
24+
x = self.fc(x)
25+
x = torch.nn.functional.relu(x)
26+
x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip)
27+
if self.dropout:
28+
x = torch.nn.functional.dropout(x, self.dropout, self.training)
29+
return x
30+
31+
32+
class DeepSpeech(torch.nn.Module):
33+
"""
34+
DeepSpeech model architecture from
35+
`"Deep Speech: Scaling up end-to-end speech recognition"`
36+
<https://arxiv.org/abs/1412.5567> paper.
37+
38+
Args:
39+
n_feature: Number of input features
40+
n_hidden: Internal hidden unit size.
41+
n_class: Number of output classes
42+
"""
43+
44+
def __init__(
45+
self,
46+
n_feature: int,
47+
n_hidden: int = 2048,
48+
n_class: int = 40,
49+
dropout: float = 0.0,
50+
) -> None:
51+
super(DeepSpeech, self).__init__()
52+
self.n_hidden = n_hidden
53+
self.fc1 = FullyConnected(n_feature, n_hidden, dropout)
54+
self.fc2 = FullyConnected(n_hidden, n_hidden, dropout)
55+
self.fc3 = FullyConnected(n_hidden, n_hidden, dropout)
56+
self.bi_rnn = torch.nn.RNN(
57+
n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True
58+
)
59+
self.fc4 = FullyConnected(n_hidden, n_hidden, dropout)
60+
self.out = torch.nn.Linear(n_hidden, n_class)
61+
62+
def forward(self, x: torch.Tensor) -> torch.Tensor:
63+
"""
64+
Args:
65+
x (torch.Tensor): Tensor of dimension (batch, channel, time, feature).
66+
Returns:
67+
Tensor: Predictor tensor of dimension (batch, time, class).
68+
"""
69+
# N x C x T x F
70+
x = self.fc1(x)
71+
# N x C x T x H
72+
x = self.fc2(x)
73+
# N x C x T x H
74+
x = self.fc3(x)
75+
# N x C x T x H
76+
x = x.squeeze(1)
77+
# N x T x H
78+
x = x.transpose(0, 1)
79+
# T x N x H
80+
x, _ = self.bi_rnn(x)
81+
# The fifth (non-recurrent) layer takes both the forward and backward units as inputs
82+
x = x[:, :, :self.n_hidden] + x[:, :, self.n_hidden:]
83+
# T x N x H
84+
x = self.fc4(x)
85+
# T x N x H
86+
x = self.out(x)
87+
# T x N x n_class
88+
x = x.permute(1, 0, 2)
89+
# N x T x n_class
90+
x = torch.nn.functional.log_softmax(x, dim=2)
91+
# N x T x n_class
92+
return x

0 commit comments

Comments
 (0)