@@ -95,11 +95,13 @@ function forward(tc::TextClassifier, gen::Channel, tracked_steps::Integer=32)
95
95
X = take! (gen)
96
96
l = length (X)
97
97
# Truncated Backprop through time
98
- for i= 1 : ceil (l/ now_per_pass)- 1 # Tracking is swiched off inside this loop
99
- (i == 1 && l% now_per_pass != 0 ) ? (last_idx = l% now_per_pass) : (last_idx = now_per_pass)
100
- H = broadcast (x -> indices (x, classifier. vocab, " _unk_" ), X[1 : last_idx])
101
- H = classifier. rnn_layers .(H)
102
- X = X[last_idx+ 1 : end ]
98
+ Zygote. ignore do
99
+ for i= 1 : ceil (l/ tracked_steps)- 1 # Tracking is swiched off inside this loop
100
+ (i == 1 && l% tracked_steps != 0 ) ? (last_idx = l% tracked_steps) : (last_idx = tracked_steps)
101
+ H = broadcast (x -> indices (x, classifier. vocab, " _unk_" ), X[1 : last_idx])
102
+ H = classifier. rnn_layers .(H)
103
+ X = X[last_idx+ 1 : end ]
104
+ end
103
105
end
104
106
# set the lated hidden states to original model
105
107
for (t_layer, unt_layer) in zip (tc. rnn_layers[2 : end ], classifier. rnn_layers[2 : end ])
@@ -130,7 +132,7 @@ Arguments:
130
132
131
133
classifier : Instance of TextClassifier
132
134
gen : 'Channel' [data loader], to give a mini-batch
133
- tracked_words : specifies the number of time-steps for which tracking is on
135
+ tracked_steps : specifies the number of time-steps for which tracking is on
134
136
"""
135
137
function loss (classifier:: TextClassifier , gen:: Channel , tracked_steps:: Integer = 32 )
136
138
H = forward (classifier, gen, tracked_steps)
@@ -140,6 +142,23 @@ function loss(classifier::TextClassifier, gen::Channel, tracked_steps::Integer=3
140
142
return l
141
143
end
142
144
145
+ function discriminative_step! (layers, classifier:: TextClassifier , gen, tracked_steps:: Integer = 32 , ηL:: Float64 , opts:: Vector )
146
+ @assert length (opts) == length (layers)
147
+ # Gradient calculation
148
+ grads = Zygote. gradient (() -> loss (classifier, gen, tracked_steps = tracked_steps), get_trainable_params (layers))
149
+
150
+ # discriminative step
151
+ ηl = ηL/ (2.6 ^ (length (layers)- 1 ))
152
+ for (layer, opt) in zip (layers, opts)
153
+ opt. eta = ηl
154
+ for ps in get_trainable_params ([layer])
155
+ Flux. Optimise. update! (opt, ps, grads[ps])
156
+ end
157
+ ηl *= 2.6
158
+ end
159
+ return
160
+ end
161
+
143
162
"""
144
163
train_classifier!(classifier::TextClassifier=TextClassifier(), classes::Integer=1,
145
164
data_loader::Channel=imdb_classifier_data, hidden_layer_size::Integer=50;kw...)
@@ -151,7 +170,7 @@ function train_classifier!(classifier::TextClassifier=TextClassifier(), classes:
151
170
data_loader:: Channel = imdb_classifier_data, hidden_layer_size:: Integer = 50 ;
152
171
stlr_cut_frac:: Float64 = 0.1 , stlr_ratio:: Number = 32 , stlr_η_max:: Float64 = 0.01 ,
153
172
val_loader:: Channel = nothing , cross_val_batches:: Union{Colon, Integer} = :,
154
- epochs:: Integer = 1 , checkpoint_itvl= 5000 )
173
+ epochs:: Integer = 1 , checkpoint_itvl= 5000 , tracked_steps :: Integer = 32 )
155
174
156
175
trainable = []
157
176
append! (trainable, [classifier. rnn_layers[[1 , 3 , 5 , 7 ]]. .. ])
@@ -166,7 +185,6 @@ function train_classifier!(classifier::TextClassifier=TextClassifier(), classes:
166
185
num_of_iters = take! (gen)
167
186
cut = num_of_iters * epochs * stlr_cut_frac
168
187
for iter= 1 : num_of_iters
169
- l = loss (classifier, gen, now_per_pass = now_per_pass)
170
188
171
189
# Slanted triangular learning rates
172
190
t = iter + (epoch- 1 )* num_of_iters
@@ -175,7 +193,7 @@ function train_classifier!(classifier::TextClassifier=TextClassifier(), classes:
175
193
176
194
# Gradual-unfreezing Step with discriminative fine-tuning
177
195
unfreezed_layers, cur_opts = (epoch < length (trainable)) ? (trainable[end - epoch+ 1 : end ], opts[end - epoch+ 1 : end ]) : (trainable, opts)
178
- discriminative_step! (unfreezed_layers, ηL, l , cur_opts)
196
+ discriminative_step! (unfreezed_layers, classifier, gen, tracked_steps,ηL , cur_opts)
179
197
180
198
reset_masks! .(classifier. rnn_layers) # reset all dropout masks
181
199
end
0 commit comments