-
Notifications
You must be signed in to change notification settings - Fork 963
/
Copy pathoverride_utils.py
148 lines (112 loc) · 4.52 KB
/
override_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (c) 2022 The Brave Authors. All rights reserved.
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/. */
import contextlib
import inspect
import types
from typing import Any
def override_function(scope, name=None, condition=True):
"""Replaces an existing function in the scope."""
def decorator(new_function):
is_dict_scope = isinstance(scope, dict)
function_name = name or new_function.__name__
if is_dict_scope:
original_function = scope.get(function_name, None)
else:
original_function = getattr(scope, function_name, None)
if not callable(original_function):
raise NameError(f'Failed to override function: '
f'{function_name} not found or not callable')
def wrapped_function(*args, **kwargs):
return new_function(original_function, *args, **kwargs)
if not condition:
wrapped_function = original_function
if is_dict_scope:
scope[function_name] = wrapped_function
else:
setattr(scope, function_name, wrapped_function)
return wrapped_function
return decorator
def override_method(scope, name=None, condition=True):
"""Replaces an existing method in the class scope."""
def decorator(new_method):
assert not isinstance(scope, dict)
method_name = name or new_method.__name__
original_method: Any = getattr(scope, method_name, None)
if not condition:
wrapped_method = original_method
else:
def wrapped_method(self, *args, **kwargs):
return new_method(self, original_method, *args, **kwargs)
if inspect.ismethod(original_method):
setattr(scope, method_name,
types.MethodType(wrapped_method, scope))
else:
assert inspect.isfunction(original_method)
setattr(scope, method_name, wrapped_method)
return wrapped_method
return decorator
@contextlib.contextmanager
def override_scope_function(scope, new_function, name=None, condition=True):
"""Scoped function override helper. Can override a scope function or a class
method."""
if not condition:
yield
return
function_name = name or new_function.__name__
original_function = getattr(scope, function_name, None)
try:
if not callable(original_function):
raise NameError(f'Failed to override scope function: '
f'{function_name} not found or not callable')
if inspect.ismethod(original_function):
def wrapped_method(self, *args, **kwargs):
return new_function(self, original_function, *args, **kwargs)
setattr(scope, function_name,
types.MethodType(wrapped_method, scope))
else:
def wrapped_function(*args, **kwargs):
return new_function(original_function, *args, **kwargs)
setattr(scope, function_name, wrapped_function)
yield
finally:
if condition and original_function:
setattr(scope, function_name, original_function)
@contextlib.contextmanager
def override_scope_variable(scope,
name,
value,
fail_if_not_found=True,
condition=True):
"""Scoped variable override helper."""
if not condition:
yield
return
is_dict_scope = isinstance(scope, dict)
def _has(scope, name):
return name in scope if is_dict_scope else hasattr(scope, name)
def _get(scope, name):
return scope[name] if is_dict_scope else getattr(scope, name)
def _set(scope, name, value):
if is_dict_scope:
scope[name] = value
else:
setattr(scope, name, value)
def _del(scope, name):
if is_dict_scope:
del scope[name]
else:
delattr(scope, name)
var_exist = _has(scope, name)
if fail_if_not_found and not var_exist:
raise NameError(f'Failed to override scope variable: {name} not found')
original_value = _get(scope, name) if var_exist else None
try:
_set(scope, name, value)
yield
finally:
if var_exist:
_set(scope, name, original_value)
else:
_del(scope, name)