Skip to content

Commit 3b614e0

Browse files
Update ULMFiT model
fix errors in training Correction in code for Text Classifier Remove gpu erro
1 parent 42a0e06 commit 3b614e0

File tree

10 files changed

+129
-103
lines changed

10 files changed

+129
-103
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@ version = "0.1.1"
66

77
[deps]
88
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
9+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
10+
CorpusLoaders = "214a0ac2-f95b-54f7-a80b-442ed9c2c9e8"
911
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
1012
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
13+
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1114
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1215
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1316
Languages = "8ef0a80b-9436-5d2c-a485-80b904378c43"
1417
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
18+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1519
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1620
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1721
TextAnalysis = "a2db99b7-8b79-58f8-94bf-bbc811eef33d"

src/TextModels.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,28 @@ module TextModels
3939

4040

4141
# ULMFiT
42-
#module ULMFiT
43-
# using ..TextAnalysis
44-
# using DataDeps
45-
# using Flux
46-
# using Tracker
47-
# using BSON
48-
# include("ULMFiT/utils.jl")
49-
# include("ULMFiT/datadeps.jl")
50-
# include("ULMFiT/data_loaders.jl")
51-
# include("ULMFiT/custom_layers.jl")
52-
# include("ULMFiT/pretrain_lm.jl")
53-
# include("ULMFiT/fine_tune_lm.jl")
54-
# include("ULMFiT/train_text_classifier.jl")
55-
#end
56-
#export ULMFiT
42+
module ULMFiT
43+
using TextAnalysis
44+
using DataDeps
45+
using Flux
46+
using Zygote
47+
using BSON
48+
using CorpusLoaders
49+
include("ULMFiT/utils.jl")
50+
include("ULMFiT/datadeps.jl")
51+
include("ULMFiT/data_loaders.jl")
52+
include("ULMFiT/custom_layers.jl")
53+
include("ULMFiT/pretrain_lm.jl")
54+
include("ULMFiT/fine_tune_lm.jl")
55+
include("ULMFiT/train_text_classifier.jl")
56+
end
57+
export ULMFiT
5758

5859
function __init__()
5960
pos_tagger_datadep_register()
6061
ner_datadep_register()
6162
pos_datadep_register()
62-
#ULMFiT.ulmfit_datadep_register()
63+
ULMFiT.ulmfit_datadep_register()
6364

6465
global sentiment_model = artifact"sentiment_model"
6566
end

src/ULMFiT/custom_layers.jl

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ This file contains the custom layers defined for this model:
88
PooledDense
99
"""
1010

11-
import Flux: gate, _testmode!, _dropout_kernel
11+
import Flux: gate, testmode!, _dropout_kernel
1212

1313
reset_masks!(entity) = nothing
1414
reset_probability!(entity) = nothing
@@ -44,12 +44,12 @@ Moreover this also follows the Vartional DropOut citeria, that is,
4444
the drop mask is remains same for a whole training pass.
4545
This is done by saving the masks in 'maskWi' and 'maskWh' fields
4646
"""
47-
mutable struct WeightDroppedLSTMCell{A, V, M}
47+
mutable struct WeightDroppedLSTMCell{A, V, S, M}
4848
Wi::A
4949
Wh::A
5050
b::V
51-
h::V
52-
c::V
51+
h::S
52+
c::S
5353
p::Float64
5454
maskWi::M
5555
maskWh::M
@@ -60,17 +60,17 @@ function WeightDroppedLSTMCell(in::Integer, out::Integer, p::Float64=0.0;
6060
init = Flux.glorot_uniform)
6161
@assert 0 p 1
6262
cell = WeightDroppedLSTMCell(
63-
param(init(out*4, in)),
64-
param(init(out*4, out)),
65-
param(init(out*4)),
66-
param(zeros(Float32, out)),
67-
param(zeros(Float32, out)),
63+
init(out*4, in),
64+
init(out*4, out),
65+
init(out*4),
66+
reshape(zeros(Float32, out),out, 1),
67+
reshape(zeros(Float32, out), out, 1),
6868
p,
6969
drop_mask((out*4, in), p),
7070
drop_mask((out*4, out), p),
7171
true
7272
)
73-
cell.b.data[gate(out, 2)] .= 1
73+
cell.b[gate(out, 2)] .= 1
7474
return cell
7575
end
7676

