Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import org.jetbrains.kotlinx.dl.api.core.layer.activation.PReLU
import org.jetbrains.kotlinx.dl.api.core.layer.activation.LeakyReLU
import org.jetbrains.kotlinx.dl.api.core.layer.activation.Softmax
import org.jetbrains.kotlinx.dl.api.core.layer.activation.ThresholdedReLU
import org.jetbrains.kotlinx.dl.api.core.layer.activation.ReLU
import org.jetbrains.kotlinx.dl.api.core.layer.activation.ELU
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.*
import org.jetbrains.kotlinx.dl.api.core.layer.core.ActivationLayer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense
Expand Down Expand Up @@ -88,6 +90,8 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i
is BatchNorm -> createKerasBatchNorm(layer, isKerasFullyCompatible)
is ActivationLayer -> createKerasActivationLayer(layer)
is PReLU -> createKerasPReLULayer(layer, isKerasFullyCompatible)
is ReLU -> createKerasReLU(layer)
is ELU -> createKerasELU(layer)
is LeakyReLU -> createKerasLeakyReLU(layer)
is ThresholdedReLU -> createKerasThresholdedReLULayer(layer)
is Add -> createKerasAddLayer(layer)
Expand Down Expand Up @@ -242,6 +246,24 @@ private fun createKerasSoftmaxLayer(layer: Softmax): KerasLayer {
return KerasLayer(class_name = LAYER_SOFTMAX, config = configX)
}

private fun createKerasReLU(layer: ReLU): KerasLayer {
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
max_value = layer.maxValue?.toDouble(),
negative_slope = layer.negativeSlope.toDouble(),
threshold = layer.threshold.toDouble()
)
return KerasLayer(class_name = LAYER_RELU, config = configX)
}

private fun createKerasELU(layer: ELU): KerasLayer {
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
alpha = layer.alpha.toDouble()
)
return KerasLayer(class_name = LAYER_ELU, config = configX)
}

