diff --git a/docs/source/models.rst b/docs/source/models.rst index ea86d8b73b..2030eefd28 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -17,6 +17,14 @@ The models subpackage contains definitions of models for addressing common audio .. automethod:: forward +:hidden:`DeepSpeech` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeepSpeech + + .. automethod:: forward + + :hidden:`Wav2Letter` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/torchaudio_unittest/models_test.py b/test/torchaudio_unittest/models_test.py index 484ec6c10c..4db4895b8f 100644 --- a/test/torchaudio_unittest/models_test.py +++ b/test/torchaudio_unittest/models_test.py @@ -3,7 +3,7 @@ import torch from parameterized import parameterized -from torchaudio.models import ConvTasNet, Wav2Letter, WaveRNN +from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN from torchaudio.models.wavernn import MelResNet, UpsampleNetwork from torchaudio_unittest import common_utils @@ -174,3 +174,20 @@ def test_paper_configuration(self, num_sources, model_params): output = model(tensor) assert output.shape == (batch_size, num_sources, num_frames) + + +class TestDeepSpeech(common_utils.TorchaudioTestCase): + + def test_deepspeech(self): + n_batch = 2 + n_feature = 1 + n_channel = 1 + n_class = 40 + n_time = 320 + + model = DeepSpeech(n_feature=n_feature, n_class=n_class) + + x = torch.rand(n_batch, n_channel, n_time, n_feature) + out = model(x) + + assert out.size() == (n_batch, n_time, n_class) diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py index 5b134345af..6696d8ded2 100644 --- a/torchaudio/models/__init__.py +++ b/torchaudio/models/__init__.py @@ -1,9 +1,11 @@ from .wav2letter import Wav2Letter from .wavernn import WaveRNN from .conv_tasnet import ConvTasNet +from .deepspeech import DeepSpeech __all__ = [ 'Wav2Letter', 'WaveRNN', 'ConvTasNet', + 'DeepSpeech', ] diff --git a/torchaudio/models/deepspeech.py b/torchaudio/models/deepspeech.py new file mode 100644 index 0000000000..477993e411 --- /dev/null +++ b/torchaudio/models/deepspeech.py @@ -0,0 +1,92 @@ +import torch + +__all__ = ["DeepSpeech"] + + +class FullyConnected(torch.nn.Module): + """ + Args: + n_feature: Number of input features + n_hidden: Internal hidden unit size. + """ + + def __init__(self, + n_feature: int, + n_hidden: int, + dropout: float, + relu_max_clip: int = 20) -> None: + super(FullyConnected, self).__init__() + self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True) + self.relu_max_clip = relu_max_clip + self.dropout = dropout + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + x = torch.nn.functional.relu(x) + x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip) + if self.dropout: + x = torch.nn.functional.dropout(x, self.dropout, self.training) + return x + + +class DeepSpeech(torch.nn.Module): + """ + DeepSpeech model architecture from + `"Deep Speech: Scaling up end-to-end speech recognition"` + paper. + + Args: + n_feature: Number of input features + n_hidden: Internal hidden unit size. + n_class: Number of output classes + """ + + def __init__( + self, + n_feature: int, + n_hidden: int = 2048, + n_class: int = 40, + dropout: float = 0.0, + ) -> None: + super(DeepSpeech, self).__init__() + self.n_hidden = n_hidden + self.fc1 = FullyConnected(n_feature, n_hidden, dropout) + self.fc2 = FullyConnected(n_hidden, n_hidden, dropout) + self.fc3 = FullyConnected(n_hidden, n_hidden, dropout) + self.bi_rnn = torch.nn.RNN( + n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True + ) + self.fc4 = FullyConnected(n_hidden, n_hidden, dropout) + self.out = torch.nn.Linear(n_hidden, n_class) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Tensor of dimension (batch, channel, time, feature). + Returns: + Tensor: Predictor tensor of dimension (batch, time, class). + """ + # N x C x T x F + x = self.fc1(x) + # N x C x T x H + x = self.fc2(x) + # N x C x T x H + x = self.fc3(x) + # N x C x T x H + x = x.squeeze(1) + # N x T x H + x = x.transpose(0, 1) + # T x N x H + x, _ = self.bi_rnn(x) + # The fifth (non-recurrent) layer takes both the forward and backward units as inputs + x = x[:, :, :self.n_hidden] + x[:, :, self.n_hidden:] + # T x N x H + x = self.fc4(x) + # T x N x H + x = self.out(x) + # T x N x n_class + x = x.permute(1, 0, 2) + # N x T x n_class + x = torch.nn.functional.log_softmax(x, dim=2) + # N x T x n_class + return x