@@ -88,9 +88,12 @@ function (m::WeightDroppedLSTMCell)((h, c), x)
8888
return (h′, c), h′
8989
end
9090

91-
Flux.@treelike WeightDroppedLSTMCell
91+
Flux.@functor WeightDroppedLSTMCell
9292

93-
_testmode!(m::WeightDroppedLSTMCell, test) = (m.active = !test)
93+
Flux.trainable(m::WeightDroppedLSTMCell) = (m.Wi, m.Wh, m.b, m.h, m.c)
94+
95+
testmode!(m::WeightDroppedLSTMCell, mode=true) =
96+
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
9497

9598
"""
9699
WeightDroppedLSTM(in::Integer, out::Integer, p::Float64=0.0)
@@ -106,7 +109,7 @@ julia> wd = WeightDroppedLSTM(4, 5, 0.3);
106109
function WeightDroppedLSTM(a...; kw...)
107110
cell = WeightDroppedLSTMCell(a...;kw...)
108111
hidden = (cell.h, cell.c)
109-
return Flux.Recur(cell, hidden, hidden)
112+
return Flux.Recur(cell, hidden)
110113
end
111114

112115
"""
@@ -155,7 +158,9 @@ end
155158

156159
AWD_LSTM(in::Integer, out::Integer, p::Float64=0.0; kw...) = AWD_LSTM(WeightDroppedLSTM(in, out, p; kw...), -1, [])
157160

158-
Flux.@treelike AWD_LSTM
161+
Flux.@functor AWD_LSTM
162+
163+
Flux.trainable(m::AWD_LSTM) = (m.layer,)
159164

160165
(m::AWD_LSTM)(in) = m.layer(in)
161166

@@ -184,12 +189,12 @@ function asgd_step!(iter::Integer, layer::AWD_LSTM)
184189
p = get_trainable_params([layer])
185190
avg_fact = 1/max(iter - layer.T + 1, 1)
186191
if avg_fact != 1
187-
layer.accum = layer.accum .+ Tracker.data.(p)
192+
layer.accum = layer.accum .+ p
188193
for (ps, accum) in zip(p, layer.accum)
189-
Tracker.data(ps) .= avg_fact*accum
194+
ps .= avg_fact*accum
190195
end
191196
else
192-
layer.accum = deepcopy(Tracker.data.(p)) # Accumulator for ASGD
197+
layer.accum = deepcopy(p) # Accumulator for ASGD
193198
end
194199
end
195200
return
@@ -230,7 +235,8 @@ function (vd::VarDrop)(x)
230235
return (x .* vd.mask)
231236
end
232237

233-
_testmode!(vd::VarDrop, test) = (vd.active = !test)
238+
testmode!(m::VarDrop, mode=true) =
239+
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
234240

