Skip to content

Commit 2dfc5d7

Browse files
authored
Merge pull request #316 from mxr/no-mock
Add --keep-mock option
2 parents 400b8bb + a4eb2bb commit 2dfc5d7

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ Availability:
243243

244244
Availability:
245245
- `--py3-plus` is passed on the commandline.
246+
- [Unless `--keep-mock` is passed on the commandline](https://github.com/asottile/pyupgrade/issues/314).
246247

247248
```diff
248249
-from mock import patch

pyupgrade.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,9 @@ def __init__(self) -> None:
12091209
self.yield_from_names: Dict[str, Set[Offset]]
12101210
self.yield_from_names = collections.defaultdict(set)
12111211

1212-
def __init__(self) -> None:
1212+
def __init__(self, keep_mock: bool) -> None:
1213+
self._find_mock = not keep_mock
1214+
12131215
self.bases_to_remove: Set[Offset] = set()
12141216

12151217
self.encode_calls: Dict[Offset, ast.Call] = {}
@@ -1330,7 +1332,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
13301332
for name in node.names:
13311333
if not name.asname:
13321334
self._from_imports[node.module].add(name.name)
1333-
elif node.module in self.MOCK_MODULES:
1335+
elif self._find_mock and node.module in self.MOCK_MODULES:
13341336
self.mock_relative_imports.add(_ast_to_offset(node))
13351337
elif node.module == 'sys' and any(
13361338
name.name == 'version_info' and not name.asname
@@ -1341,6 +1343,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
13411343

13421344
def visit_Import(self, node: ast.Import) -> None:
13431345
if (
1346+
self._find_mock and
13441347
len(node.names) == 1 and
13451348
node.names[0].name in self.MOCK_MODULES
13461349
):
@@ -1437,7 +1440,7 @@ def _visit_comp(self, node: ast.expr) -> None:
14371440
def visit_Attribute(self, node: ast.Attribute) -> None:
14381441
if self._is_six(node, SIX_SIMPLE_ATTRS):
14391442
self.six_simple[_ast_to_offset(node)] = node
1440-
elif self._is_mock_mock(node):
1443+
elif self._find_mock and self._is_mock_mock(node):
14411444
self.mock_mock.add(_ast_to_offset(node))
14421445
self.generic_visit(node)
14431446

@@ -1994,13 +1997,17 @@ def _replace_yield(tokens: List[Token], i: int) -> None:
19941997
tokens[i:block.end] = [Token('CODE', f'yield from {container}\n')]
19951998

19961999

1997-
def _fix_py3_plus(contents_text: str, min_version: MinVersion) -> str:
2000+
def _fix_py3_plus(
2001+
contents_text: str,
2002+
min_version: MinVersion,
2003+
keep_mock: bool = False,
2004+
) -> str:
19982005
try:
19992006
ast_obj = ast_parse(contents_text)
20002007
except SyntaxError:
20012008
return contents_text
20022009

2003-
visitor = FindPy3Plus()
2010+
visitor = FindPy3Plus(keep_mock)
20042011
visitor.visit(ast_obj)
20052012

20062013
if not any((
@@ -2637,7 +2644,9 @@ def _fix_file(filename: str, args: argparse.Namespace) -> int:
26372644
if not args.keep_percent_format:
26382645
contents_text = _fix_percent_format(contents_text)
26392646
if args.min_version >= (3,):
2640-
contents_text = _fix_py3_plus(contents_text, args.min_version)
2647+
contents_text = _fix_py3_plus(
2648+
contents_text, args.min_version, args.keep_mock,
2649+
)
26412650
if args.min_version >= (3, 6):
26422651
contents_text = _fix_py36_plus(contents_text)
26432652

@@ -2659,6 +2668,7 @@ def main(argv: Optional[Sequence[str]] = None) -> int:
26592668
parser.add_argument('filenames', nargs='*')
26602669
parser.add_argument('--exit-zero-even-if-changed', action='store_true')
26612670
parser.add_argument('--keep-percent-format', action='store_true')
2671+
parser.add_argument('--keep-mock', action='store_true')
26622672
parser.add_argument(
26632673
'--py3-plus', '--py3-only',
26642674
action='store_const', dest='min_version', default=(2, 7), const=(3,),

tests/mock_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ def test_mock_noop(s):
2020
assert _fix_py3_plus(s, (3,)) == s
2121

2222

23+
def test_mock_noop_keep_mock():
24+
"""This would've been rewritten if keep_mock were False"""
25+
s = (
26+
'from mock import patch\n'
27+
'\n'
28+
'patch("func")'
29+
)
30+
assert _fix_py3_plus(s, (3,), keep_mock=True) == s
31+
32+
2333
@pytest.mark.parametrize(
2434
('s', 'expected'),
2535
(

0 commit comments

Comments
 (0)