Skip to content

Commit 19750dd

Browse files
committed
Add loaders for ELU, RELU activation layers (#78)
- missing activation layers loaders added - two distinct examples with save/load added to examples folder (trying to reach 0.7 accuracy)
1 parent 218fb57 commit 19750dd

File tree

6 files changed

+131
-4
lines changed

6 files changed

+131
-4
lines changed

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ private fun createActivationLayer(config: LayerConfig, name: String): Layer {
398398

399399
private fun createReLULayer(config: LayerConfig, name: String): Layer {
400400
return ReLU(
401-
maxValue = config.max_value!!.toFloat(),
401+
maxValue = config.max_value?.toFloat(),
402402
negativeSlope = config.negative_slope!!.toFloat(),
403403
threshold = config.threshold!!.toFloat(),
404404
name = name

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import org.jetbrains.kotlinx.dl.api.core.layer.activation.PReLU
1616
import org.jetbrains.kotlinx.dl.api.core.layer.activation.LeakyReLU
1717
import org.jetbrains.kotlinx.dl.api.core.layer.activation.Softmax
1818
import org.jetbrains.kotlinx.dl.api.core.layer.activation.ThresholdedReLU
19+
import org.jetbrains.kotlinx.dl.api.core.layer.activation.ReLU
20+
import org.jetbrains.kotlinx.dl.api.core.layer.activation.ELU
1921
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.*
2022
import org.jetbrains.kotlinx.dl.api.core.layer.core.ActivationLayer
2123
import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense
@@ -87,6 +89,8 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i
8789
is BatchNorm -> createKerasBatchNorm(layer, isKerasFullyCompatible)
8890
is ActivationLayer -> createKerasActivationLayer(layer)
8991
is PReLU -> createKerasPReLULayer(layer, isKerasFullyCompatible)
92+
is ReLU -> createKerasReLU(layer)
93+
is ELU -> createKerasELU(layer)
9094
is LeakyReLU -> createKerasLeakyReLU(layer)
9195
is ThresholdedReLU -> createKerasThresholdedReLULayer(layer)
9296
is Add -> createKerasAddLayer(layer)
@@ -241,6 +245,24 @@ private fun createKerasSoftmaxLayer(layer: Softmax): KerasLayer {
241245
return KerasLayer(class_name = LAYER_SOFTMAX, config = configX)
242246
}
243247

248+
private fun createKerasReLU(layer: ReLU): KerasLayer {
249+
val configX = LayerConfig(
250+
dtype = DATATYPE_FLOAT32,
251+
max_value = layer.maxValue?.toDouble(),
252+
negative_slope = layer.negativeSlope.toDouble(),
253+
threshold = layer.threshold.toDouble()
254+
)
255+
return KerasLayer(class_name = LAYER_RELU, config = configX)
256+
}
257+
258+
private fun createKerasELU(layer: ELU): KerasLayer {
259+
val configX = LayerConfig(
260+
dtype = DATATYPE_FLOAT32,
261+
alpha = layer.alpha.toDouble()
262+
)
263+
return KerasLayer(class_name = LAYER_ELU, config = configX)
264+
}
265+
244266
private fun createKerasLeakyReLU(layer: LeakyReLU): KerasLayer {
245267
val configX = LayerConfig(
246268
dtype = DATATYPE_FLOAT32,

examples/src/main/kotlin/examples/inference/LeNetModel.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ private val biasInitializer = GlorotUniform(SEED)
2727
/**
2828
* Returns classic LeNet-5 model with minor improvements (Sigmoid activation -> ReLU activation, AvgPool layer -> MaxPool layer).
2929
*/
30-
fun lenet5(): Sequential = Sequential.of(
30+
fun lenet5(sigmoidActivations:Activations = Activations.Relu): Sequential = Sequential.of(
3131
Input(
3232
IMAGE_SIZE,
3333
IMAGE_SIZE,
@@ -38,7 +38,7 @@ fun lenet5(): Sequential = Sequential.of(
3838
filters = 32,
3939
kernelSize = longArrayOf(5, 5),
4040
strides = longArrayOf(1, 1, 1, 1),
41-
activation = Activations.Relu,
41+
activation = sigmoidActivations,
4242
kernelInitializer = kernelInitializer,
4343
biasInitializer = biasInitializer,
4444
padding = ConvPadding.SAME,
@@ -53,7 +53,7 @@ fun lenet5(): Sequential = Sequential.of(
5353
filters = 64,
5454
kernelSize = longArrayOf(5, 5),
5555
strides = longArrayOf(1, 1, 1, 1),
56-
activation = Activations.Relu,
56+
activation = sigmoidActivations,
5757
kernelInitializer = kernelInitializer,
5858
biasInitializer = biasInitializer,
5959
padding = ConvPadding.SAME,
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package examples.inference.saveload
2+
3+
import examples.inference.lenet5
4+
import org.jetbrains.kotlinx.dl.api.core.Sequential
5+
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
6+
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
7+
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
8+
import org.jetbrains.kotlinx.dl.api.core.optimizer.SGD
9+
import org.jetbrains.kotlinx.dl.dataset.mnist
10+
import java.io.File
11+
12+
13+
private const val PATH_TO_MODEL = "savedmodels/elu_lenet_saveload"
14+
15+
/**
16+
* This examples demonstrates running Save and Load for prediction on [mnist] dataset.
17+
*/
18+
fun eluLenetOnMnistWithIntermediateSave() {
19+
val (train, test) = mnist()
20+
SaveLoadExample.trainAndSave(train, test, lenet5(Activations.Elu), PATH_TO_MODEL, 0.7)
21+
Sequential.loadDefaultModelConfiguration(File(PATH_TO_MODEL)).use {
22+
it.compile(optimizer = SGD(learningRate = 0.3f), loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
23+
metric = Metrics.ACCURACY)
24+
it.loadWeights(File(PATH_TO_MODEL))
25+
val accuracy = it.evaluate(test).metrics[Metrics.ACCURACY] ?: 0.0
26+
println("Accuracy is : $accuracy")
27+
}
28+
}
29+
30+
fun main(): Unit = eluLenetOnMnistWithIntermediateSave()
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package examples.inference.saveload
2+
3+
import org.jetbrains.kotlinx.dl.api.core.SavingFormat
4+
import org.jetbrains.kotlinx.dl.api.core.Sequential
5+
import org.jetbrains.kotlinx.dl.api.core.WritingMode
6+
import org.jetbrains.kotlinx.dl.api.core.callback.Callback
7+
import org.jetbrains.kotlinx.dl.api.core.history.BatchTrainingEvent
8+
import org.jetbrains.kotlinx.dl.api.core.history.TrainingHistory
9+
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
10+
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
11+
import org.jetbrains.kotlinx.dl.api.core.optimizer.SGD
12+
import org.jetbrains.kotlinx.dl.dataset.Dataset
13+
import java.io.File
14+
15+
object SaveLoadExample {
16+
17+
private const val TEST_BATCH_SIZE = 1000
18+
private const val TRAINING_BATCH_SIZE = 500
19+
20+
fun trainAndSave(train: Dataset, test: Dataset, model: Sequential, path: String, accuracyThreshold: Double = 0.7) {
21+
model.use {
22+
it.name = "lenet-accuracy85"
23+
it.compile(
24+
optimizer = SGD(learningRate = 0.3f),
25+
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
26+
metric = Metrics.ACCURACY,
27+
callback = object : Callback() {
28+
override fun onTrainBatchEnd(batch: Int, batchSize: Int, event: BatchTrainingEvent, logs: TrainingHistory) {
29+
if (event.metricValue > accuracyThreshold+0.1) {
30+
println("Stopping training at ${event.metricValue} accuracy")
31+
model.stopTraining = true
32+
}
33+
}
34+
}
35+
)
36+
it.init()
37+
var accuracy = 0.0
38+
while (accuracy < accuracyThreshold) {
39+
it.fit(dataset = train, epochs = 1, batchSize = TRAINING_BATCH_SIZE)
40+
accuracy = it.evaluate(dataset = test, batchSize = TEST_BATCH_SIZE).metrics[Metrics.ACCURACY] ?: 0.0
41+
println("Accuracy: $accuracy")
42+
}
43+
model.save(modelDirectory = File(path), savingFormat = SavingFormat.JSON_CONFIG_CUSTOM_VARIABLES, writingMode = WritingMode.OVERRIDE)
44+
}
45+
}
46+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package examples.inference.saveload
2+
3+
import examples.inference.lenet5
4+
import org.jetbrains.kotlinx.dl.api.core.Sequential
5+
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
6+
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
7+
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
8+
import org.jetbrains.kotlinx.dl.api.core.optimizer.SGD
9+
import org.jetbrains.kotlinx.dl.dataset.mnist
10+
import java.io.File
11+
12+
13+
private const val PATH_TO_MODEL = "savedmodels/relu_lenet_saveload"
14+
15+
/**
16+
* This examples demonstrates running Save and Load for prediction on [mnist] dataset.
17+
*/
18+
fun reluLenetOnMnistWithIntermediateSave() {
19+
val (train, test) = mnist()
20+
SaveLoadExample.trainAndSave(train, test, lenet5(Activations.Relu), PATH_TO_MODEL, 0.7)
21+
Sequential.loadDefaultModelConfiguration(File(PATH_TO_MODEL)).use {
22+
it.compile(optimizer = SGD(learningRate = 0.3f), loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS, metric = Metrics.ACCURACY)
23+
it.loadWeights(File(PATH_TO_MODEL))
24+
val accuracy = it.evaluate(test).metrics[Metrics.ACCURACY] ?: 0.0
25+
println("Accuracy is : $accuracy")
26+
}
27+
}
28+
29+
fun main(): Unit = reluLenetOnMnistWithIntermediateSave()

0 commit comments

Comments
 (0)