@@ -167,7 +167,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
167
167
168
168
private lateinit var methodType: CgTestMethodType
169
169
170
- private val cachedFieldsFromAllExecutions = mutableMapOf<Pair <FieldId , Int >, MutableList <UtModel >>()
170
+ private val fieldsOfExecutionResults = mutableMapOf<Pair <FieldId , Int >, MutableList <UtModel >>()
171
171
172
172
private fun setupInstrumentation () {
173
173
if (currentExecution is UtSymbolicExecution ) {
@@ -831,27 +831,24 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
831
831
depth : Int ,
832
832
visitedModels : MutableSet <UtModel >
833
833
) {
834
+ // if field is static, it is represents itself in "before" and
835
+ // "after" state: no need to assert its equality to itself.
836
+ if (fieldId.isStatic) {
837
+ return
838
+ }
839
+
840
+ // if model is already processed, so we don't want to add new statements
841
+ if (fieldModel in visitedModels) {
842
+ return
843
+ }
844
+
834
845
when (parametrizedTestSource) {
835
846
ParametrizedTestSource .DO_NOT_PARAMETRIZE -> {
836
- traverseField(
837
- fieldId,
838
- fieldModel,
839
- expected,
840
- actual,
841
- depth,
842
- visitedModels
843
- )
847
+ traverseField(fieldId, fieldModel, expected, actual, depth, visitedModels)
844
848
}
845
849
846
850
ParametrizedTestSource .PARAMETRIZE -> {
847
- traverseFieldForParametrizedTest(
848
- fieldId,
849
- fieldModel,
850
- expected,
851
- actual,
852
- depth,
853
- visitedModels
854
- )
851
+ traverseFieldForParametrizedTest(fieldId, fieldModel, expected, actual, depth, visitedModels)
855
852
}
856
853
}
857
854
}
@@ -864,17 +861,6 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
864
861
depth : Int ,
865
862
visitedModels : MutableSet <UtModel >
866
863
) {
867
- // if field is static, it is represents itself in "before" and
868
- // "after" state: no need to assert its equality to itself.
869
- if (fieldId.isStatic) {
870
- return
871
- }
872
-
873
- // if model is already processed, so we don't want to add new statements
874
- if (fieldModel in visitedModels) {
875
- return
876
- }
877
-
878
864
// fieldModel is not visited and will be marked in assertDeepEquals call
879
865
val fieldName = fieldId.name
880
866
var expectedVariable: CgVariable ? = null
@@ -907,31 +893,23 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
907
893
depth : Int ,
908
894
visitedModels : MutableSet <UtModel >
909
895
) {
910
- // if field is static, it is represents itself in "before" and
911
- // "after" state: no need to assert its equality to itself.
912
- if (fieldId.isStatic) {
913
- return
914
- }
896
+ val fieldResultModels = fieldsOfExecutionResults[fieldId to depth]
897
+ val nullResultModelInExecutions = fieldResultModels?.find { it.isNull() }
898
+ val notNullResultModelInExecutions = fieldResultModels?.find { it.isNotNull() }
915
899
916
- val hasNullModelFieldInOtherExecutions = cachedFieldsFromAllExecutions[fieldId to depth]?.any { it.isNull() } ? : false
917
- val notNullModelFieldFromOtherExecution = cachedFieldsFromAllExecutions[fieldId to depth]?.find { it.isNotNull() }
900
+ val hasNullResultModel = nullResultModelInExecutions != null
901
+ val hasNotNullResultModel = notNullResultModelInExecutions != null
918
902
919
- val needToSubstituteFieldModel = notNullModelFieldFromOtherExecution != null && fieldModel is UtNullModel
903
+ val needToSubstituteFieldModel = fieldModel is UtNullModel && hasNotNullResultModel
920
904
921
- val fieldModel = if (needToSubstituteFieldModel) notNullModelFieldFromOtherExecution!! else fieldModel
922
-
923
- val needIfStatement = needToSubstituteFieldModel || hasNullModelFieldInOtherExecutions
924
-
925
- // if model is already processed, so we don't want to add new statements
926
- if (fieldModel in visitedModels) {
927
- return
928
- }
905
+ val fieldModelForAssert = if (needToSubstituteFieldModel) notNullResultModelInExecutions!! else fieldModel
929
906
930
907
// fieldModel is not visited and will be marked in assertDeepEquals call
931
908
val fieldName = fieldId.name
932
909
var expectedVariable: CgVariable ? = null
933
910
934
- if (needExpectedDeclaration(fieldModel)) {
911
+ val needExpectedDeclaration = needExpectedDeclaration(fieldModelForAssert)
912
+ if (needExpectedDeclaration) {
935
913
val expectedFieldDeclaration = createDeclarationForFieldFromVariable(fieldId, expected, fieldName)
936
914
937
915
currentBlock + = expectedFieldDeclaration
@@ -941,13 +919,13 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
941
919
val actualFieldDeclaration = createDeclarationForFieldFromVariable(fieldId, actual, fieldName)
942
920
currentBlock + = actualFieldDeclaration
943
921
944
- if (needExpectedDeclaration(fieldModel) && needIfStatement ) {
922
+ if (needExpectedDeclaration && hasNullResultModel ) {
945
923
ifStatement(
946
924
CgEqualTo (expectedVariable!! , nullLiteral()),
947
925
trueBranch = { + testFrameworkManager.assertions[testFramework.assertNull](actualFieldDeclaration.variable).toStatement() },
948
926
falseBranch = {
949
927
assertDeepEquals(
950
- fieldModel ,
928
+ fieldModelForAssert ,
951
929
expectedVariable,
952
930
actualFieldDeclaration.variable,
953
931
depth + 1 ,
@@ -957,7 +935,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
957
935
)
958
936
} else {
959
937
assertDeepEquals(
960
- fieldModel ,
938
+ fieldModelForAssert ,
961
939
expectedVariable,
962
940
actualFieldDeclaration.variable,
963
941
depth + 1 ,
@@ -967,7 +945,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
967
945
emptyLineIfNeeded()
968
946
}
969
947
970
- private fun traverseExecutionsFieldsForParametrizedTest () {
948
+ private fun collectExecutionsResultFields () {
971
949
val successfulExecutionsModels = allExecutions
972
950
.filter {
973
951
it.result is UtExecutionSuccess
@@ -979,22 +957,31 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
979
957
when (model) {
980
958
is UtCompositeModel -> {
981
959
for ((fieldId, fieldModel) in model.fields) {
982
- traverseExecutionsFieldsForParametrizedTestRecursively (fieldId, fieldModel, 0 )
960
+ collectExecutionsResultFieldsRecursively (fieldId, fieldModel, 0 )
983
961
}
984
962
}
985
963
986
964
is UtAssembleModel -> {
987
- for ((fieldId, fieldModel) in model.origin!! .fields) {
988
- traverseExecutionsFieldsForParametrizedTestRecursively(fieldId, fieldModel, 0 )
965
+ model.origin?.let {
966
+ for ((fieldId, fieldModel) in it.fields) {
967
+ collectExecutionsResultFieldsRecursively(fieldId, fieldModel, 0 )
968
+ }
989
969
}
990
970
}
991
971
992
- else -> {} // TODO: check this specific case
972
+ is UtNullModel ,
973
+ is UtPrimitiveModel ,
974
+ is UtArrayModel ,
975
+ is UtClassRefModel ,
976
+ is UtEnumConstantModel ,
977
+ is UtVoidModel -> {
978
+ // only [UtCompositeModel] and [UtAssembleModel] have fields to traverse
979
+ }
993
980
}
994
981
}
995
982
}
996
983
997
- private fun traverseExecutionsFieldsForParametrizedTestRecursively (
984
+ private fun collectExecutionsResultFieldsRecursively (
998
985
fieldId : FieldId ,
999
986
fieldModel : UtModel ,
1000
987
depth : Int ,
@@ -1003,43 +990,32 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
1003
990
return
1004
991
}
1005
992
993
+ val fieldKey = fieldId to depth
994
+ fieldsOfExecutionResults.getOrPut(fieldKey) { mutableListOf () } + = fieldModel
995
+
1006
996
when (fieldModel) {
1007
997
is UtCompositeModel -> {
1008
998
for ((id, model) in fieldModel.fields) {
1009
- if (id.isInnerClassEnclosingClassReference) continue
1010
-
1011
- if (id.isStatic) {
1012
- return
1013
- }
1014
-
1015
- if (cachedFieldsFromAllExecutions[id to depth] != null ) {
1016
- cachedFieldsFromAllExecutions[id to depth]!! .add(model)
1017
- } else {
1018
- cachedFieldsFromAllExecutions[id to depth] = listOf (model).toMutableList()
1019
- }
1020
-
1021
- traverseExecutionsFieldsForParametrizedTestRecursively(
1022
- id,
1023
- model,
1024
- depth + 1 ,
1025
- )
999
+ collectExecutionsResultFieldsRecursively(id, model, depth + 1 )
1026
1000
}
1027
1001
}
1028
1002
1029
1003
is UtAssembleModel -> {
1030
1004
fieldModel.origin?.let {
1031
- traverseExecutionsFieldsForParametrizedTestRecursively(fieldId, it, depth)
1005
+ for ((id, model) in it.fields) {
1006
+ collectExecutionsResultFieldsRecursively(id, model, depth + 1 )
1007
+ }
1032
1008
}
1033
- return
1034
1009
}
1035
1010
1036
- else -> {} // TODO: check this specific case
1037
- }
1038
-
1039
- if (cachedFieldsFromAllExecutions[fieldId to depth] != null ) {
1040
- cachedFieldsFromAllExecutions[fieldId to depth]!! .add(fieldModel)
1041
- } else {
1042
- cachedFieldsFromAllExecutions[fieldId to depth] = listOf (fieldModel).toMutableList()
1011
+ is UtNullModel ,
1012
+ is UtPrimitiveModel ,
1013
+ is UtArrayModel ,
1014
+ is UtClassRefModel ,
1015
+ is UtEnumConstantModel ,
1016
+ is UtVoidModel -> {
1017
+ // only [UtCompositeModel] and [UtAssembleModel] have fields to traverse
1018
+ }
1043
1019
}
1044
1020
}
1045
1021
@@ -1234,7 +1210,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
1234
1210
when (parametrizedTestSource) {
1235
1211
ParametrizedTestSource .DO_NOT_PARAMETRIZE -> generateDeepEqualsAssertion(expected, actual)
1236
1212
ParametrizedTestSource .PARAMETRIZE -> {
1237
- traverseExecutionsFieldsForParametrizedTest ()
1213
+ collectExecutionsResultFields ()
1238
1214
1239
1215
when {
1240
1216
actual.type.isPrimitive -> generateDeepEqualsAssertion(expected, actual)
0 commit comments