@@ -14,6 +14,7 @@ import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
14
14
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
15
15
import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionModel
16
16
import org.jetbrains.kotlinx.dl.api.inference.keras.loadWeights
17
+ import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
17
18
import org.jetbrains.kotlinx.dl.dataset.preprocessor.Preprocessing
18
19
import java.io.File
19
20
@@ -27,10 +28,10 @@ public object TFModels {
27
28
public sealed class CV <T : GraphTrainableModel >(
28
29
override val modelRelativePath : String ,
29
30
override val channelsFirst : Boolean = false ,
31
+ override val inputColorMode : ColorMode = ColorMode .RGB ,
30
32
public var inputShape : IntArray? = null ,
31
33
internal var noTop : Boolean = false
32
- ) :
33
- ModelType <T , ImageRecognitionModel > {
34
+ ) : ModelType<T, ImageRecognitionModel> {
34
35
35
36
init {
36
37
if (inputShape != null ) {
@@ -61,7 +62,12 @@ public object TFModels {
61
62
* Official VGG16 model from Keras.applications.</a>
62
63
*/
63
64
public class VGG16 (noTop : Boolean = false , inputShape : IntArray? = null ) :
64
- CV <Sequential >(" models/tensorflow/cv/vgg16" , inputShape = inputShape, noTop = noTop) {
65
+ CV <Sequential >(
66
+ " models/tensorflow/cv/vgg16" ,
67
+ inputShape = inputShape,
68
+ noTop = noTop,
69
+ inputColorMode = ColorMode .BGR
70
+ ) {
65
71
override fun preprocessInput (data : FloatArray , tensorShape : LongArray ): FloatArray {
66
72
return preprocessInput(data, tensorShape, inputType = InputType .CAFFE )
67
73
}
@@ -85,7 +91,12 @@ public object TFModels {
85
91
* Official VGG19 model from Keras.applications.</a>
86
92
*/
87
93
public class VGG19 (noTop : Boolean = false , inputShape : IntArray? = null ) :
88
- CV <Sequential >(" models/tensorflow/cv/vgg19" , inputShape = inputShape, noTop = noTop) {
94
+ CV <Sequential >(
95
+ " models/tensorflow/cv/vgg19" ,
96
+ inputShape = inputShape,
97
+ noTop = noTop,
98
+ inputColorMode = ColorMode .BGR
99
+ ) {
89
100
override fun preprocessInput (data : FloatArray , tensorShape : LongArray ): FloatArray {
90
101
return preprocessInput(data, tensorShape, inputType = InputType .CAFFE )
91
102
}
@@ -155,7 +166,12 @@ public object TFModels {
155
166
* Official ResNet50 model from Keras.applications.</a>
156
167
*/
157
168
public class ResNet50 (noTop : Boolean = false , inputShape : IntArray? = null ) :
158
- CV <Functional >(" models/tensorflow/cv/resnet50" , inputShape = inputShape, noTop = noTop) {
169
+ CV <Functional >(
170
+ " models/tensorflow/cv/resnet50" ,
171
+ inputShape = inputShape,
172
+ noTop = noTop,
173
+ inputColorMode = ColorMode .BGR
174
+ ) {
159
175
override fun preprocessInput (data : FloatArray , tensorShape : LongArray ): FloatArray {
160
176
return preprocessInput(data, tensorShape, inputType = InputType .CAFFE )
161
177
}
@@ -181,7 +197,12 @@ public object TFModels {
181
197
* Official ResNet101 model from Keras.applications.</a>
182
198
*/
183
199
public class ResNet101 (noTop : Boolean = false , inputShape : IntArray? = null ) :
184
- CV <Functional >(" models/tensorflow/cv/resnet101" , inputShape = inputShape, noTop = noTop) {
200
+ CV <Functional >(
201
+ " models/tensorflow/cv/resnet101" ,
202
+ inputShape = inputShape,
203
+ noTop = noTop,
204
+ inputColorMode = ColorMode .BGR
205
+ ) {
185
206
override fun preprocessInput (data : FloatArray , tensorShape : LongArray ): FloatArray {
186
207
return preprocessInput(data, tensorShape, inputType = InputType .CAFFE )
187
208
}
@@ -207,7 +228,12 @@ public object TFModels {
207
228
* Official ResNet152 model from Keras.applications.</a>
208
229
*/
209
230
public class ResNet152 (noTop : Boolean = false , inputShape : IntArray? = null ) :
210
- CV <Functional >(" models/tensorflow/cv/resnet152" , inputShape = inputShape, noTop = noTop) {
231
+ CV <Functional >(
232
+ " models/tensorflow/cv/resnet152" ,
233
+ inputShape = inputShape,
234
+ noTop = noTop,
235
+ inputColorMode = ColorMode .BGR
236
+ ) {
211
237
override fun preprocessInput (data : FloatArray , tensorShape : LongArray ): FloatArray {
212
238
return preprocessInput(data, tensorShape, inputType = InputType .CAFFE )
213
239
}
@@ -548,6 +574,8 @@ public interface ModelType<T : InferenceModel, U : InferenceModel> {
548
574
*/
549
575
public val channelsFirst: Boolean
550
576
577
+ public val inputColorMode: ColorMode
578
+
551
579
/* *
552
580
* Common preprocessing function for the Neural Networks trained on ImageNet and whose weights are available with the keras.application.
553
581
*
0 commit comments