Skip to content

Fix layers import and export #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class PReLU(
* TODO: support for constraint (alphaConstraint) should be added
*/

private lateinit var alpha: KVariable
internal lateinit var alpha: KVariable
private fun alphaVariableName(): String =
if (name.isNotEmpty()) "${name}_alpha" else "alpha"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

Expand Down Expand Up @@ -67,6 +67,6 @@ public class Cropping2D(
}

override fun toString(): String {
return "Cropping2D(name = $name, isTrainable=$isTrainable, cropping=${cropping.contentToString()}, hasActivation=$hasActivation)"
return "Cropping2D(name = $name, isTrainable=$isTrainable, cropping=${cropping.contentDeepToString()}, hasActivation = $hasActivation)"
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

Expand Down Expand Up @@ -70,6 +70,6 @@ public class Cropping3D(
}

override fun toString(): String {
return "Cropping3D(name = $name, isTrainable=$isTrainable, cropping=${cropping.contentToString()}, hasActivation=$hasActivation)"
return "Cropping3D(name = $name, isTrainable=$isTrainable, cropping=${cropping.contentDeepToString()}, hasActivation=$hasActivation)"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ private fun createConcatenateLayer(config: LayerConfig): Layer {

private fun createDotLayer(config: LayerConfig): Layer {
return Dot(
axis = config.axis!! as IntArray
axis = config.axis!! as IntArray,
normalize = config.normalize ?: false
)
}

Expand Down Expand Up @@ -445,7 +446,7 @@ private fun createPReLULayer(config: LayerConfig): Layer {
return PReLU(
alphaInitializer = convertToInitializer(config.alpha_initializer!!),
alphaRegularizer = convertToRegularizer(config.alpha_regularizer),
sharedAxes = config.shared_axes!!.toIntArray()
sharedAxes = config.shared_axes?.toIntArray()
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i
is UpSampling1D -> createKerasUpSampling1DLayer(layer)
is UpSampling2D -> createKerasUpSampling2DLayer(layer)
is UpSampling3D -> createKerasUpSampling3DLayer(layer)
is Reshape -> createKerasReshapeLayer(layer)
// Merging layers
is Add -> createKerasAddLayer(layer)
is Maximum -> createKerasMaximumLayer(layer)
Expand Down Expand Up @@ -432,6 +433,7 @@ private fun createKerasReLULayer(layer: ReLU): KerasLayer {
max_value = layer.maxValue?.toDouble(),
negative_slope = layer.negativeSlope.toDouble(),
threshold = layer.threshold.toDouble(),
name = layer.name,
trainable = layer.isTrainable
)
return KerasLayer(class_name = LAYER_RELU, config = configX)
Expand All @@ -441,6 +443,7 @@ private fun createKerasELULayer(layer: ELU): KerasLayer {
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
alpha = layer.alpha.toDouble(),
name = layer.name,
trainable = layer.isTrainable
)
return KerasLayer(class_name = LAYER_ELU, config = configX)
Expand Down Expand Up @@ -571,8 +574,8 @@ private fun createKerasAvgPool1DLayer(layer: AvgPool1D): KerasLayer {
}

private fun createKerasMaxPool3DLayer(layer: MaxPool3D): KerasLayer {
val poolSize = mutableListOf(layer.poolSize[1], layer.poolSize[3])
val strides = mutableListOf(layer.strides[1], layer.strides[3])
val poolSize = mutableListOf(layer.poolSize[1], layer.poolSize[2], layer.poolSize[3])
val strides = mutableListOf(layer.strides[1], layer.strides[2], layer.strides[3])
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
name = layer.name,
Expand Down Expand Up @@ -657,7 +660,8 @@ private fun createKerasDotLayer(layer: Dot): KerasLayer {
dtype = DATATYPE_FLOAT32,
axis = layer.axis,
name = layer.name,
trainable = layer.isTrainable
trainable = layer.isTrainable,
normalize = layer.normalize
)
return KerasLayer(class_name = LAYER_DOT, config = configX)
}
Expand Down Expand Up @@ -922,3 +926,12 @@ private fun createKerasUpSampling3DLayer(layer: UpSampling3D): KerasLayer {
)
return KerasLayer(class_name = LAYER_UP_SAMPLING_3D, config = configX)
}

private fun createKerasReshapeLayer(layer: Reshape): KerasLayer {
val configX = LayerConfig(
target_shape = layer.targetShape,
name = layer.name,
trainable = layer.isTrainable,
)
return KerasLayer(class_name = LAYER_RESHAPE, config = configX)
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ package org.jetbrains.kotlinx.dl.api.inference.keras

import org.jetbrains.kotlinx.dl.api.core.layer.KVariable
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D
import org.jetbrains.kotlinx.dl.api.core.layer.activation.PReLU
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.AbstractConv
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvTranspose
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.DepthwiseConv2D
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.SeparableConv2D
Expand All @@ -26,16 +27,18 @@ internal object WeightMappings {
private const val DEPTHWISE_KERNEL_DATA_PATH_TEMPLATE = "/%s/%s/depthwise_kernel:0"
private const val POINTWISE_KERNEL_DATA_PATH_TEMPLATE = "/%s/%s/pointwise_kernel:0"
private const val DEPTHWISE_BIAS_DATA_PATH_TEMPLATE = "/%s/%s/depthwise_bias:0"
private const val PRELU_ALPHA_DATA_PATH_TEMPLATE = "/%s/%s/alpha:0"

// TODO: add loading for all layers with weights from Keras like Conv1D and Conv3D
internal fun getLayerVariables(layer: Layer): Map<String, KVariable>? {
return when (layer) {
is Dense -> getDenseVariables(layer)
is Conv2D -> getConv2DVariables(layer)
is ConvTranspose -> getConvTransposeVariables(layer)
is DepthwiseConv2D -> getDepthwiseConv2DVariables(layer)
is SeparableConv2D -> getSeparableConv2DVariables(layer)
is AbstractConv -> getConvVariables(layer)
is BatchNorm -> getBatchNormVariables(layer)
is PReLU -> getPReLUVariables(layer)
else -> null
}
}
Expand All @@ -47,20 +50,21 @@ internal object WeightMappings {
internal fun getLayerVariablePathTemplates(layer: Layer, layerPaths: LayerPaths?): Map<KVariable, String>? {
return when (layer) {
is Dense -> getDenseVariablesPathTemplates(layer, layerPaths)
is Conv2D -> getConv2DVariablePathTemplates(layer, layerPaths)
is ConvTranspose -> getConvTransposeVariablePathTemplates(layer, layerPaths)
is DepthwiseConv2D -> getDepthwiseConv2DVariablePathTemplates(layer, layerPaths)
is SeparableConv2D -> getSeparableConv2DVariablePathTemplates(layer, layerPaths)
is AbstractConv -> getConvVariablePathTemplates(layer, layerPaths)
is BatchNorm -> getBatchNormVariablePathTemplates(layer, layerPaths)
is PReLU -> getPReLUVariablePathTemplates(layer, layerPaths)
else -> null
}
}

private fun getConv2DVariables(layer: Conv2D): Map<String, KVariable> {
private fun getConvVariables(layer: AbstractConv): Map<String, KVariable> {
return mapOfNotNull("kernel:0" to layer.kernel, "bias:0" to layer.bias)
}

private fun getConv2DVariablePathTemplates(layer: Conv2D, layerPaths: LayerPaths?): Map<KVariable, String> {
private fun getConvVariablePathTemplates(layer: AbstractConv, layerPaths: LayerPaths?): Map<KVariable, String> {
val layerConvOrDensePaths = layerPaths as? LayerConvOrDensePaths
?: LayerConvOrDensePaths(layer.name, KERNEL_DATA_PATH_TEMPLATE, BIAS_DATA_PATH_TEMPLATE)
return mapOfNotNull(
Expand Down Expand Up @@ -167,6 +171,16 @@ internal object WeightMappings {
layer.beta to layerBatchNormPaths.betaPath
)
}

private fun getPReLUVariables(layer: PReLU): Map<String, KVariable> {
return mapOfNotNull("alpha:0" to layer.alpha)
}

private fun getPReLUVariablePathTemplates(layer: PReLU, layerPaths: LayerPaths?): Map<KVariable, String> {
val layerPReLUPaths = layerPaths as? LayerPReLUPaths
?: LayerPReLUPaths(layer.name, PRELU_ALPHA_DATA_PATH_TEMPLATE)
return mapOfNotNull(layer.alpha to layerPReLUPaths.alphaPath)
}
}

/**
Expand Down Expand Up @@ -226,3 +240,12 @@ public class LayerBatchNormPaths(
/** */
public val movingVariancePath: String
) : LayerPaths(layerName)

/**
* Contains [layerName], [alphaPath] for [PReLU] layer, found in hdf5 file via
* ```
* recursivePrintGroupInHDF5File()
* ```
* function call.
*/
public class LayerPReLUPaths(layerName: String, public val alphaPath: String) : LayerPaths(layerName)
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,7 @@ internal data class LayerConfig(
@Json(serializeNull = false)
val use_bias: Boolean? = null,
@Json(serializeNull = false)
val dims: IntArray? = null
val dims: IntArray? = null,
@Json(serializeNull = false)
val normalize: Boolean? = null
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.inference.keras

import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal
import org.jetbrains.kotlinx.dl.api.core.layer.activation.*
import org.jetbrains.kotlinx.dl.api.core.layer.core.ActivationLayer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.regularizer.L2
import org.junit.jupiter.api.Test

class ActivationLayersImportExportTest {
@Test
fun activationLayer() {
LayerImportExportTest.run(
Sequential.of(
Input(100),
ActivationLayer(name = "test_activation_layer", activation = Activations.Exponential)
)
)
}

@Test
fun elu() {
LayerImportExportTest.run(
Sequential.of(
Input(100),
ELU(name = "test_elu", alpha = 0.5f)
)
)
}

@Test
fun leakyRelu() {
LayerImportExportTest.run(
Sequential.of(
Input(100),
LeakyReLU(name = "test_leaky_relu", alpha = 0.5f)
)
)
}

@Test
fun prelu() {
LayerImportExportTest.run(
Sequential.of(
Input(100),
PReLU(
name = "test_prelu",
alphaInitializer = HeNormal(),
alphaRegularizer = L2(),
sharedAxes = intArrayOf(1)
)
)
)
}

@Test
fun relu() {
LayerImportExportTest.run(
Sequential.of(
Input(100),
ReLU(name = "test_relu", maxValue = 2.0f, threshold = 0.1f, negativeSlope = 2.0f)
)
)
}

@Test
fun softmax() {
LayerImportExportTest.run(
Sequential.of(
Input(100),
Softmax(name = "test_softmax", axis = listOf(1))
)
)
}

@Test
fun thresholdedRelu() {
LayerImportExportTest.run(
Sequential.of(
Input(100),
ThresholdedReLU(name = "test_thresholded_relu", theta = 2f)
)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.regularizer.L2
import org.junit.jupiter.api.Test

class ConvolutionalLayersPersistenceTest {
class ConvolutionalLayersImportExportTest {
@Test
fun conv1D() {
LayerPersistenceTest.run(
LayerImportExportTest.run(
Sequential.of(
Input(dims = longArrayOf(256, 1)),
Conv1D(
Expand All @@ -40,7 +40,7 @@ class ConvolutionalLayersPersistenceTest {

@Test
fun conv2D() {
LayerPersistenceTest.run(
LayerImportExportTest.run(
Sequential.of(
Input(dims = longArrayOf(256, 256, 3)),
Conv2D(
Expand All @@ -64,7 +64,7 @@ class ConvolutionalLayersPersistenceTest {

@Test
fun conv3D() {
LayerPersistenceTest.run(
LayerImportExportTest.run(
Sequential.of(
Input(dims = longArrayOf(10, 256, 256, 3)),
Conv3D(
Expand All @@ -88,7 +88,7 @@ class ConvolutionalLayersPersistenceTest {

@Test
fun conv1DTranspose() {
LayerPersistenceTest.run(
LayerImportExportTest.run(
Sequential.of(
Input(dims = longArrayOf(3)),
Conv1DTranspose(
Expand All @@ -113,7 +113,7 @@ class ConvolutionalLayersPersistenceTest {

@Test
fun conv2DTranspose() {
LayerPersistenceTest.run(
LayerImportExportTest.run(
Sequential.of(
Input(dims = longArrayOf(3, 3)),
Conv2DTranspose(
Expand All @@ -138,7 +138,7 @@ class ConvolutionalLayersPersistenceTest {

@Test
fun conv3DTranspose() {
LayerPersistenceTest.run(
LayerImportExportTest.run(
Sequential.of(
Input(dims = longArrayOf(3, 3, 3)),
Conv3DTranspose(
Expand All @@ -161,8 +161,8 @@ class ConvolutionalLayersPersistenceTest {
}

@Test
fun separableConvTest() {
LayerPersistenceTest.run(
fun separableConv() {
LayerImportExportTest.run(
Sequential.of(
Input(dims = longArrayOf(30, 30, 3)),
SeparableConv2D(
Expand All @@ -188,8 +188,8 @@ class ConvolutionalLayersPersistenceTest {
}

@Test
fun depthwiseConvTest() {
LayerPersistenceTest.run(
fun depthwiseConv() {
LayerImportExportTest.run(
Sequential.of(
Input(dims = longArrayOf(30, 30, 3)),
DepthwiseConv2D(
Expand Down
Loading