Skip to content

Commit 34ae33f

Browse files
authored
Add RepeatVector layer (#139)
* Added missing saving functions for ReLU and ELU activation layers (JetBrains#78) * Reverted changes to the imports * Added RepeatVector layer #123 * Added serialisation support for RepeatVector layer #123 * Wrote test for RepeatVector #123 * Made changed requested by avan (see desc.) - added missing require check in init block of RepeatVector - updated docs - reformatted code - housekeeping * Removed redundant Obs.repeat ext fun * Made changed requested by avan (see desc.) - change require message in computeOutputShape - used inputShape.size(...) for creating shape - removed author tag * Used `=` instead of `return` block, added TODO * Implemented changes requested by zaleslaw - save trainability status - renamed tests * Added test for negative `n` #123 * Added missing newline
1 parent f7bebf9 commit 34ae33f

File tree

6 files changed

+146
-0
lines changed

6 files changed

+146
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
3+
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
4+
*/
5+
6+
package org.jetbrains.kotlinx.dl.api.core.layer.reshaping
7+
8+
import org.jetbrains.kotlinx.dl.api.core.KGraph
9+
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
10+
import org.tensorflow.Operand
11+
import org.tensorflow.Shape
12+
import org.tensorflow.op.Ops
13+
14+
/**
15+
* Layer that repeats the input [n] times.
16+
*
17+
* Input shape: `2D tensor of shape (num_samples, features)`.
18+
*
19+
* Output shape: `3D tensor of shape (num_samples, n, features)`.
20+
*
21+
* @property n Repetition factor.
22+
* @property [name] Custom layer name.
23+
* @constructor Creates [RepeatVector] object.
24+
*
25+
* @since 0.3
26+
*/
27+
public class RepeatVector(
28+
public val n: Int,
29+
name: String = ""
30+
) : Layer(name) {
31+
32+
init {
33+
require(n >= 1) { "Number of repetitions (n) in RepeatVector should be positive but got $n" }
34+
}
35+
36+
override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit
37+
38+
override fun computeOutputShape(inputShape: Shape): Shape {
39+
require(inputShape.numDimensions() == 2) {
40+
"Input tensor must have 2 dimensions but got ${inputShape.numDimensions()}"
41+
}
42+
return Shape.make(inputShape.size(0), n.toLong(), inputShape.size(1))
43+
}
44+
45+
override fun forward(
46+
tf: Ops,
47+
input: Operand<Float>,
48+
isTraining: Operand<Boolean>,
49+
numberOfLosses: Operand<Float>?
50+
): Operand<Float> {
51+
val x = tf.expandDims(input, tf.constant(1))
52+
val pattern = tf.stack(listOf(tf.constant(1), tf.constant(n), tf.constant(1)))
53+
return tf.tile(x, pattern)
54+
}
55+
56+
override var weights: Map<String, Array<*>>
57+
get() = emptyMap()
58+
set(value) = assignWeights(value)
59+
60+
override val hasActivation: Boolean get() = false
61+
62+
override val paramCount: Int get() = 0
63+
64+
override fun toString(): String {
65+
return "RepeatVector"
66+
}
67+
}

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ internal const val LAYER_DROPOUT: String = "Dropout"
3838
// Attention layers
3939
// Reshaping layers
4040
internal const val LAYER_FLATTEN: String = "Flatten"
41+
internal const val LAYER_REPEAT_VECTOR: String = "RepeatVector"
4142
internal const val LAYER_RESHAPE: String = "Reshape"
4243
internal const val LAYER_ZERO_PADDING_2D = "ZeroPadding2D"
4344
internal const val LAYER_CROPPING_2D = "Cropping2D"

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.pooling.*
2222
import org.jetbrains.kotlinx.dl.api.core.layer.regularization.Dropout
2323
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Cropping2D
2424
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten
25+
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector
2526
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Reshape
2627
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D
2728
import org.jetbrains.kotlinx.dl.api.core.regularizer.L1
@@ -142,6 +143,7 @@ private fun convertToLayer(
142143
// Attention layers
143144
// Reshaping layers
144145
LAYER_FLATTEN -> createFlattenLayer(kerasLayer.config!!.name!!)
146+
LAYER_REPEAT_VECTOR -> createRepeatVectorLayer(kerasLayer.config!!, kerasLayer.config.name!!)
145147
LAYER_RESHAPE -> createReshapeLayer(kerasLayer.config!!, kerasLayer.config.name!!)
146148
LAYER_CROPPING_2D -> createCropping2DLayer(kerasLayer.config!!, kerasLayer.config.name!!)
147149
LAYER_ZERO_PADDING_2D -> createZeroPadding2DLayer(kerasLayer.config!!, kerasLayer.config.name!!)
@@ -722,6 +724,10 @@ private fun createFlattenLayer(name: String): Layer {
722724
return Flatten(name = name)
723725
}
724726

727+
private fun createRepeatVectorLayer(config: LayerConfig, name: String): Layer {
728+
return RepeatVector(name = name, n = config.n!!)
729+
}
730+
725731
private fun createReshapeLayer(config: LayerConfig, name: String): Layer {
726732
return Reshape(name = name, targetShape = config.target_shape!!)
727733
}

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.merge.*
2121
import org.jetbrains.kotlinx.dl.api.core.layer.normalization.BatchNorm
2222
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.*
2323
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten
24+
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector
2425
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D
2526
import org.jetbrains.kotlinx.dl.api.core.regularizer.L2L1
2627
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
@@ -98,6 +99,7 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i
9899
// Attention layers
99100
// Reshaping layers
100101
is Flatten -> createKerasFlattenLayer(layer)
102+
is RepeatVector -> createKerasRepeatVectorLayer(layer)
101103
is ZeroPadding2D -> createKerasZeroPadding2DLayer(layer)
102104
// Merging layers
103105
is Add -> createKerasAddLayer(layer)
@@ -584,6 +586,17 @@ private fun createKerasFlattenLayer(layer: Flatten): KerasLayer {
584586
return KerasLayer(class_name = LAYER_FLATTEN, config = configX)
585587
}
586588

589+
private fun createKerasRepeatVectorLayer(layer: RepeatVector): KerasLayer {
590+
val configX = LayerConfig(
591+
data_format = CHANNELS_LAST,
592+
dtype = DATATYPE_FLOAT32,
593+
trainable = layer.isTrainable,
594+
name = layer.name,
595+
n = layer.n
596+
)
597+
return KerasLayer(class_name = LAYER_REPEAT_VECTOR, config = configX)
598+
}
599+
587600
private fun createKerasConcatenateLayer(layer: Concatenate): KerasLayer {
588601
val configX = LayerConfig(
589602
dtype = DATATYPE_FLOAT32,

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ internal data class LayerConfig(
8585
@Json(serializeNull = false)
8686
val moving_variance_initializer: KerasInitializer? = null,
8787
@Json(serializeNull = false)
88+
val n: Int? = null,
89+
@Json(serializeNull = false)
8890
val name: String? = null,
8991
@Json(serializeNull = false)
9092
val negative_slope: Double? = null,
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
3+
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
4+
*/
5+
6+
package org.jetbrains.kotlinx.dl.api.core.layer
7+
8+
import org.jetbrains.kotlinx.dl.api.core.KGraph
9+
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector
10+
import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray
11+
import org.junit.jupiter.api.Assertions
12+
import org.junit.jupiter.api.Test
13+
import org.tensorflow.Graph
14+
import org.tensorflow.Output
15+
import org.tensorflow.Shape
16+
import org.tensorflow.op.Ops
17+
18+
internal class RepeatVectorLayerTest {
19+
20+
@Test
21+
fun testIllegalRepetitions() {
22+
Assertions.assertThrows(IllegalArgumentException::class.java) {
23+
RepeatVector(n = -10)
24+
}
25+
}
26+
27+
@Test
28+
fun testOutputShape() {
29+
val layer = RepeatVector(n = 2)
30+
val x = Array(10) { FloatArray(10) { 1F } }
31+
val y = layer(x)
32+
Assertions.assertArrayEquals(intArrayOf(10, layer.n, 10), y.shape().toIntArray())
33+
}
34+
35+
@Test
36+
fun testOutput() {
37+
val layer = RepeatVector(n = 2)
38+
val x = Array(3) { FloatArray(3) { it.toFloat() } }
39+
val y = layer(x)
40+
val actual = y.tensor().copyTo(Array(3) { Array(layer.n) { FloatArray(3) } })
41+
val expected = arrayOf(
42+
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)),
43+
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)),
44+
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F))
45+
)
46+
Assertions.assertArrayEquals(expected, actual)
47+
}
48+
49+
// TODO: generalise this for Layer, see https://github.com/JetBrains/KotlinDL/issues/145
50+
private operator fun RepeatVector.invoke(input: Array<FloatArray>): Output<Float> = Ops.create().let { tf ->
51+
build(tf, KGraph(Graph().toGraphDef()), Shape.make(10, 10))
52+
val inputOp = tf.constant(input)
53+
val isTraining = tf.constant(true)
54+
val numberOfLosses = tf.constant(1.0f)
55+
forward(tf, inputOp, isTraining, numberOfLosses).asOutput()
56+
}
57+
}

0 commit comments

Comments
 (0)