@@ -1209,7 +1209,9 @@ def __init__(self) -> None:
1209
1209
self .yield_from_names : Dict [str , Set [Offset ]]
1210
1210
self .yield_from_names = collections .defaultdict (set )
1211
1211
1212
- def __init__ (self ) -> None :
1212
+ def __init__ (self , keep_mock : bool ) -> None :
1213
+ self ._find_mock = not keep_mock
1214
+
1213
1215
self .bases_to_remove : Set [Offset ] = set ()
1214
1216
1215
1217
self .encode_calls : Dict [Offset , ast .Call ] = {}
@@ -1330,7 +1332,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
1330
1332
for name in node .names :
1331
1333
if not name .asname :
1332
1334
self ._from_imports [node .module ].add (name .name )
1333
- elif node .module in self .MOCK_MODULES :
1335
+ elif self . _find_mock and node .module in self .MOCK_MODULES :
1334
1336
self .mock_relative_imports .add (_ast_to_offset (node ))
1335
1337
elif node .module == 'sys' and any (
1336
1338
name .name == 'version_info' and not name .asname
@@ -1341,6 +1343,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
1341
1343
1342
1344
def visit_Import (self , node : ast .Import ) -> None :
1343
1345
if (
1346
+ self ._find_mock and
1344
1347
len (node .names ) == 1 and
1345
1348
node .names [0 ].name in self .MOCK_MODULES
1346
1349
):
@@ -1437,7 +1440,7 @@ def _visit_comp(self, node: ast.expr) -> None:
1437
1440
def visit_Attribute (self , node : ast .Attribute ) -> None :
1438
1441
if self ._is_six (node , SIX_SIMPLE_ATTRS ):
1439
1442
self .six_simple [_ast_to_offset (node )] = node
1440
- elif self ._is_mock_mock (node ):
1443
+ elif self ._find_mock and self . _is_mock_mock (node ):
1441
1444
self .mock_mock .add (_ast_to_offset (node ))
1442
1445
self .generic_visit (node )
1443
1446
@@ -1994,13 +1997,17 @@ def _replace_yield(tokens: List[Token], i: int) -> None:
1994
1997
tokens [i :block .end ] = [Token ('CODE' , f'yield from { container } \n ' )]
1995
1998
1996
1999
1997
- def _fix_py3_plus (contents_text : str , min_version : MinVersion ) -> str :
2000
+ def _fix_py3_plus (
2001
+ contents_text : str ,
2002
+ min_version : MinVersion ,
2003
+ keep_mock : bool = False ,
2004
+ ) -> str :
1998
2005
try :
1999
2006
ast_obj = ast_parse (contents_text )
2000
2007
except SyntaxError :
2001
2008
return contents_text
2002
2009
2003
- visitor = FindPy3Plus ()
2010
+ visitor = FindPy3Plus (keep_mock )
2004
2011
visitor .visit (ast_obj )
2005
2012
2006
2013
if not any ((
@@ -2637,7 +2644,9 @@ def _fix_file(filename: str, args: argparse.Namespace) -> int:
2637
2644
if not args .keep_percent_format :
2638
2645
contents_text = _fix_percent_format (contents_text )
2639
2646
if args .min_version >= (3 ,):
2640
- contents_text = _fix_py3_plus (contents_text , args .min_version )
2647
+ contents_text = _fix_py3_plus (
2648
+ contents_text , args .min_version , args .keep_mock ,
2649
+ )
2641
2650
if args .min_version >= (3 , 6 ):
2642
2651
contents_text = _fix_py36_plus (contents_text )
2643
2652
@@ -2659,6 +2668,7 @@ def main(argv: Optional[Sequence[str]] = None) -> int:
2659
2668
parser .add_argument ('filenames' , nargs = '*' )
2660
2669
parser .add_argument ('--exit-zero-even-if-changed' , action = 'store_true' )
2661
2670
parser .add_argument ('--keep-percent-format' , action = 'store_true' )
2671
+ parser .add_argument ('--keep-mock' , action = 'store_true' )
2662
2672
parser .add_argument (
2663
2673
'--py3-plus' , '--py3-only' ,
2664
2674
action = 'store_const' , dest = 'min_version' , default = (2 , 7 ), const = (3 ,),
0 commit comments