Skip to content

Commit bb7432f

Browse files
discortvincentqb
authored andcommitted
batch_first in model output for better performance
1 parent 7a4a38c commit bb7432f

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

torchaudio/models/deepspeech.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,15 @@ def __init__(self,
1818
relu_max_clip: int = 20) -> None:
1919
super(FullyConnected, self).__init__()
2020
self.fc = nn.Linear(in_features, hidden_size, bias=True)
21-
self.nonlinearity = nn.Sequential(*[
22-
nn.ReLU(),
23-
nn.Hardtanh(0, relu_max_clip)
24-
])
25-
if dropout:
26-
self.nonlinearity = nn.Sequential(*[
27-
self.nonlinearity,
28-
nn.Dropout(dropout)
29-
])
21+
self.relu_max_clip = relu_max_clip
22+
self.dropout = dropout
3023

3124
def forward(self, x: torch.Tensor) -> torch.Tensor:
3225
x = self.fc(x)
33-
x = self.nonlinearity(x)
26+
x = torch.nn.functional.relu(x)
27+
x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip)
28+
if self.dropout:
29+
x = torch.nn.functional.dropout(x, self.dropout, self.training)
3430
return x
3531

3632

@@ -58,19 +54,15 @@ def __init__(self,
5854
self.fc3 = FullyConnected(hidden_size, hidden_size, dropout)
5955
self.bi_rnn = nn.RNN(
6056
hidden_size, hidden_size, num_layers=1, nonlinearity='relu', bidirectional=True)
61-
self.nonlinearity = nn.ReLU()
6257
self.fc4 = FullyConnected(hidden_size, hidden_size, dropout)
63-
self.out = nn.Sequential(*[
64-
nn.Linear(hidden_size, num_classes),
65-
nn.LogSoftmax(dim=2)
66-
])
58+
self.out = nn.Linear(hidden_size, num_classes)
6759

6860
def forward(self, x: torch.Tensor) -> torch.Tensor:
6961
"""
7062
Args:
7163
x (torch.Tensor): Tensor of dimension (batch_size, num_channels, input_length, num_features).
7264
Returns:
73-
Tensor: Predictor tensor of dimension (input_length, batch_size, number_of_classes).
65+
Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes).
7466
"""
7567
# N x C x T x F
7668
x = self.fc1(x)
@@ -91,4 +83,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9183
# T x N x H
9284
x = self.out(x)
9385
# T x N x num_classes
86+
x = x.permute(1, 0, 2)
87+
# N x T x num_classes
88+
x = torch.nn.functional.log_softmax(x, dim=2)
89+
# T x N x num_classes
9490
return x

0 commit comments

Comments
 (0)