Skip to content

Commit 54f5746

Browse files
fix errors in training
1 parent 2977425 commit 54f5746

File tree

4 files changed

+40
-27
lines changed

4 files changed

+40
-27
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.1.1"
66

77
[deps]
88
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
9+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
910
CorpusLoaders = "214a0ac2-f95b-54f7-a80b-442ed9c2c9e8"
1011
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
1112
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

src/ULMFiT/fine_tune_lm.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@ opts : `Vector` of optimizers used to update weights for corresponding la
2424
2525
NOTE: length(opts) == length(layers)
2626
"""
27-
function discriminative_step!(layers, ηL::Float64, l, opts::Vector)
27+
function discriminative_step!(layers, lm::LanguageModel, gen, ηL::Float64, opts::Vector)
2828
@assert length(opts) == length(layers)
2929
# Gradient calculation
30-
grads = Zygote.gradient(() -> l, get_trainable_params(layers))
30+
grads = Zygote.gradient(() -> loss(lm, gen), get_trainable_params(layers))
3131

3232
# discriminative step
3333
ηl = ηL/(2.6^(length(layers)-1))
3434
for (layer, opt) in zip(layers, opts)
3535
opt.eta = ηl
3636
for ps in get_trainable_params([layer])
37-
Flux.Optimise.update!(opt, ps, grads)
37+
Flux.Optimise.update!(opt, ps, grads[ps])
3838
end
3939
ηl *= 2.6
4040
end
@@ -55,26 +55,23 @@ function fine_tune_lm!(lm=LanguageModel(), data_loader=imdb_fine_tune_data,
5555
epochs::Integer=1, checkpoint_itvl::Integer=5000)
5656

5757
opts = [ADAM(0.001, (0.7, 0.99)) for i=1:4]
58-
gen = data_loader()
59-
num_of_iters = take!(gen)
60-
cut = num_of_iters * epochs * stlr_cut_frac
58+
6159
# Fine-Tuning loops
6260
for epoch=1:epochs
6361
println("\nEpoch: $epoch")
62+
gen = data_loader()
63+
num_of_iters = take!(gen)
64+
cut = num_of_iters * epochs * stlr_cut_frac
6465
T = num_of_iters-Int(floor((num_of_iters*2)/100))
6566
set_trigger!.(T, lm.layers)
6667
for i=1:num_of_iters
67-
68-
# FORWARD
69-
l = loss(lm, gen)
70-
7168
# Slanted triangular learning rate step
7269
t = i + (epoch-1)*num_of_iters
7370
p_frac = (i < cut) ? i/cut : (1 - ((i-cut)/(cut*(1/stlr_cut_frac-1))))
7471
ηL = stlr_η_max*((1+p_frac*(stlr_ratio-1))/stlr_ratio)
7572

7673
# Backprop with discriminative fine-tuning step
77-
discriminative_step!(lm.layers[[1, 3, 5, 7]], ηL, l, opts)
74+
discriminative_step!(lm.layers[[1, 3, 5, 7]], lm, gen, ηL, opts)
7875

7976
# Resets dropout masks for all the layers with DropOut or DropConnect
8077
reset_masks!.(lm.layers)

src/ULMFiT/pretrain_lm.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ function loss(lm, gen)
107107
end
108108

109109
# Backpropagation step while training
110-
function backward!(layers, l, opt)
110+
function backward!(layers, lm, gen, opt)
111111
# Calulating gradients and weights updation
112112
p = get_trainable_params(layers)
113-
grads = Zygote.gradient(() -> l, p)
113+
grads = Zygote.gradient(() -> loss(lm, gen), p)
114114
Flux.Optimise.update!(opt, p, grads)
115115
return
116116
end
@@ -138,11 +138,8 @@ function pretrain_lm!(lm::LanguageModel=LanguageModel(), data_loader::Channel=lo
138138
set_trigger!.(T, lm.layers) # Setting triggers for AWD_LSTM layers
139139
for i=1:num_of_batches
140140

141-
# FORWARD PASS
142-
l = loss(lm, gen)
143-
144141
# REVERSE PASS
145-
backward!(lm.layers, l, opt)
142+
backward!(lm.layers, lm, gen, opt)
146143

147144
# ASGD Step, works after Triggering
148145
asgd_step!.(i, lm.layers)
@@ -158,7 +155,7 @@ end
158155

159156
# To save model
160157
function save_model!(m::LanguageModel, filepath::String)
161-
weights = cpu.(Tracker.data.(params(m)))
158+
weights = cpu.(params(m))
162159
BSON.@save filepath weights
163160
end
164161

src/ULMFiT/train_text_classifier.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,13 @@ function forward(tc::TextClassifier, gen::Channel, tracked_steps::Integer=32)
9595
X = take!(gen)
9696
l = length(X)
9797
# Truncated Backprop through time
98-
for i=1:ceil(l/now_per_pass)-1 # Tracking is swiched off inside this loop
99-
(i == 1 && l%now_per_pass != 0) ? (last_idx = l%now_per_pass) : (last_idx = now_per_pass)
100-
H = broadcast(x -> indices(x, classifier.vocab, "_unk_"), X[1:last_idx])
101-
H = classifier.rnn_layers.(H)
102-
X = X[last_idx+1:end]
98+
Zygote.ignore do
99+
for i=1:ceil(l/tracked_steps)-1 # Tracking is swiched off inside this loop
100+
(i == 1 && l%tracked_steps != 0) ? (last_idx = l%tracked_steps) : (last_idx = tracked_steps)
101+
H = broadcast(x -> indices(x, classifier.vocab, "_unk_"), X[1:last_idx])
102+
H = classifier.rnn_layers.(H)
103+
X = X[last_idx+1:end]
104+
end
103105
end
104106
# set the lated hidden states to original model
105107
for (t_layer, unt_layer) in zip(tc.rnn_layers[2:end], classifier.rnn_layers[2:end])
@@ -130,7 +132,7 @@ Arguments:
130132
131133
classifier : Instance of TextClassifier
132134
gen : 'Channel' [data loader], to give a mini-batch
133-
tracked_words : specifies the number of time-steps for which tracking is on
135+
tracked_steps : specifies the number of time-steps for which tracking is on
134136
"""
135137
function loss(classifier::TextClassifier, gen::Channel, tracked_steps::Integer=32)
136138
H = forward(classifier, gen, tracked_steps)
@@ -140,6 +142,23 @@ function loss(classifier::TextClassifier, gen::Channel, tracked_steps::Integer=3
140142
return l
141143
end
142144

145+
function discriminative_step!(layers, classifier::TextClassifier, gen, tracked_steps::Integer=32, ηL::Float64, opts::Vector)
146+
@assert length(opts) == length(layers)
147+
# Gradient calculation
148+
grads = Zygote.gradient(() -> loss(classifier, gen, tracked_steps = tracked_steps), get_trainable_params(layers))
149+
150+
# discriminative step
151+
ηl = ηL/(2.6^(length(layers)-1))
152+
for (layer, opt) in zip(layers, opts)
153+
opt.eta = ηl
154+
for ps in get_trainable_params([layer])
155+
Flux.Optimise.update!(opt, ps, grads[ps])
156+
end
157+
ηl *= 2.6
158+
end
159+
return
160+
end
161+
143162
"""
144163
train_classifier!(classifier::TextClassifier=TextClassifier(), classes::Integer=1,
145164
data_loader::Channel=imdb_classifier_data, hidden_layer_size::Integer=50;kw...)
@@ -151,7 +170,7 @@ function train_classifier!(classifier::TextClassifier=TextClassifier(), classes:
151170
data_loader::Channel=imdb_classifier_data, hidden_layer_size::Integer=50;
152171
stlr_cut_frac::Float64=0.1, stlr_ratio::Number=32, stlr_η_max::Float64=0.01,
153172
val_loader::Channel=nothing, cross_val_batches::Union{Colon, Integer}=:,
154-
epochs::Integer=1, checkpoint_itvl=5000)
173+
epochs::Integer=1, checkpoint_itvl=5000, tracked_steps::Integer=32)
155174

156175
trainable = []
157176
append!(trainable, [classifier.rnn_layers[[1, 3, 5, 7]]...])
@@ -166,7 +185,6 @@ function train_classifier!(classifier::TextClassifier=TextClassifier(), classes:
166185
num_of_iters = take!(gen)
167186
cut = num_of_iters * epochs * stlr_cut_frac
168187
for iter=1:num_of_iters
169-
l = loss(classifier, gen, now_per_pass = now_per_pass)
170188

171189
# Slanted triangular learning rates
172190
t = iter + (epoch-1)*num_of_iters
@@ -175,7 +193,7 @@ function train_classifier!(classifier::TextClassifier=TextClassifier(), classes:
175193

176194
# Gradual-unfreezing Step with discriminative fine-tuning
177195
unfreezed_layers, cur_opts = (epoch < length(trainable)) ? (trainable[end-epoch+1:end], opts[end-epoch+1:end]) : (trainable, opts)
178-
discriminative_step!(unfreezed_layers, ηL, l, cur_opts)
196+
discriminative_step!(unfreezed_layers, classifier, gen, tracked_steps,ηL, cur_opts)
179197

180198
reset_masks!.(classifier.rnn_layers) # reset all dropout masks
181199
end

0 commit comments

Comments
 (0)