Skip to content

Commit 433567d

Browse files
committed
Refactoring
1 parent e86fa2a commit 433567d

File tree

2 files changed

+60
-82
lines changed
  • utbot-framework-api/src/main/kotlin/org/utbot/framework/plugin/api
  • utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/tree

2 files changed

+60
-82
lines changed

utbot-framework-api/src/main/kotlin/org/utbot/framework/plugin/api/Api.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,12 @@ sealed class UtReferenceModel(
282282
) : UtModel(classId)
283283

284284
/**
285-
* Checks if [UtModel] is a null.
285+
* Checks if [UtModel] is a [UtNullModel].
286286
*/
287287
fun UtModel.isNull() = this is UtNullModel
288288

289289
/**
290-
* Checks if [UtModel] is not a null.
290+
* Checks if [UtModel] is not a [UtNullModel].
291291
*/
292292
fun UtModel.isNotNull() = !isNull()
293293

utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/tree/CgMethodConstructor.kt

Lines changed: 58 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
167167

168168
private lateinit var methodType: CgTestMethodType
169169

170-
private val cachedFieldsFromAllExecutions = mutableMapOf<Pair<FieldId, Int>, MutableList<UtModel>>()
170+
private val fieldsOfExecutionResults = mutableMapOf<Pair<FieldId, Int>, MutableList<UtModel>>()
171171

