@@ -167,7 +167,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
167
167
168
168
private lateinit var methodType: CgTestMethodType
169
169
170
- private val cachedFieldsFromAllExecutions = mutableMapOf<Pair <FieldId , Int >, MutableList <UtModel >>()
170
+ private val fieldsOfExecutionResults = mutableMapOf<Pair <FieldId , Int >, MutableList <UtModel >>()
171
171
172
172
private fun setupInstrumentation () {
173
173
if (currentExecution is UtSymbolicExecution ) {
@@ -831,27 +831,24 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
831
831
depth : Int ,
832
832
visitedModels : MutableSet <UtModel >
833
833
) {
834
+ // if field is static, it is represents itself in "before" and
835
+ // "after" state: no need to assert its equality to itself.
836
+ if (fieldId.isStatic) {
837
+ return
838
+ }
839
+
840
+ // if model is already processed, so we don't want to add new statements
841
+ if (fieldModel in visitedModels) {
842
+ return
843
+ }
844
+
834
845
when (parametrizedTestSource) {
835
846
ParametrizedTestSource .DO_NOT_PARAMETRIZE -> {
836
- traverseField(
837
- fieldId,
838
- fieldModel,
839
- expected,
840
- actual,
841
- depth,
842
- visitedModels
843
- )
847
+ traverseField(fieldId, fieldModel, expected, actual, depth, visitedModels)
844
848
}
845
849
846
850
ParametrizedTestSource .PARAMETRIZE -> {
847
- traverseFieldForParametrizedTest(
848
- fieldId,
849
- fieldModel,
850
- expected,
851
- actual,
852
- depth,
853
- visitedModels
854
- )
851
+ traverseFieldForParametrizedTest(fieldId, fieldModel, expected, actual, depth, visitedModels)
855
852
}
856
853
}
857
854
}
@@ -864,17 +861,6 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
864
861
depth : Int ,
865
862
visitedModels : MutableSet <UtModel >
866
863
) {
867
- // if field is static, it is represents itself in "before" and
868
- // "after" state: no need to assert its equality to itself.
869
- if (fieldId.isStatic) {
870
- return
871
- }
872
-
873
- // if model is already processed, so we don't want to add new statements
874
- if (fieldModel in visitedModels) {
875
- return
876
- }
877
-
878
864
// fieldModel is not visited and will be marked in assertDeepEquals call
879
865
val fieldName = fieldId.name
880
866
var expectedVariable: CgVariable ? = null
@@ -907,31 +893,25 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
907
893
depth : Int ,
908
894
visitedModels : MutableSet <UtModel >
909
895
) {
910
- // if field is static, it is represents itself in "before" and
911
- // "after" state: no need to assert its equality to itself.
912
- if (fieldId.isStatic) {
913
- return
914
- }
896
+ val fieldResultModels = fieldsOfExecutionResults[fieldId to depth]
897
+ val nullResultModelInExecutions = fieldResultModels?.find { it.isNull() }
898
+ val notNullResultModelInExecutions = fieldResultModels?.find { it.isNotNull() }
915
899
916
- val hasNullModelFieldInOtherExecutions = cachedFieldsFromAllExecutions[fieldId to depth]?.any { it.isNull() } ? : false
917
- val notNullModelFieldFromOtherExecution = cachedFieldsFromAllExecutions[fieldId to depth]?.find { it.isNotNull() }
900
+ val hasNullResultModel = nullResultModelInExecutions != null
901
+ val hasNotNullResultModel = notNullResultModelInExecutions != null
918
902
919
- val needToSubstituteFieldModel = notNullModelFieldFromOtherExecution != null && fieldModel is UtNullModel
903
+ val needToSubstituteFieldModel = fieldModel is UtNullModel && hasNotNullResultModel
920
904
921
- val fieldModel = if (needToSubstituteFieldModel) notNullModelFieldFromOtherExecution !! else fieldModel
905
+ val fieldModelForAssert = if (needToSubstituteFieldModel) notNullResultModelInExecutions !! else fieldModel
922
906
923
- val needIfStatement = needToSubstituteFieldModel || hasNullModelFieldInOtherExecutions
924
-
925
- // if model is already processed, so we don't want to add new statements
926
- if (fieldModel in visitedModels) {
927
- return
928
- }
907
+ // val needIfStatement = needToSubstituteFieldModel || hasNullResultModel
929
908
930
909
// fieldModel is not visited and will be marked in assertDeepEquals call
931
910
val fieldName = fieldId.name
932
911
var expectedVariable: CgVariable ? = null
933
912
934
- if (needExpectedDeclaration(fieldModel)) {
913
+ val needExpectedDeclaration = needExpectedDeclaration(fieldModelForAssert)
914
+ if (needExpectedDeclaration) {
935
915
val expectedFieldDeclaration = createDeclarationForFieldFromVariable(fieldId, expected, fieldName)
936
916
937
917
currentBlock + = expectedFieldDeclaration
@@ -941,13 +921,13 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
941
921
val actualFieldDeclaration = createDeclarationForFieldFromVariable(fieldId, actual, fieldName)
942
922
currentBlock + = actualFieldDeclaration
943
923
944
- if (needExpectedDeclaration(fieldModel) && needIfStatement ) {
924
+ if (needExpectedDeclaration && hasNullResultModel ) {
945
925
ifStatement(
946
926
CgEqualTo (expectedVariable!! , nullLiteral()),
947
927
trueBranch = { + testFrameworkManager.assertions[testFramework.assertNull](actualFieldDeclaration.variable).toStatement() },
948
928
falseBranch = {
949
929
assertDeepEquals(
950
- fieldModel ,
930
+ fieldModelForAssert ,
951
931
expectedVariable,
952
932
actualFieldDeclaration.variable,
953
933
depth + 1 ,
@@ -957,7 +937,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
957
937
)
958
938
} else {
959
939
assertDeepEquals(
960
- fieldModel ,
940
+ fieldModelForAssert ,
961
941
expectedVariable,
962
942
actualFieldDeclaration.variable,
963
943
depth + 1 ,
@@ -967,7 +947,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
967
947
emptyLineIfNeeded()
968
948
}
969
949
970
- private fun traverseExecutionsFieldsForParametrizedTest () {
950
+ private fun collectExecutionsResultFields () {
971
951
val successfulExecutionsModels = allExecutions
972
952
.filter {
973
953
it.result is UtExecutionSuccess
@@ -979,22 +959,31 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
979
959
when (model) {
980
960
is UtCompositeModel -> {
981
961
for ((fieldId, fieldModel) in model.fields) {
982
- traverseExecutionsFieldsForParametrizedTestRecursively (fieldId, fieldModel, 0 )
962
+ collectExecutionsResultFieldsRecursively (fieldId, fieldModel, 0 )
983
963
}
984
964
}
985
965
986
966
is UtAssembleModel -> {
987
- for ((fieldId, fieldModel) in model.origin!! .fields) {
988
- traverseExecutionsFieldsForParametrizedTestRecursively(fieldId, fieldModel, 0 )
967
+ model.origin?.let {
968
+ for ((fieldId, fieldModel) in it.fields) {
969
+ collectExecutionsResultFieldsRecursively(fieldId, fieldModel, 0 )
970
+ }
989
971
}
990
972
}
991
973
992
- else -> {} // TODO: check this specific case
974
+ is UtNullModel ,
975
+ is UtPrimitiveModel ,
976
+ is UtArrayModel ,
977
+ is UtClassRefModel ,
978
+ is UtEnumConstantModel ,
979
+ is UtVoidModel -> {
980
+ // only [UtCompositeModel] and [UtAssembleModel] have fields to traverse
981
+ }
993
982
}
994
983
}
995
984
}
996
985
997
- private fun traverseExecutionsFieldsForParametrizedTestRecursively (
986
+ private fun collectExecutionsResultFieldsRecursively (
998
987
fieldId : FieldId ,
999
988
fieldModel : UtModel ,
1000
989
depth : Int ,
@@ -1003,43 +992,32 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
1003
992
return
1004
993
}
1005
994
995
+ val fieldKey = fieldId to depth
996
+ fieldsOfExecutionResults.getOrPut(fieldKey) { mutableListOf () } + = fieldModel
997
+
1006
998
when (fieldModel) {
1007
999
is UtCompositeModel -> {
1008
1000
for ((id, model) in fieldModel.fields) {
1009
- if (id.isInnerClassEnclosingClassReference) continue
1010
-
1011
- if (id.isStatic) {
1012
- return
1013
- }
1014
-
1015
- if (cachedFieldsFromAllExecutions[id to depth] != null ) {
1016
- cachedFieldsFromAllExecutions[id to depth]!! .add(model)
1017
- } else {
1018
- cachedFieldsFromAllExecutions[id to depth] = listOf (model).toMutableList()
1019
- }
1020
-
1021
- traverseExecutionsFieldsForParametrizedTestRecursively(
1022
- id,
1023
- model,
1024
- depth + 1 ,
1025
- )
1001
+ collectExecutionsResultFieldsRecursively(id, model, depth + 1 )
1026
1002
}
1027
1003
}
1028
1004
1029
1005
is UtAssembleModel -> {
1030
1006
fieldModel.origin?.let {
1031
- traverseExecutionsFieldsForParametrizedTestRecursively(fieldId, it, depth)
1007
+ for ((id, model) in it.fields) {
1008
+ collectExecutionsResultFieldsRecursively(id, model, depth + 1 )
1009
+ }
1032
1010
}
1033
- return
1034
1011
}
1035
1012
1036
- else -> {} // TODO: check this specific case
1037
- }
1038
-
1039
- if (cachedFieldsFromAllExecutions[fieldId to depth] != null ) {
1040
- cachedFieldsFromAllExecutions[fieldId to depth]!! .add(fieldModel)
1041
- } else {
1042
- cachedFieldsFromAllExecutions[fieldId to depth] = listOf (fieldModel).toMutableList()
1013
+ is UtNullModel ,
1014
+ is UtPrimitiveModel ,
1015
+ is UtArrayModel ,
1016
+ is UtClassRefModel ,
1017
+ is UtEnumConstantModel ,
1018
+ is UtVoidModel -> {
1019
+ // only [UtCompositeModel] and [UtAssembleModel] have fields to traverse
1020
+ }
1043
1021
}
1044
1022
}
1045
1023
@@ -1234,7 +1212,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
1234
1212
when (parametrizedTestSource) {
1235
1213
ParametrizedTestSource .DO_NOT_PARAMETRIZE -> generateDeepEqualsAssertion(expected, actual)
1236
1214
ParametrizedTestSource .PARAMETRIZE -> {
1237
- traverseExecutionsFieldsForParametrizedTest ()
1215
+ collectExecutionsResultFields ()
1238
1216
1239
1217
when {
1240
1218
actual.type.isPrimitive -> generateDeepEqualsAssertion(expected, actual)
0 commit comments