1919
2020
2121import org .greencodeinitiative .creedengo .python .utils .UtilsAST ;
22- import org .sonar .check .Priority ;
2322import org .sonar .check .Rule ;
2423import org .sonar .plugins .python .api .PythonSubscriptionCheck ;
2524import org .sonar .plugins .python .api .SubscriptionContext ;
5554
5655public class AvoidConvBiasBeforeBatchNorm extends PythonSubscriptionCheck {
5756
58- private static final String nnModuleFullyQualifiedName = "torch.nn.Module" ;
59- private static final String convFullyQualifiedName = "torch.nn.Conv2d" ;
60- private static final String forwardMethodName = "forward" ;
61- private static final String batchNormFullyQualifiedName = "torch.nn.BatchNorm2d" ;
62- private static final String sequentialModuleFullyQualifiedName = "torch.nn.Sequential" ;
57+ private static final String NN_MODULE_FULLY_QUALIFIED_NAME = "torch.nn.Module" ;
58+ private static final String CONV_FULLY_QUALIFIED_NAME = "torch.nn.Conv2d" ;
59+ private static final String FORWARD_METHOD_NAME = "forward" ;
60+ private static final String BATCH_NORM_FULLY_QUALIFIED_NAME = "torch.nn.BatchNorm2d" ;
61+ private static final String SEQUENTIAL_MODULE_FULLY_QUALIFIED_NAME = "torch.nn.Sequential" ;
6362 protected static final String MESSAGE = "Remove bias for convolutions before batch norm layers to save time and memory." ;
6463
6564 @ Override
@@ -76,15 +75,15 @@ private boolean isConvWithBias(CallExpression convDefinition) {
7675 return true ;
7776 else {
7877 Expression expression = biasArgument .expression ();
79- return expression .is (NAME ) && (( Name ) expression ).name (). equals ( "True" );
78+ return expression .is (NAME ) && "True" . equals ((( Name ) expression ).name ());
8079 }
8180 }
8281
8382 private boolean isModelClass (ClassDef classDef ) {
8483 ClassSymbol classSymbol = (ClassSymbol ) classDef .name ().symbol ();
8584 if (classSymbol != null ) {
86- return classSymbol .superClasses ().stream ().anyMatch (e -> Objects .equals (e .fullyQualifiedName (), nnModuleFullyQualifiedName ))
87- && classSymbol .declaredMembers ().stream ().anyMatch (e -> e .name (). equals ( forwardMethodName ));
85+ return classSymbol .superClasses ().stream ().anyMatch (e -> Objects .equals (e .fullyQualifiedName (), NN_MODULE_FULLY_QUALIFIED_NAME ))
86+ && classSymbol .declaredMembers ().stream ().anyMatch (e -> FORWARD_METHOD_NAME . equals ( e .name ()));
8887 } else
8988 return false ;
9089 }
@@ -142,12 +141,12 @@ private void reportForSequentialModules(SubscriptionContext context, CallExpress
142141 Argument moduleInSequential = UtilsAST .getArgumentsFromCall (sequentialCall ).get (moduleIndex );
143142 if (moduleInSequential .is (REGULAR_ARGUMENT ) && ((RegularArgument ) moduleInSequential ).expression ().is (CALL_EXPR )) {
144143 CallExpression module = (CallExpression ) ((RegularArgument ) moduleInSequential ).expression ();
145- if (UtilsAST .getQualifiedName (module ). equals ( convFullyQualifiedName ) && isConvWithBias (module )) {
144+ if (CONV_FULLY_QUALIFIED_NAME . equals ( UtilsAST .getQualifiedName (module )) && isConvWithBias (module )) {
146145 if (moduleIndex == nModulesInSequential - 1 )
147146 break ;
148147 Argument nextModuleInSequential = UtilsAST .getArgumentsFromCall (sequentialCall ).get (moduleIndex + 1 );
149148 CallExpression nextModule = (CallExpression ) ((RegularArgument ) nextModuleInSequential ).expression ();
150- if (UtilsAST .getQualifiedName (nextModule ). equals ( batchNormFullyQualifiedName ))
149+ if (BATCH_NORM_FULLY_QUALIFIED_NAME . equals ( UtilsAST .getQualifiedName (nextModule )))
151150 context .addIssue (module , MESSAGE );
152151 }
153152 }
@@ -160,7 +159,7 @@ private void visitModelClass(SubscriptionContext context, ClassDef classDef) {
160159 Map <String , CallExpression > batchNormsInInit = new HashMap <>();
161160
162161 for (Statement s : classDef .body ().statements ()) {
163- if (s .is (FUNCDEF ) && (( FunctionDef ) s ).name ().name (). equals ( "__init__" )) {
162+ if (s .is (FUNCDEF ) && "__init__" . equals ((( FunctionDef ) s ).name ().name ())) {
164163 for (Statement ss : ((FunctionDef ) s ).body ().statements ()) {
165164 if (ss .is (ASSIGNMENT_STMT )) {
166165 Expression lhs = ((AssignmentStatement ) ss ).lhsExpressions ().get (0 ).expressions ().get (0 );
@@ -170,19 +169,19 @@ private void visitModelClass(SubscriptionContext context, ClassDef classDef) {
170169 CallExpression callExpression = (CallExpression ) ((AssignmentStatement ) ss ).assignedValue ();
171170 String variableName = ((QualifiedExpression ) lhs ).name ().name ();
172171 String variableClass = UtilsAST .getQualifiedName (callExpression );
173- if (variableClass .equals (sequentialModuleFullyQualifiedName )) {
172+ if (SEQUENTIAL_MODULE_FULLY_QUALIFIED_NAME .equals (variableClass )) {
174173 reportForSequentialModules (context , callExpression );
175- } else if (convFullyQualifiedName .equals (variableClass ) && isConvWithBias (callExpression )) {
174+ } else if (variableClass .equals (CONV_FULLY_QUALIFIED_NAME ) && isConvWithBias (callExpression )) {
176175 dirtyConvInInit .put (variableName , callExpression );
177- } else if (batchNormFullyQualifiedName .equals (variableClass )) {
176+ } else if (BATCH_NORM_FULLY_QUALIFIED_NAME .equals (variableClass )) {
178177 batchNormsInInit .put (variableName , callExpression );
179178 }
180179 }
181180 }
182181 }
183182 }
184183 for (Statement s : classDef .body ().statements ()) {
185- if (s .is (FUNCDEF ) && (( FunctionDef ) s ).name ().name (). equals ( forwardMethodName )) {
184+ if (s .is (FUNCDEF ) && FORWARD_METHOD_NAME . equals ((( FunctionDef ) s ).name ().name ())) {
186185 FunctionDef forwardDef = (FunctionDef ) s ;
187186 reportIfBatchNormIsCalledAfterDirtyConv (context , forwardDef , dirtyConvInInit , batchNormsInInit );
188187 }
0 commit comments