From 9b359a172d06428062a39aa97fff7099b96b713c Mon Sep 17 00:00:00 2001 From: discort Date: Thu, 18 Mar 2021 17:56:52 +0200 Subject: [PATCH 1/6] #446 add vanilla deepspeech model --- docs/source/models.rst | 7 ++ test/torchaudio_unittest/models_test.py | 19 +++++- torchaudio/models/__init__.py | 2 + torchaudio/models/deepspeech.py | 88 +++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 torchaudio/models/deepspeech.py diff --git a/docs/source/models.rst b/docs/source/models.rst index ea86d8b73b..1cf282c573 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -31,3 +31,10 @@ The models subpackage contains definitions of models for addressing common audio .. autoclass:: WaveRNN .. automethod:: forward + +:hidden:`DeepSpeech` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeepSpeech + + .. automethod:: forward diff --git a/test/torchaudio_unittest/models_test.py b/test/torchaudio_unittest/models_test.py index 484ec6c10c..76ee682c10 100644 --- a/test/torchaudio_unittest/models_test.py +++ b/test/torchaudio_unittest/models_test.py @@ -3,7 +3,7 @@ import torch from parameterized import parameterized -from torchaudio.models import ConvTasNet, Wav2Letter, WaveRNN +from torchaudio.models import ConvTasNet, Wav2Letter, WaveRNN, DeepSpeech from torchaudio.models.wavernn import MelResNet, UpsampleNetwork from torchaudio_unittest import common_utils @@ -174,3 +174,20 @@ def test_paper_configuration(self, num_sources, model_params): output = model(tensor) assert output.shape == (batch_size, num_sources, num_frames) + + +class TestDeepSpeech(common_utils.TorchaudioTestCase): + + def test_deepspeech(self): + batch_size = 2 + num_features = 1 + num_channels = 1 + num_classes = 40 + input_length = 320 + + model = DeepSpeech(in_features=1, num_classes=num_classes) + + x = torch.rand(batch_size, num_channels, input_length, num_features) + out = model(x) + + assert out.size() == (input_length, batch_size, num_classes) diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py index 5b134345af..6696d8ded2 100644 --- a/torchaudio/models/__init__.py +++ b/torchaudio/models/__init__.py @@ -1,9 +1,11 @@ from .wav2letter import Wav2Letter from .wavernn import WaveRNN from .conv_tasnet import ConvTasNet +from .deepspeech import DeepSpeech __all__ = [ 'Wav2Letter', 'WaveRNN', 'ConvTasNet', + 'DeepSpeech', ] diff --git a/torchaudio/models/deepspeech.py b/torchaudio/models/deepspeech.py new file mode 100644 index 0000000000..d0f4292c6b --- /dev/null +++ b/torchaudio/models/deepspeech.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn + +__all__ = ["DeepSpeech"] + + +class FullyConnected(nn.Module): + """ + Args: + in_features: Number of input features + hidden_size: Internal hidden unit size. + """ + + def __init__(self, + in_features: int, + hidden_size: int, + dropout: float, + relu_max_clip: int = 20) -> None: + super(FullyConnected, self).__init__() + self.fc = nn.Linear(in_features, hidden_size, bias=True) + self.nonlinearity = nn.Sequential(*[ + nn.ReLU(), + nn.Hardtanh(0, relu_max_clip) + ]) + if dropout: + self.nonlinearity = nn.Sequential(*[ + self.nonlinearity, + nn.Dropout(dropout) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + x = self.nonlinearity(x) + return x + + +class DeepSpeech(nn.Module): + """ + DeepSpeech model architecture from + `"Deep Speech: Scaling up end-to-end speech recognition"` + paper. + + Args: + in_features: Number of input features + hidden_size: Internal hidden unit size. + num_classes: Number of output classes + """ + + def __init__(self, + in_features: int, + hidden_size: int = 2048, + num_classes: int = 40, + dropout: float = 0.0) -> None: + super(DeepSpeech, self).__init__() + self.hidden_size = hidden_size + self.fc1 = FullyConnected(in_features, hidden_size, dropout) + self.fc2 = FullyConnected(hidden_size, hidden_size, dropout) + self.fc3 = FullyConnected(hidden_size, hidden_size, dropout) + self.bi_rnn = nn.RNN( + hidden_size, hidden_size, num_layers=1, nonlinearity='relu', bidirectional=True) + self.nonlinearity = nn.ReLU() + self.fc4 = FullyConnected(hidden_size, hidden_size, dropout) + self.out = nn.Sequential(*[ + nn.Linear(hidden_size, num_classes), + nn.LogSoftmax(dim=2) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # N x C x T x F + x = self.fc1(x) + # N x C x T x H + x = self.fc2(x) + # N x C x T x H + x = self.fc3(x) + # N x C x T x H + x = x.squeeze(1) + # N x T x H + x = x.transpose(0, 1) + # T x N x H + x, _ = self.bi_rnn(x) + # The fifth (non-recurrent) layer takes both the forward and backward units as inputs + x = x[:, :, :self.hidden_size] + x[:, :, self.hidden_size:] + # T x N x H + x = self.fc4(x) + # T x N x H + x = self.out(x) + # T x N x num_classes + return x From 7a4a38cb72fc3aca15db03606d991060dc84c83f Mon Sep 17 00:00:00 2001 From: discort Date: Wed, 21 Apr 2021 20:57:29 +0300 Subject: [PATCH 2/6] docstring for deepspeech forward --- torchaudio/models/deepspeech.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchaudio/models/deepspeech.py b/torchaudio/models/deepspeech.py index d0f4292c6b..7dd5c4fdb4 100644 --- a/torchaudio/models/deepspeech.py +++ b/torchaudio/models/deepspeech.py @@ -66,6 +66,12 @@ def __init__(self, ]) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Tensor of dimension (batch_size, num_channels, input_length, num_features). + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, number_of_classes). + """ # N x C x T x F x = self.fc1(x) # N x C x T x H From bb7432fe3e818c7b5ce5cf6cb0b3321082e2cb28 Mon Sep 17 00:00:00 2001 From: discort Date: Fri, 30 Apr 2021 12:24:19 +0300 Subject: [PATCH 3/6] batch_first in model output for better performance --- torchaudio/models/deepspeech.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/torchaudio/models/deepspeech.py b/torchaudio/models/deepspeech.py index 7dd5c4fdb4..99a2ee9d6e 100644 --- a/torchaudio/models/deepspeech.py +++ b/torchaudio/models/deepspeech.py @@ -18,19 +18,15 @@ def __init__(self, relu_max_clip: int = 20) -> None: super(FullyConnected, self).__init__() self.fc = nn.Linear(in_features, hidden_size, bias=True) - self.nonlinearity = nn.Sequential(*[ - nn.ReLU(), - nn.Hardtanh(0, relu_max_clip) - ]) - if dropout: - self.nonlinearity = nn.Sequential(*[ - self.nonlinearity, - nn.Dropout(dropout) - ]) + self.relu_max_clip = relu_max_clip + self.dropout = dropout def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc(x) - x = self.nonlinearity(x) + x = torch.nn.functional.relu(x) + x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip) + if self.dropout: + x = torch.nn.functional.dropout(x, self.dropout, self.training) return x @@ -58,19 +54,15 @@ def __init__(self, self.fc3 = FullyConnected(hidden_size, hidden_size, dropout) self.bi_rnn = nn.RNN( hidden_size, hidden_size, num_layers=1, nonlinearity='relu', bidirectional=True) - self.nonlinearity = nn.ReLU() self.fc4 = FullyConnected(hidden_size, hidden_size, dropout) - self.out = nn.Sequential(*[ - nn.Linear(hidden_size, num_classes), - nn.LogSoftmax(dim=2) - ]) + self.out = nn.Linear(hidden_size, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (torch.Tensor): Tensor of dimension (batch_size, num_channels, input_length, num_features). Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, number_of_classes). + Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes). """ # N x C x T x F x = self.fc1(x) @@ -91,4 +83,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # T x N x H x = self.out(x) # T x N x num_classes + x = x.permute(1, 0, 2) + # N x T x num_classes + x = torch.nn.functional.log_softmax(x, dim=2) + # T x N x num_classes return x From 0a1647a9c51f96d352243c2d9f8023328584c231 Mon Sep 17 00:00:00 2001 From: discort Date: Sat, 1 May 2021 10:51:45 +0300 Subject: [PATCH 4/6] fixed tests for deepspeech --- test/torchaudio_unittest/models_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/models_test.py b/test/torchaudio_unittest/models_test.py index 76ee682c10..3e5c9ef3ec 100644 --- a/test/torchaudio_unittest/models_test.py +++ b/test/torchaudio_unittest/models_test.py @@ -190,4 +190,4 @@ def test_deepspeech(self): x = torch.rand(batch_size, num_channels, input_length, num_features) out = model(x) - assert out.size() == (input_length, batch_size, num_classes) + assert out.size() == (batch_size, input_length, num_classes) From 7b59009089dad2851a97f0c3b41bdc7763efb42f Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 11 May 2021 14:21:41 -0400 Subject: [PATCH 5/6] use naming convention from readme. --- test/torchaudio_unittest/models_test.py | 16 ++++----- torchaudio/models/deepspeech.py | 48 ++++++++++++------------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/test/torchaudio_unittest/models_test.py b/test/torchaudio_unittest/models_test.py index 3e5c9ef3ec..9df4fb8c6b 100644 --- a/test/torchaudio_unittest/models_test.py +++ b/test/torchaudio_unittest/models_test.py @@ -179,15 +179,15 @@ def test_paper_configuration(self, num_sources, model_params): class TestDeepSpeech(common_utils.TorchaudioTestCase): def test_deepspeech(self): - batch_size = 2 - num_features = 1 - num_channels = 1 - num_classes = 40 - input_length = 320 + n_batch = 2 + n_feature = 1 + n_channel = 1 + n_class = 40 + n_time = 320 - model = DeepSpeech(in_features=1, num_classes=num_classes) + model = DeepSpeech(n_feature=n_feature, n_class=n_class) - x = torch.rand(batch_size, num_channels, input_length, num_features) + x = torch.rand(n_batch, n_channel, n_time, n_feature) out = model(x) - assert out.size() == (batch_size, input_length, num_classes) + assert out.size() == (n_batch, n_time, n_class) diff --git a/torchaudio/models/deepspeech.py b/torchaudio/models/deepspeech.py index 99a2ee9d6e..eb7e68b66c 100644 --- a/torchaudio/models/deepspeech.py +++ b/torchaudio/models/deepspeech.py @@ -7,17 +7,17 @@ class FullyConnected(nn.Module): """ Args: - in_features: Number of input features - hidden_size: Internal hidden unit size. + n_feature: Number of input features + n_hidden: Internal hidden unit size. """ def __init__(self, - in_features: int, - hidden_size: int, + n_feature: int, + n_hidden: int, dropout: float, relu_max_clip: int = 20) -> None: super(FullyConnected, self).__init__() - self.fc = nn.Linear(in_features, hidden_size, bias=True) + self.fc = nn.Linear(n_feature, n_hidden, bias=True) self.relu_max_clip = relu_max_clip self.dropout = dropout @@ -37,32 +37,32 @@ class DeepSpeech(nn.Module): paper. Args: - in_features: Number of input features - hidden_size: Internal hidden unit size. - num_classes: Number of output classes + n_feature: Number of input features + n_hidden: Internal hidden unit size. + n_class: Number of output classes """ def __init__(self, - in_features: int, - hidden_size: int = 2048, - num_classes: int = 40, + n_feature: int, + n_hidden: int = 2048, + n_class: int = 40, dropout: float = 0.0) -> None: super(DeepSpeech, self).__init__() - self.hidden_size = hidden_size - self.fc1 = FullyConnected(in_features, hidden_size, dropout) - self.fc2 = FullyConnected(hidden_size, hidden_size, dropout) - self.fc3 = FullyConnected(hidden_size, hidden_size, dropout) + self.n_hidden = n_hidden + self.fc1 = FullyConnected(n_feature, n_hidden, dropout) + self.fc2 = FullyConnected(n_hidden, n_hidden, dropout) + self.fc3 = FullyConnected(n_hidden, n_hidden, dropout) self.bi_rnn = nn.RNN( - hidden_size, hidden_size, num_layers=1, nonlinearity='relu', bidirectional=True) - self.fc4 = FullyConnected(hidden_size, hidden_size, dropout) - self.out = nn.Linear(hidden_size, num_classes) + n_hidden, n_hidden, num_layers=1, nonlinearity='relu', bidirectional=True) + self.fc4 = FullyConnected(n_hidden, n_hidden, dropout) + self.out = nn.Linear(n_hidden, n_class) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: - x (torch.Tensor): Tensor of dimension (batch_size, num_channels, input_length, num_features). + x (torch.Tensor): Tensor of dimension (batch, channel, time, feature). Returns: - Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes). + Tensor: Predictor tensor of dimension (batch, time, class). """ # N x C x T x F x = self.fc1(x) @@ -77,14 +77,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # T x N x H x, _ = self.bi_rnn(x) # The fifth (non-recurrent) layer takes both the forward and backward units as inputs - x = x[:, :, :self.hidden_size] + x[:, :, self.hidden_size:] + x = x[:, :, :self.n_hidden] + x[:, :, self.n_hidden:] # T x N x H x = self.fc4(x) # T x N x H x = self.out(x) - # T x N x num_classes + # T x N x n_class x = x.permute(1, 0, 2) - # N x T x num_classes + # N x T x n_class x = torch.nn.functional.log_softmax(x, dim=2) - # T x N x num_classes + # N x T x n_class return x From 087430650a5d036476565475dec0e5be4e9fea09 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 11 May 2021 16:37:55 -0400 Subject: [PATCH 6/6] lint. --- docs/source/models.rst | 13 +++++++------ test/torchaudio_unittest/models_test.py | 2 +- torchaudio/models/deepspeech.py | 26 +++++++++++++------------ 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 1cf282c573..2030eefd28 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -17,24 +17,25 @@ The models subpackage contains definitions of models for addressing common audio .. automethod:: forward -:hidden:`Wav2Letter` +:hidden:`DeepSpeech` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: Wav2Letter +.. autoclass:: DeepSpeech .. automethod:: forward -:hidden:`WaveRNN` +:hidden:`Wav2Letter` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: WaveRNN +.. autoclass:: Wav2Letter .. automethod:: forward -:hidden:`DeepSpeech` + +:hidden:`WaveRNN` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: DeepSpeech +.. autoclass:: WaveRNN .. automethod:: forward diff --git a/test/torchaudio_unittest/models_test.py b/test/torchaudio_unittest/models_test.py index 9df4fb8c6b..4db4895b8f 100644 --- a/test/torchaudio_unittest/models_test.py +++ b/test/torchaudio_unittest/models_test.py @@ -3,7 +3,7 @@ import torch from parameterized import parameterized -from torchaudio.models import ConvTasNet, Wav2Letter, WaveRNN, DeepSpeech +from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN from torchaudio.models.wavernn import MelResNet, UpsampleNetwork from torchaudio_unittest import common_utils diff --git a/torchaudio/models/deepspeech.py b/torchaudio/models/deepspeech.py index eb7e68b66c..477993e411 100644 --- a/torchaudio/models/deepspeech.py +++ b/torchaudio/models/deepspeech.py @@ -1,10 +1,9 @@ import torch -import torch.nn as nn __all__ = ["DeepSpeech"] -class FullyConnected(nn.Module): +class FullyConnected(torch.nn.Module): """ Args: n_feature: Number of input features @@ -17,7 +16,7 @@ def __init__(self, dropout: float, relu_max_clip: int = 20) -> None: super(FullyConnected, self).__init__() - self.fc = nn.Linear(n_feature, n_hidden, bias=True) + self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True) self.relu_max_clip = relu_max_clip self.dropout = dropout @@ -30,7 +29,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class DeepSpeech(nn.Module): +class DeepSpeech(torch.nn.Module): """ DeepSpeech model architecture from `"Deep Speech: Scaling up end-to-end speech recognition"` @@ -42,20 +41,23 @@ class DeepSpeech(nn.Module): n_class: Number of output classes """ - def __init__(self, - n_feature: int, - n_hidden: int = 2048, - n_class: int = 40, - dropout: float = 0.0) -> None: + def __init__( + self, + n_feature: int, + n_hidden: int = 2048, + n_class: int = 40, + dropout: float = 0.0, + ) -> None: super(DeepSpeech, self).__init__() self.n_hidden = n_hidden self.fc1 = FullyConnected(n_feature, n_hidden, dropout) self.fc2 = FullyConnected(n_hidden, n_hidden, dropout) self.fc3 = FullyConnected(n_hidden, n_hidden, dropout) - self.bi_rnn = nn.RNN( - n_hidden, n_hidden, num_layers=1, nonlinearity='relu', bidirectional=True) + self.bi_rnn = torch.nn.RNN( + n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True + ) self.fc4 = FullyConnected(n_hidden, n_hidden, dropout) - self.out = nn.Linear(n_hidden, n_class) + self.out = torch.nn.Linear(n_hidden, n_class) def forward(self, x: torch.Tensor) -> torch.Tensor: """