@@ -1209,7 +1209,9 @@ def __init__(self) -> None:
12091209 self .yield_from_names : Dict [str , Set [Offset ]]
12101210 self .yield_from_names = collections .defaultdict (set )
12111211
1212- def __init__ (self ) -> None :
1212+ def __init__ (self , no_mock : bool ) -> None :
1213+ self ._find_mock = not no_mock
1214+
12131215 self .bases_to_remove : Set [Offset ] = set ()
12141216
12151217 self .encode_calls : Dict [Offset , ast .Call ] = {}
@@ -1330,7 +1332,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
13301332 for name in node .names :
13311333 if not name .asname :
13321334 self ._from_imports [node .module ].add (name .name )
1333- elif node .module in self .MOCK_MODULES :
1335+ elif self . _find_mock and node .module in self .MOCK_MODULES :
13341336 self .mock_relative_imports .add (_ast_to_offset (node ))
13351337 elif node .module == 'sys' and any (
13361338 name .name == 'version_info' and not name .asname
@@ -1341,6 +1343,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
13411343
13421344 def visit_Import (self , node : ast .Import ) -> None :
13431345 if (
1346+ self ._find_mock and
13441347 len (node .names ) == 1 and
13451348 node .names [0 ].name in self .MOCK_MODULES
13461349 ):
@@ -1437,7 +1440,7 @@ def _visit_comp(self, node: ast.expr) -> None:
14371440 def visit_Attribute (self , node : ast .Attribute ) -> None :
14381441 if self ._is_six (node , SIX_SIMPLE_ATTRS ):
14391442 self .six_simple [_ast_to_offset (node )] = node
1440- elif self ._is_mock_mock (node ):
1443+ elif self ._find_mock and self . _is_mock_mock (node ):
14411444 self .mock_mock .add (_ast_to_offset (node ))
14421445 self .generic_visit (node )
14431446
@@ -1994,13 +1997,17 @@ def _replace_yield(tokens: List[Token], i: int) -> None:
19941997 tokens [i :block .end ] = [Token ('CODE' , f'yield from { container } \n ' )]
19951998
19961999
1997- def _fix_py3_plus (contents_text : str , min_version : MinVersion ) -> str :
2000+ def _fix_py3_plus (
2001+ contents_text : str ,
2002+ min_version : MinVersion ,
2003+ no_mock : bool = False ,
2004+ ) -> str :
19982005 try :
19992006 ast_obj = ast_parse (contents_text )
20002007 except SyntaxError :
20012008 return contents_text
20022009
2003- visitor = FindPy3Plus ()
2010+ visitor = FindPy3Plus (no_mock )
20042011 visitor .visit (ast_obj )
20052012
20062013 if not any ((
@@ -2637,7 +2644,9 @@ def _fix_file(filename: str, args: argparse.Namespace) -> int:
26372644 if not args .keep_percent_format :
26382645 contents_text = _fix_percent_format (contents_text )
26392646 if args .min_version >= (3 ,):
2640- contents_text = _fix_py3_plus (contents_text , args .min_version )
2647+ contents_text = _fix_py3_plus (
2648+ contents_text , args .min_version , args .no_mock ,
2649+ )
26412650 if args .min_version >= (3 , 6 ):
26422651 contents_text = _fix_py36_plus (contents_text )
26432652
@@ -2659,6 +2668,7 @@ def main(argv: Optional[Sequence[str]] = None) -> int:
26592668 parser .add_argument ('filenames' , nargs = '*' )
26602669 parser .add_argument ('--exit-zero-even-if-changed' , action = 'store_true' )
26612670 parser .add_argument ('--keep-percent-format' , action = 'store_true' )
2671+ parser .add_argument ('--no-mock' , action = 'store_true' )
26622672 parser .add_argument (
26632673 '--py3-plus' , '--py3-only' ,
26642674 action = 'store_const' , dest = 'min_version' , default = (2 , 7 ), const = (3 ,),
0 commit comments