Skip to content

Commit fc8f349

Browse files
committed
Fix channels ordering for classification models (Kotlin#400)
1 parent 89e4769 commit fc8f349

File tree

3 files changed

+89
-16
lines changed

3 files changed

+89
-16
lines changed

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/imagerecognition/ImageRecognitionModel.kt

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ public class ImageRecognitionModel(
5959

6060
/**
6161
* Predicts [topK] objects for the given [imageFile].
62+
* Default [Preprocessing] is applied to an image.
63+
*
64+
* @param [imageFile] Input image [File].
65+
* @param [topK] Number of top ranked predictions to return
66+
*
67+
* @see preprocessData
6268
*
6369
* @return The list of pairs <label, probability> sorted from the most probable to the lowest probable.
6470
*/
@@ -67,8 +73,22 @@ public class ImageRecognitionModel(
6773
return predictTopKImageNetLabels(internalModel, inputData, imageNetClassLabels, topK)
6874
}
6975

76+
/**
77+
* Predicts [topK] objects for the given [imageFile] with a custom [Preprocessing] provided.
78+
*
79+
* @param [imageFile] Input image [File].
80+
* @param [preprocessing] custom [Preprocessing] instance
81+
* @param [topK] Number of top ranked predictions to return
82+
*
83+
* @return The list of pairs <label, probability> sorted from the most probable to the lowest probable.
84+
*/
85+
public fun predictTopKObjects(imageFile: File, preprocessing: Preprocessing, topK: Int = 5): List<Pair<String, Float>> {
86+
val (inputData, _) = preprocessing(imageFile)
87+
return predictTopKImageNetLabels(internalModel, inputData, imageNetClassLabels, topK)
88+
}
89+
7090
private fun preprocessData(imageFile: File): FloatArray {
71-
val (weight, height) = if (modelType.channelsFirst)
91+
val (width, height) = if (modelType.channelsFirst)
7292
Pair(internalModel.inputDimensions[1], internalModel.inputDimensions[2])
7393
else
7494
Pair(internalModel.inputDimensions[0], internalModel.inputDimensions[1])
@@ -77,10 +97,10 @@ public class ImageRecognitionModel(
7797
transformImage {
7898
resize {
7999
outputHeight = height.toInt()
80-
outputWidth = weight.toInt()
100+
outputWidth = width.toInt()
81101
interpolation = InterpolationType.BILINEAR
82102
}
83-
convert { colorMode = ColorMode.BGR }
103+
convert { colorMode = modelType.inputColorMode }
84104
}
85105
}
86106

@@ -89,11 +109,28 @@ public class ImageRecognitionModel(
89109

90110
/**
91111
* Predicts object for the given [imageFile].
112+
* Default [Preprocessing] is applied to an image.
113+
*
114+
* @param [imageFile] Input image [File].
115+
* @see preprocessData
92116
*
93117
* @return The label of the recognized object with the highest probability.
94118
*/
95119
public fun predictObject(imageFile: File): String {
96120
val inputData = preprocessData(imageFile)
97121
return imageNetClassLabels[internalModel.predict(inputData)]!!
98122
}
123+
124+
/**
125+
* Predicts object for the given [imageFile] with a custom [Preprocessing] provided.
126+
*
127+
* @param [imageFile] Input image [File].
128+
* @param [preprocessing] custom [Preprocessing] instance
129+
*
130+
* @return The label of the recognized object with the highest probability.
131+
*/
132+
public fun predictObject(imageFile: File, preprocessing: Preprocessing): String {
133+
val (inputData, _) = preprocessing(imageFile)
134+
return imageNetClassLabels[internalModel.predict(inputData)]!!
135+
}
99136
}

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/loaders/TFModels.kt

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
1414
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
1515
import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionModel
1616
import org.jetbrains.kotlinx.dl.api.inference.keras.loadWeights
17+
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
1718
import org.jetbrains.kotlinx.dl.dataset.preprocessor.Preprocessing
1819
import java.io.File
1920

@@ -27,10 +28,10 @@ public object TFModels {
2728
public sealed class CV<T : GraphTrainableModel>(
2829
override val modelRelativePath: String,
2930
override val channelsFirst: Boolean = false,
31+
override val inputColorMode: ColorMode = ColorMode.RGB,
3032
public var inputShape: IntArray? = null,
3133
internal var noTop: Boolean = false
32-
) :
33-
ModelType<T, ImageRecognitionModel> {
34+
) : ModelType<T, ImageRecognitionModel> {
3435

3536
init {
3637
if (inputShape != null) {
@@ -61,7 +62,12 @@ public object TFModels {
6162
* Official VGG16 model from Keras.applications.</a>
6263
*/
6364
public class VGG16(noTop: Boolean = false, inputShape: IntArray? = null) :
64-
CV<Sequential>("models/tensorflow/cv/vgg16", inputShape = inputShape, noTop = noTop) {
65+
CV<Sequential>(
66+
"models/tensorflow/cv/vgg16",
67+
inputShape = inputShape,
68+
noTop = noTop,
69+
inputColorMode = ColorMode.BGR
70+
) {
6571
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
6672
return preprocessInput(data, tensorShape, inputType = InputType.CAFFE)
6773
}
@@ -85,7 +91,12 @@ public object TFModels {
8591
* Official VGG19 model from Keras.applications.</a>
8692
*/
8793
public class VGG19(noTop: Boolean = false, inputShape: IntArray? = null) :
88-
CV<Sequential>("models/tensorflow/cv/vgg19", inputShape = inputShape, noTop = noTop) {
94+
CV<Sequential>(
95+
"models/tensorflow/cv/vgg19",
96+
inputShape = inputShape,
97+
noTop = noTop,
98+
inputColorMode = ColorMode.BGR
99+
) {
89100
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
90101
return preprocessInput(data, tensorShape, inputType = InputType.CAFFE)
91102
}
@@ -155,7 +166,12 @@ public object TFModels {
155166
* Official ResNet50 model from Keras.applications.</a>
156167
*/
157168
public class ResNet50(noTop: Boolean = false, inputShape: IntArray? = null) :
158-
CV<Functional>("models/tensorflow/cv/resnet50", inputShape = inputShape, noTop = noTop) {
169+
CV<Functional>(
170+
"models/tensorflow/cv/resnet50",
171+
inputShape = inputShape,
172+
noTop = noTop,
173+
inputColorMode = ColorMode.BGR
174+
) {
159175
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
160176
return preprocessInput(data, tensorShape, inputType = InputType.CAFFE)
161177
}
@@ -181,7 +197,12 @@ public object TFModels {
181197
* Official ResNet101 model from Keras.applications.</a>
182198
*/
183199
public class ResNet101(noTop: Boolean = false, inputShape: IntArray? = null) :
184-
CV<Functional>("models/tensorflow/cv/resnet101", inputShape = inputShape, noTop = noTop) {
200+
CV<Functional>(
201+
"models/tensorflow/cv/resnet101",
202+
inputShape = inputShape,
203+
noTop = noTop,
204+
inputColorMode = ColorMode.BGR
205+
) {
185206
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
186207
return preprocessInput(data, tensorShape, inputType = InputType.CAFFE)
187208
}
@@ -207,7 +228,12 @@ public object TFModels {
207228
* Official ResNet152 model from Keras.applications.</a>
208229
*/
209230
public class ResNet152(noTop: Boolean = false, inputShape: IntArray? = null) :
210-
CV<Functional>("models/tensorflow/cv/resnet152", inputShape = inputShape, noTop = noTop) {
231+
CV<Functional>(
232+
"models/tensorflow/cv/resnet152",
233+
inputShape = inputShape,
234+
noTop = noTop,
235+
inputColorMode = ColorMode.BGR
236+
) {
211237
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
212238
return preprocessInput(data, tensorShape, inputType = InputType.CAFFE)
213239
}
@@ -548,6 +574,8 @@ public interface ModelType<T : InferenceModel, U : InferenceModel> {
548574
*/
549575
public val channelsFirst: Boolean
550576

577+
public val inputColorMode: ColorMode
578+
551579
/**
552580
* Common preprocessing function for the Neural Networks trained on ImageNet and whose weights are available with the keras.application.
553581
*

onnx/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDMobileNetV
1717
import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDObjectDetectionModel
1818
import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.MultiPoseDetectionModel
1919
import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.SinglePoseDetectionModel
20+
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
2021
import org.jetbrains.kotlinx.dl.dataset.preprocessor.ImageShape
2122
import org.jetbrains.kotlinx.dl.dataset.preprocessor.Transpose
2223

@@ -26,10 +27,10 @@ public object ONNXModels {
2627
public sealed class CV<T : InferenceModel>(
2728
override val modelRelativePath: String,
2829
override val channelsFirst: Boolean,
30+
override val inputColorMode: ColorMode = ColorMode.RGB,
2931
/** If true, model is shipped without last few layers and could be used for transfer learning and fine-tuning with TF Runtime. */
3032
internal var noTop: Boolean = false
31-
) :
32-
ModelType<T, ImageRecognitionModel> {
33+
) : ModelType<T, ImageRecognitionModel> {
3334
override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel {
3435
return ImageRecognitionModel(modelHub.loadModel(this), this)
3536
}
@@ -323,7 +324,11 @@ public object ONNXModels {
323324
* Official ResNet model from Keras.applications.</a>
324325
*/
325326
public object ResNet50custom :
326-
CV<OnnxInferenceModel>("models/onnx/cv/custom/resnet50", channelsFirst = false) {
327+
CV<OnnxInferenceModel>(
328+
"models/onnx/cv/custom/resnet50",
329+
channelsFirst = false,
330+
inputColorMode = ColorMode.BGR
331+
) {
327332
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
328333
return preprocessInput(
329334
data,
@@ -612,7 +617,8 @@ public object ONNXModels {
612617
/** Object detection models and preprocessing. */
613618
public sealed class ObjectDetection<T : InferenceModel, U : InferenceModel>(
614619
override val modelRelativePath: String,
615-
override val channelsFirst: Boolean = true
620+
override val channelsFirst: Boolean = true,
621+
override val inputColorMode: ColorMode = ColorMode.RGB
616622
) :
617623
ModelType<T, U> {
618624
/**
@@ -964,7 +970,8 @@ public object ONNXModels {
964970
/** Face alignment models and preprocessing. */
965971
public sealed class FaceAlignment<T : InferenceModel, U : InferenceModel>(
966972
override val modelRelativePath: String,
967-
override val channelsFirst: Boolean = true
973+
override val channelsFirst: Boolean = true,
974+
override val inputColorMode: ColorMode = ColorMode.RGB
968975
) :
969976
ModelType<T, U> {
970977
/**
@@ -996,7 +1003,8 @@ public object ONNXModels {
9961003
/** Pose detection models. */
9971004
public sealed class PoseDetection<T : InferenceModel, U : InferenceModel>(
9981005
override val modelRelativePath: String,
999-
override val channelsFirst: Boolean = true
1006+
override val channelsFirst: Boolean = true,
1007+
override val inputColorMode: ColorMode = ColorMode.RGB
10001008
) :
10011009
ModelType<T, U> {
10021010
/**

0 commit comments

Comments
 (0)