Skip to content

Commit 90aa125

Browse files
committed
Add passing modelKindDescription to the copy functions (Kotlin#368)
1 parent 4398dd1 commit 90aa125

File tree

11 files changed

+41
-15
lines changed

11 files changed

+41
-15
lines changed

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/summary/ModelSummary.kt

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package org.jetbrains.kotlinx.dl.api.summary
22

3-
import org.jetbrains.kotlinx.dl.api.inference.loaders.ModelType
4-
53
/**
64
* Common interface for model summary.
75
*/
@@ -36,14 +34,14 @@ public class EmptySummary : ModelSummary {
3634

3735
/**
3836
* The summary for the models from ModelHub.
39-
* It appends corresponding [ModelType] to the header of the model summary.
37+
* It appends corresponding [modelKindDescription] to the header of the model summary.
4038
*
4139
* @property [modelType] type of the model, aka. model architecture. E.g. VGG16, ResNet50, etc.
4240
* @property [internalSummary] summary of the internal model used for inference
4341
*/
4442
public class ModelHubModelSummary(
4543
private val internalSummary: ModelSummary,
46-
private val modelKindDescription: String?
44+
private val modelKindDescription: String? = null
4745
) : ModelSummary {
4846
override fun format(
4947
columnSeparator: String,
@@ -61,7 +59,7 @@ public class ModelHubModelSummary(
6159
val separator = thickLineSeparatorSymbol.toString().repeat(tableWidth)
6260

6361
val tableWithHeader = mutableListOf(separator, modelTypeHeader)
64-
table.forEach(tableWithHeader::add)
62+
tableWithHeader.addAll(table)
6563

6664
return tableWithHeader
6765
}

impl/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/impl/inference/imagerecognition/ImageRecognitionModelBase.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ import org.jetbrains.kotlinx.dl.api.summary.ModelSummary
1616
/**
1717
* Base class for image classification models.
1818
* @property [internalModel] model used for prediction
19-
* @property [modelKindDescription] High-level description of the model. Used for model summary printing. For the models from [OnnxModels] it equals to the string representation of [OnnxModelType]
19+
* @property [modelKindDescription] High-level description of the model. Used for model summary printing.
20+
* For the models from ModelHub it equals to the string representation of [ModelType].
2021
*/
2122
public abstract class ImageRecognitionModelBase<I>(
2223
protected val internalModel: InferenceModel,

onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/facealignment/FaceDetectionModel.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ public class FaceDetectionModel(
3838
.call(ONNXModels.FaceDetection.defaultPreprocessor)
3939

4040
override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel {
41-
return FaceDetectionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights))
41+
return FaceDetectionModel(
42+
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
43+
modelKindDescription
44+
)
4245
}
4346
}
4447

onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/facealignment/Fan2D106FaceAlignmentModel.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ public class Fan2D106FaceAlignmentModel(
4242
.toFloatArray { layout = TensorLayout.NCHW }
4343

4444
override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel {
45-
return Fan2D106FaceAlignmentModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights))
45+
return Fan2D106FaceAlignmentModel(
46+
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
47+
modelKindDescription
48+
)
4649
}
4750
}
4851

onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/facealignment/FaceDetectionModel.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ public class FaceDetectionModel(
3939
.call(ONNXModels.FaceDetection.defaultPreprocessor)
4040

4141
override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel {
42-
return FaceDetectionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights))
42+
return FaceDetectionModel(
43+
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
44+
modelKindDescription
45+
)
4346
}
4447
}

onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/facealignment/Fan2D106FaceAlignmentModel.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ public class Fan2D106FaceAlignmentModel(
6262
saveOptimizerState: Boolean,
6363
copyWeights: Boolean
6464
): Fan2D106FaceAlignmentModel {
65-
return Fan2D106FaceAlignmentModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights))
65+
return Fan2D106FaceAlignmentModel(
66+
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
67+
modelKindDescription
68+
)
6669
}
6770
}
6871

onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/objectdetection/EfficientDetObjectDetectionModel.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ public class EfficientDetObjectDetectionModel(
7070
saveOptimizerState: Boolean,
7171
copyWeights: Boolean
7272
): EfficientDetObjectDetectionModel {
73-
return EfficientDetObjectDetectionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights))
73+
return EfficientDetObjectDetectionModel(
74+
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
75+
modelKindDescription
76+
)
7477
}
7578
}

onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/objectdetection/SSDMobileNetV1ObjectDetectionModel.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ public class SSDMobileNetV1ObjectDetectionModel(
7979
saveOptimizerState: Boolean,
8080
copyWeights: Boolean
8181
): SSDMobileNetV1ObjectDetectionModel {
82-
return SSDMobileNetV1ObjectDetectionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights))
82+
return SSDMobileNetV1ObjectDetectionModel(
83+
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
84+
modelKindDescription
85+
)
8386
}
8487
}

onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/objectdetection/SSDObjectDetectionModel.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ public class SSDObjectDetectionModel(
110110
saveOptimizerState: Boolean,
111111
copyWeights: Boolean
112112
): SSDObjectDetectionModel {
113-
return SSDObjectDetectionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights))
113+
return SSDObjectDetectionModel(
114+
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
115+
modelKindDescription
116+
)
114117
}
115118
}

onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/posedetection/MultiPoseDetectionModel.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ public class MultiPoseDetectionModel(
6666
saveOptimizerState: Boolean,
6767
copyWeights: Boolean
6868
): MultiPoseDetectionModel {
69-
return MultiPoseDetectionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights))
69+
return MultiPoseDetectionModel(
70+
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
71+
modelKindDescription
72+
)
7073
}
7174
}

0 commit comments

Comments
 (0)