@@ -18,19 +18,15 @@ def __init__(self,
18
18
relu_max_clip : int = 20 ) -> None :
19
19
super (FullyConnected , self ).__init__ ()
20
20
self .fc = nn .Linear (in_features , hidden_size , bias = True )
21
- self .nonlinearity = nn .Sequential (* [
22
- nn .ReLU (),
23
- nn .Hardtanh (0 , relu_max_clip )
24
- ])
25
- if dropout :
26
- self .nonlinearity = nn .Sequential (* [
27
- self .nonlinearity ,
28
- nn .Dropout (dropout )
29
- ])
21
+ self .relu_max_clip = relu_max_clip
22
+ self .dropout = dropout
30
23
31
24
def forward (self , x : torch .Tensor ) -> torch .Tensor :
32
25
x = self .fc (x )
33
- x = self .nonlinearity (x )
26
+ x = torch .nn .functional .relu (x )
27
+ x = torch .nn .functional .hardtanh (x , 0 , self .relu_max_clip )
28
+ if self .dropout :
29
+ x = torch .nn .functional .dropout (x , self .dropout , self .training )
34
30
return x
35
31
36
32
@@ -58,19 +54,15 @@ def __init__(self,
58
54
self .fc3 = FullyConnected (hidden_size , hidden_size , dropout )
59
55
self .bi_rnn = nn .RNN (
60
56
hidden_size , hidden_size , num_layers = 1 , nonlinearity = 'relu' , bidirectional = True )
61
- self .nonlinearity = nn .ReLU ()
62
57
self .fc4 = FullyConnected (hidden_size , hidden_size , dropout )
63
- self .out = nn .Sequential (* [
64
- nn .Linear (hidden_size , num_classes ),
65
- nn .LogSoftmax (dim = 2 )
66
- ])
58
+ self .out = nn .Linear (hidden_size , num_classes )
67
59
68
60
def forward (self , x : torch .Tensor ) -> torch .Tensor :
69
61
"""
70
62
Args:
71
63
x (torch.Tensor): Tensor of dimension (batch_size, num_channels, input_length, num_features).
72
64
Returns:
73
- Tensor: Predictor tensor of dimension (input_length, batch_size , number_of_classes).
65
+ Tensor: Predictor tensor of dimension (batch_size, input_length , number_of_classes).
74
66
"""
75
67
# N x C x T x F
76
68
x = self .fc1 (x )
@@ -91,4 +83,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
91
83
# T x N x H
92
84
x = self .out (x )
93
85
# T x N x num_classes
86
+ x = x .permute (1 , 0 , 2 )
87
+ # N x T x num_classes
88
+ x = torch .nn .functional .log_softmax (x , dim = 2 )
89
+ # T x N x num_classes
94
90
return x
0 commit comments