private fun createKerasLeakyReLU(layer: LeakyReLU): KerasLayer {
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
Expand Down
112 changes: 112 additions & 0 deletions examples/src/main/kotlin/examples/inference/saveload/SaveLoadElu.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package examples.inference.saveload

import examples.inference.lenet5
import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
import org.jetbrains.kotlinx.dl.api.core.initializer.GlorotNormal
import org.jetbrains.kotlinx.dl.api.core.initializer.GlorotUniform
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding
import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.MaxPool2D
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
import org.jetbrains.kotlinx.dl.api.core.optimizer.SGD
import org.jetbrains.kotlinx.dl.dataset.handler.NUMBER_OF_CLASSES
import org.jetbrains.kotlinx.dl.dataset.mnist
import java.io.File


private const val MODEL_SAVE_PATH = "savedmodels/elu_lenet_saveload"

private const val NUM_CHANNELS = 1L
private const val IMAGE_SIZE = 28L
private const val SEED = 12L

private val kernelInitializer = GlorotNormal(SEED)
private val biasInitializer = GlorotUniform(SEED)

/**
* See [lenet5]. This just has Relu replaced for ELU on earlier layers for save/load test.
*/
private fun modifiedLenet5(): Sequential = Sequential.of(
Input(
IMAGE_SIZE,
IMAGE_SIZE,
NUM_CHANNELS,
name = "input_0"
),
Conv2D(
filters = 32,
kernelSize = longArrayOf(5, 5),
strides = longArrayOf(1, 1, 1, 1),
activation = Activations.Elu,
kernelInitializer = kernelInitializer,
biasInitializer = biasInitializer,
padding = ConvPadding.SAME,
name = "conv2d_1"
),
MaxPool2D(
poolSize = intArrayOf(1, 2, 2, 1),
strides = intArrayOf(1, 2, 2, 1),
name = "maxPool_2"
),
Conv2D(
filters = 64,
kernelSize = longArrayOf(5, 5),
strides = longArrayOf(1, 1, 1, 1),
activation = Activations.Elu,
kernelInitializer = kernelInitializer,
biasInitializer = biasInitializer,
padding = ConvPadding.SAME,
name = "conv2d_3"
),
MaxPool2D(
poolSize = intArrayOf(1, 2, 2, 1),
strides = intArrayOf(1, 2, 2, 1),
name = "maxPool_4"
),
Flatten(name = "flatten_5"), // 3136
Dense(
outputSize = 120,
activation = Activations.Relu,
kernelInitializer = kernelInitializer,
biasInitializer = biasInitializer,
name = "dense_6"
),
Dense(
outputSize = 84,
activation = Activations.Relu,
kernelInitializer = kernelInitializer,
biasInitializer = biasInitializer,
name = "dense_7"
),
Dense(
outputSize = NUMBER_OF_CLASSES,
activation = Activations.Linear,
kernelInitializer = kernelInitializer,
biasInitializer = biasInitializer,
name = "dense_8"
)
)

/**
* This examples demonstrates running Save and Load for prediction on [mnist] dataset.
*/
fun eluLenetOnMnistWithIntermediateSave() {
val (train, test) = mnist()
SaveTrainedModelHelper().trainAndSave(train, test, modifiedLenet5(), MODEL_SAVE_PATH, 0.7)
Sequential.loadDefaultModelConfiguration(File(MODEL_SAVE_PATH)).use {
it.compile(
optimizer = SGD(learningRate = 0.3f), loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
metric = Metrics.ACCURACY
)
it.loadWeights(File(MODEL_SAVE_PATH))
val accuracy = it.evaluate(test).metrics[Metrics.ACCURACY] ?: 0.0
println("Accuracy is : $accuracy")
}
}

fun main(): Unit = eluLenetOnMnistWithIntermediateSave()
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package examples.inference.saveload

import examples.inference.lenet5
import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
import org.jetbrains.kotlinx.dl.api.core.optimizer.SGD
import org.jetbrains.kotlinx.dl.dataset.mnist
import java.io.File


private const val MODEL_SAVE_PATH = "savedmodels/relu_lenet_saveload"

/**
* This examples demonstrates running Save and Load for prediction on [mnist] dataset.
*/
fun reluLenetOnMnistWithIntermediateSave() {
val (train, test) = mnist()
SaveTrainedModelHelper().trainAndSave(
train, test, lenet5(),
MODEL_SAVE_PATH, 0.7
)
Sequential.loadDefaultModelConfiguration(File(MODEL_SAVE_PATH)).use {
it.compile(
optimizer = SGD(learningRate = 0.3f),
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
metric = Metrics.ACCURACY
)
it.loadWeights(File(MODEL_SAVE_PATH))
val accuracy = it.evaluate(test).metrics[Metrics.ACCURACY] ?: 0.0
println("Accuracy is : $accuracy")
}
}

fun main(): Unit = reluLenetOnMnistWithIntermediateSave()
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package examples.inference.saveload

import org.jetbrains.kotlinx.dl.api.core.SavingFormat
import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.WritingMode
import org.jetbrains.kotlinx.dl.api.core.callback.Callback
import org.jetbrains.kotlinx.dl.api.core.history.BatchTrainingEvent
import org.jetbrains.kotlinx.dl.api.core.history.TrainingHistory
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
import org.jetbrains.kotlinx.dl.api.core.optimizer.SGD
import org.jetbrains.kotlinx.dl.dataset.Dataset
import java.io.File

/**
* The object wraps the logic on a given model training up to particular accuracy on a test dataset
* and then persist it in a file.
*/
class SaveTrainedModelHelper(val trainBatchSize: Int = 500, val testBatchSize: Int = 1000) {

/**
* Train [model] on [train] dataset and evaluate accuracy on [test] dataset until [accuracyThreshold] is reached
* then saves model to the folder [path].
*/
fun trainAndSave(train: Dataset, test: Dataset, model: Sequential, path: String, accuracyThreshold: Double = 0.7) {
model.use {
it.name = "lenet-accuracy$accuracyThreshold"
it.compile(
optimizer = SGD(learningRate = 0.3f),
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
metric = Metrics.ACCURACY,
callback = object : Callback() {
override fun onTrainBatchEnd(
batch: Int,
batchSize: Int,
event: BatchTrainingEvent,
logs: TrainingHistory
) {
if (event.metricValue > accuracyThreshold) {
println("Stopping training at ${event.metricValue} accuracy")
model.stopTraining = true
}
}
}
)
it.init()
var accuracy = 0.0
while (accuracy < accuracyThreshold) {
it.fit(dataset = train, epochs = 1, batchSize = trainBatchSize)
accuracy = it.evaluate(dataset = test, batchSize = testBatchSize).metrics[Metrics.ACCURACY] ?: 0.0
println("Accuracy: $accuracy")
}
model.save(modelDirectory = File(path), savingFormat = SavingFormat.JSON_CONFIG_CUSTOM_VARIABLES, writingMode = WritingMode.OVERRIDE)
}
}
}