172172
private fun setupInstrumentation() {
173173
if (currentExecution is UtSymbolicExecution) {
@@ -831,27 +831,24 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
831831
depth: Int,
832832
visitedModels: MutableSet<UtModel>
833833
) {
834+
// if field is static, it is represents itself in "before" and
835+
// "after" state: no need to assert its equality to itself.
836+
if (fieldId.isStatic) {
837+
return
838+
}
839+
840+
// if model is already processed, so we don't want to add new statements
841+
if (fieldModel in visitedModels) {
842+
return
843+
}
844+
834845
when (parametrizedTestSource) {
835846
ParametrizedTestSource.DO_NOT_PARAMETRIZE -> {
836-
traverseField(
837-
fieldId,
838-
fieldModel,
839-
expected,
840-
actual,
841-
depth,
842-
visitedModels
843-
)
847+
traverseField(fieldId, fieldModel, expected, actual, depth, visitedModels)
844848
}
845849

846850
ParametrizedTestSource.PARAMETRIZE -> {
847-
traverseFieldForParametrizedTest(
848-
fieldId,
849-
fieldModel,
850-
expected,
851-
actual,
852-
depth,
853-
visitedModels
854-
)
851+
traverseFieldForParametrizedTest(fieldId, fieldModel, expected, actual, depth, visitedModels)
855852
}
856853
}
857854
}
@@ -864,17 +861,6 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
864861
depth: Int,
865862
visitedModels: MutableSet<UtModel>
866863
) {
867-
// if field is static, it is represents itself in "before" and
868-
// "after" state: no need to assert its equality to itself.
869-
if (fieldId.isStatic) {
870-
return
871-
}
872-
873-
// if model is already processed, so we don't want to add new statements
874-
if (fieldModel in visitedModels) {
875-
return
876-
}
877-
878864
// fieldModel is not visited and will be marked in assertDeepEquals call
879865
val fieldName = fieldId.name
880866
var expectedVariable: CgVariable? = null
@@ -907,31 +893,25 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
907893
depth: Int,
908894
visitedModels: MutableSet<UtModel>
909895
) {
910-
// if field is static, it is represents itself in "before" and
911-
// "after" state: no need to assert its equality to itself.
912-
if (fieldId.isStatic) {
913-
return
914-
}
896+
val fieldResultModels = fieldsOfExecutionResults[fieldId to depth]
897+
val nullResultModelInExecutions = fieldResultModels?.find { it.isNull() }
898+
val notNullResultModelInExecutions = fieldResultModels?.find { it.isNotNull() }
915899

916-
val hasNullModelFieldInOtherExecutions = cachedFieldsFromAllExecutions[fieldId to depth]?.any { it.isNull() } ?: false
917-
val notNullModelFieldFromOtherExecution = cachedFieldsFromAllExecutions[fieldId to depth]?.find { it.isNotNull() }
900+
val hasNullResultModel = nullResultModelInExecutions != null
901+
val hasNotNullResultModel = notNullResultModelInExecutions != null
918902

919-
val needToSubstituteFieldModel = notNullModelFieldFromOtherExecution != null && fieldModel is UtNullModel
903+
val needToSubstituteFieldModel = fieldModel is UtNullModel && hasNotNullResultModel
920904

921-
val fieldModel = if (needToSubstituteFieldModel) notNullModelFieldFromOtherExecution!! else fieldModel
905+
val fieldModelForAssert = if (needToSubstituteFieldModel) notNullResultModelInExecutions!! else fieldModel
922906

923-
val needIfStatement = needToSubstituteFieldModel || hasNullModelFieldInOtherExecutions
924-
925-
// if model is already processed, so we don't want to add new statements
926-
if (fieldModel in visitedModels) {
927-
return
928-
}
907+
// val needIfStatement = needToSubstituteFieldModel || hasNullResultModel
929908

930909
// fieldModel is not visited and will be marked in assertDeepEquals call
931910
val fieldName = fieldId.name
932911
var expectedVariable: CgVariable? = null
933912

934-
if (needExpectedDeclaration(fieldModel)) {
913+
val needExpectedDeclaration = needExpectedDeclaration(fieldModelForAssert)
914+
if (needExpectedDeclaration) {
935915
val expectedFieldDeclaration = createDeclarationForFieldFromVariable(fieldId, expected, fieldName)
936916

937917
currentBlock += expectedFieldDeclaration
@@ -941,13 +921,13 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
941921
val actualFieldDeclaration = createDeclarationForFieldFromVariable(fieldId, actual, fieldName)
942922
currentBlock += actualFieldDeclaration
943923

944-
if (needExpectedDeclaration(fieldModel) && needIfStatement) {
924+
if (needExpectedDeclaration && hasNullResultModel) {
945925
ifStatement(
946926
CgEqualTo(expectedVariable!!, nullLiteral()),
947927
trueBranch = { +testFrameworkManager.assertions[testFramework.assertNull](actualFieldDeclaration.variable).toStatement() },
948928
falseBranch = {
949929
assertDeepEquals(
950-
fieldModel,
930+
fieldModelForAssert,
951931
expectedVariable,
952932
actualFieldDeclaration.variable,
953933
depth + 1,
@@ -957,7 +937,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
957937
)
958938
} else {
959939
assertDeepEquals(
960-
fieldModel,
940+
fieldModelForAssert,
961941
expectedVariable,
962942
actualFieldDeclaration.variable,
963943
depth + 1,
@@ -967,7 +947,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
967947
emptyLineIfNeeded()
968948
}
969949

970-
private fun traverseExecutionsFieldsForParametrizedTest() {
950+
private fun collectExecutionsResultFields() {
971951
val successfulExecutionsModels = allExecutions
972952
.filter {
973953
it.result is UtExecutionSuccess
@@ -979,22 +959,31 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
979959
when (model) {
980960
is UtCompositeModel -> {
981961
for ((fieldId, fieldModel) in model.fields) {
982-
traverseExecutionsFieldsForParametrizedTestRecursively(fieldId, fieldModel, 0)
962+
collectExecutionsResultFieldsRecursively(fieldId, fieldModel, 0)
983963
}
984964
}
985965

986966
is UtAssembleModel -> {
987-
for ((fieldId, fieldModel) in model.origin!!.fields) {
988-
traverseExecutionsFieldsForParametrizedTestRecursively(fieldId, fieldModel, 0)
967+
model.origin?.let {
968+
for ((fieldId, fieldModel) in it.fields) {
969+
collectExecutionsResultFieldsRecursively(fieldId, fieldModel, 0)
970+
}
989971
}
990972
}
991973

992-
else -> {} // TODO: check this specific case
974+
is UtNullModel,
975+
is UtPrimitiveModel,
976+
is UtArrayModel,
977+
is UtClassRefModel,
978+
is UtEnumConstantModel,
979+
is UtVoidModel -> {
980+
// only [UtCompositeModel] and [UtAssembleModel] have fields to traverse
981+
}
993982
}
994983
}
995984
}
996985

997-
private fun traverseExecutionsFieldsForParametrizedTestRecursively(
986+
private fun collectExecutionsResultFieldsRecursively(
998987
fieldId: FieldId,
999988
fieldModel: UtModel,
1000989
depth: Int,
@@ -1003,43 +992,32 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
1003992
return
1004993
}
1005994

995+
val fieldKey = fieldId to depth
996+
fieldsOfExecutionResults.getOrPut(fieldKey) { mutableListOf() } += fieldModel
997+
1006998
when (fieldModel) {
1007999
is UtCompositeModel -> {
10081000
for ((id, model) in fieldModel.fields) {
1009-
if (id.isInnerClassEnclosingClassReference) continue
1010-
1011-
if (id.isStatic) {
1012-
return
1013-
}
1014-
1015-
if (cachedFieldsFromAllExecutions[id to depth] != null) {
1016-
cachedFieldsFromAllExecutions[id to depth]!!.add(model)
1017-
} else {
1018-
cachedFieldsFromAllExecutions[id to depth] = listOf(model).toMutableList()
1019-
}
1020-
1021-
traverseExecutionsFieldsForParametrizedTestRecursively(
1022-
id,
1023-
model,
1024-
depth + 1,
1025-
)
1001+
collectExecutionsResultFieldsRecursively(id, model, depth + 1)
10261002
}
10271003
}
10281004

10291005
is UtAssembleModel -> {
10301006
fieldModel.origin?.let {
1031-
traverseExecutionsFieldsForParametrizedTestRecursively(fieldId, it, depth)
1007+
for ((id, model) in it.fields) {
1008+
collectExecutionsResultFieldsRecursively(id, model, depth + 1)
1009+
}
10321010
}
1033-
return
10341011
}
10351012

1036-
else -> {} // TODO: check this specific case
1037-
}
1038-
1039-
if (cachedFieldsFromAllExecutions[fieldId to depth] != null) {
1040-
cachedFieldsFromAllExecutions[fieldId to depth]!!.add(fieldModel)
1041-
} else {
1042-
cachedFieldsFromAllExecutions[fieldId to depth] = listOf(fieldModel).toMutableList()
1013+
is UtNullModel,
1014+
is UtPrimitiveModel,
1015+
is UtArrayModel,
1016+
is UtClassRefModel,
1017+
is UtEnumConstantModel,
1018+
is UtVoidModel -> {
1019+
// only [UtCompositeModel] and [UtAssembleModel] have fields to traverse
1020+
}
10431021
}
10441022
}
10451023

@@ -1234,7 +1212,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
12341212
when (parametrizedTestSource) {
12351213
ParametrizedTestSource.DO_NOT_PARAMETRIZE -> generateDeepEqualsAssertion(expected, actual)
12361214
ParametrizedTestSource.PARAMETRIZE -> {
1237-
traverseExecutionsFieldsForParametrizedTest()
1215+
collectExecutionsResultFields()
12381216

12391217
when {
12401218
actual.type.isPrimitive -> generateDeepEqualsAssertion(expected, actual)

0 commit comments

Comments
 (0)