235241
# method for reseting mask of VarDrop
236242
reset_masks!(vd::VarDrop) = (vd.reset = true)
@@ -270,7 +276,7 @@ end
270276
function DroppedEmbeddings(in::Integer, embed_size::Integer, p::Float64=0.0;
271277
init = Flux.glorot_uniform)
272278
de = DroppedEmbeddings{AbstractArray, typeof(p)}(
273-
param(init(in, embed_size)),
279+
init(in, embed_size),
274280
p,
275281
drop_mask((in,), p),
276282
true
@@ -283,9 +289,12 @@ function (de::DroppedEmbeddings)(x::AbstractArray, tying::Bool=false)
283289
return tying ? dropped * x : transpose(dropped[x, :])
284290
end
285291

286-
Flux.@treelike DroppedEmbeddings
292+
Flux.@functor DroppedEmbeddings
293+
294+
Flux.trainable(m::DroppedEmbeddings) = (m.emb)
287295

288-
_testmode!(de::DroppedEmbeddings, test) = (de.active = !test)
296+
testmode!(m::DroppedEmbeddings, mode=true) =
297+
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
289298

290299
function reset_masks!(de::DroppedEmbeddings)
291300
de.mask = drop_mask(de.mask, de.p)
@@ -324,10 +333,10 @@ PooledDense(W, b) = PooledDense(W, b, identity)
324333

325334
function PooledDense(hidden_sz::Integer, out::Integer, σ = identity;
326335
initW = Flux.glorot_uniform, initb = (dims...) -> zeros(Float32, dims...))
327-
return PooledDense(param(initW(out, hidden_sz*3)), param(initb(out)), σ)
336+
return PooledDense(initW(out, hidden_sz*3), initb(out), σ)
328337
end
329338

330-
Flux.@treelike PooledDense
339+
Flux.@functor PooledDense
331340

332341
function (a::PooledDense)(x)
333342
W, b, σ = a.W, a.b, a.σ

src/ULMFiT/data_loaders.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,29 +27,29 @@ function imdb_preprocess(doc::AbstractDocument)
2727
length(word) == 1 && return [word]
2828
return split(word, symbol)
2929
end
30-
text = text(doc)
31-
remove_corrupt_utf8!(text)
32-
remove_case!(text)
33-
prepare!(text, strip_html_tags)
34-
tokens = tokens(text)
30+
text_ = doc
31+
remove_corrupt_utf8!(text_)
32+
remove_case!(text_)
33+
prepare!(text_, strip_html_tags)
34+
tokens_ = tokens(text_)
3535
for symbol in [',', '.', '-', '/', "'s"]
36-
tokens = split_word.(tokens, symbol)
36+
tokens_ = split_word.(tokens_, symbol)
3737
temp = []
38-
for token in tokens
38+
for token_ in tokens_
3939
try
40-
append!(temp, put(token, symbol))
40+
append!(temp, put(token_, symbol))
4141
catch
42-
append!(temp, token)
42+
append!(temp, token_)
4343
end
4444
end
45-
tokens = temp
45+
tokens_ = temp
4646
end
47-
deleteat!(tokens, findall(x -> isequal(x, "")||isequal(x, " "), tokens))
48-
return tokens
47+
deleteat!(tokens_, findall(x -> isequal(x, "")||isequal(x, " "), tokens_))
48+
return tokens_
4949
end
5050

5151
# Loads WikiText-103 corpus and output a Channel to give a mini-batch at each call
52-
function load_wikitext_103(batchsize::Integer, bptt::Integer; type = "train")
52+
function load_wikitext_103(batchsize::Integer=16, bptt::Integer=70; type = "train")
5353
corpuspath = joinpath(datadep"WikiText-103", "wiki.$(type).tokens")
5454
corpus = read(open(corpuspath, "r"), String)
5555
corpus = tokenize(corpus)
@@ -58,13 +58,13 @@ end
5858

5959
# IMDB Data loaders for Sentiment Analysis specifically
6060
# IMDB data loader for fine-tuning Language Model
61-
function imdb_fine_tune_data(batchsize::Integer, bptt::Integer, num_examples::Integer=50000)
61+
function imdb_fine_tune_data(batchsize::Integer=16, bptt::Integer=70, num_examples::Integer=50000)
6262
imdb_dataset = IMDB("train_unsup")
6363
dataset = []
64-
for path in imdb_dataset.filepaths #extract data from the files in directory and put into channel
64+
for path in imdb_dataset.filepaths[1:num_examples] #extract data from the files in directory and put into channel
6565
open(path) do fileio
6666
cur_text = read(fileio, String)
67-
append!(dataset, imdb_preprocess(cur_text))
67+
append!(dataset, imdb_preprocess(StringDocument(cur_text)))
6868
end #open
6969
end #for
7070
return Channel(x -> generator(x, dataset; batchsize=batchsize, bptt=bptt))

src/ULMFiT/fine_tune_lm.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@ opts : `Vector` of optimizers used to update weights for corresponding la
2424
2525
NOTE: length(opts) == length(layers)
2626
"""
27-
function discriminative_step!(layers, ηL::Float64, l, opts::Vector)
27+
function discriminative_step!(layers, lm::LanguageModel, gen, ηL::Float64, opts::Vector)
2828
@assert length(opts) == length(layers)
2929
# Gradient calculation
30-
grads = Tracker.gradient(() -> l, get_trainable_params(layers))
30+
grads = Zygote.gradient(() -> loss(lm, gen), get_trainable_params(layers))
3131

3232
# discriminative step
3333
ηl = ηL/(2.6^(length(layers)-1))
3434
for (layer, opt) in zip(layers, opts)
3535
opt.eta = ηl
3636
for ps in get_trainable_params([layer])
37-
Tracker.update!(opt, ps, grads[ps])
37+
Flux.Optimise.update!(opt, ps, grads[ps])
3838
end
3939
ηl *= 2.6
4040
end
@@ -50,32 +50,28 @@ This function contains main training loops for fine-tuning the language model.
5050
To use this funciton, an instance of LanguageModel and a data loader is needed.
5151
Read the docs for more info about arguments
5252
"""
53-
function fine_tune_lm!(lm::LanguageModel, data_loader::Channel=imdb_fine_tune_data,
54-
stlr_cut_frac::Float64=0.1, stlr_ratio::Float32=32, stlr_η_max::Float64=4e-3;
53+
function fine_tune_lm!(lm=LanguageModel(), data_loader=imdb_fine_tune_data,
54+
stlr_cut_frac::Float64=0.1, stlr_ratio::Float32=Float32(32), stlr_η_max::Float64=4e-3;
5555
epochs::Integer=1, checkpoint_itvl::Integer=5000)
5656

5757
opts = [ADAM(0.001, (0.7, 0.99)) for i=1:4]
58-
cut = num_of_iters * epochs * stlr_cut_frac
59-
58+
6059
# Fine-Tuning loops
6160
for epoch=1:epochs
6261
println("\nEpoch: $epoch")
63-
gen = data_loader()
64-
num_of_iters = take!(gen)
62+
gen = data_loader()
63+
num_of_iters = take!(gen)
64+
cut = num_of_iters * epochs * stlr_cut_frac
6565
T = num_of_iters-Int(floor((num_of_iters*2)/100))
6666
set_trigger!.(T, lm.layers)
6767
for i=1:num_of_iters
68-
69-
# FORWARD
70-
l = loss(lm, gen)
71-
7268
# Slanted triangular learning rate step
7369
t = i + (epoch-1)*num_of_iters
7470
p_frac = (i < cut) ? i/cut : (1 - ((i-cut)/(cut*(1/stlr_cut_frac-1))))
7571
ηL = stlr_η_max*((1+p_frac*(stlr_ratio-1))/stlr_ratio)
7672

7773
# Backprop with discriminative fine-tuning step
78-
discriminative_step!(lm.layers[[1, 3, 5, 7]], ηL, l, opts)
74+
discriminative_step!(lm.layers[[1, 3, 5, 7]], lm, gen, ηL, opts)
7975

8076
# Resets dropout masks for all the layers with DropOut or DropConnect
8177
reset_masks!.(lm.layers)
@@ -121,7 +117,7 @@ julia> insert!(vocab, 2, "_pad_")
121117
function set_vocab!(lm::LanguageModel, vocab::Vector)
122118
idxs = indices(vocab, lm.vocab)
123119
lm.vocab = vocab
124-
lm.layers[1].emb = param(Tracker.data(lm.layers[1].emb)[idxs, :])
120+
lm.layers[1].emb = param(lm.layers[1].emb[idxs, :])
125121
lm.layers[1].mask = gpu(drop_mask((length(vocab),), lm.layers[1].p))
126122
return
127123
end

0 commit comments

Comments
 (0)