Skip to content

Rework how we refine information from callable() #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 67 additions & 14 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2809,7 +2809,7 @@ def function_type(self, func: FuncBase) -> FunctionLike:
return function_type(func, self.named_type('builtins.function'))

def find_isinstance_check(self, n: Expression) -> 'Tuple[TypeMap, TypeMap]':
return find_isinstance_check(n, self.type_map)
return find_isinstance_check(n, self.type_map, self)

def push_type_map(self, type_map: 'TypeMap') -> None:
if type_map is None:
Expand Down Expand Up @@ -2883,7 +2883,45 @@ def conditional_type_map(expr: Expression,
return {}, {}


def partition_by_callable(type: Type) -> Tuple[List[Type], List[Type]]:
def intersect_instance_callable(type: Instance, callable_type: CallableType) -> Type:
"""Creates a fake type that represents the intersection of an
Instance and a CallableType.

It operates by creating a bare-minimum dummy TypeInfo that
subclasses type and adds a __call__ method matching callable_type."""

# Build the fake ClassDef and TypeInfo together.
# The ClassDef is full of lies and doesn't actually contain a body.
cdef = ClassDef("<callable subtype of CHEDDAR CHEESE IS GOOD {}>".format(type.type.name()), Block([]))
info = TypeInfo(SymbolTable(), cdef, '<dummy>')
cdef.info = info
info.bases = [type]
info.calculate_mro()

# Build up a fake FuncDef so we can populate the symbol table.
func_def = FuncDef('__call__', [], Block([]), callable_type)
func_def.info = info
info.names['__call__'] = SymbolTableNode(MDEF, func_def, callable_type)

return Instance(info, [])


def make_fake_callable(type: Instance, typechecker: TypeChecker) -> Type:
"""Produce a new type that makes type Callable with a generic callable type."""

fallback = typechecker.named_type('builtins.function')
callable_type = CallableType([AnyType(TypeOfAny.explicit),
AnyType(TypeOfAny.explicit)],
[nodes.ARG_STAR, nodes.ARG_STAR2],
[None, None],
ret_type=AnyType(TypeOfAny.explicit),
fallback=fallback,
is_ellipsis_args=True)

return intersect_instance_callable(type, callable_type)


def partition_by_callable(type: Type, typechecker: TypeChecker) -> Tuple[List[Type], List[Type]]:
"""Takes in a type and partitions that type into callable subtypes and
uncallable subtypes.

Expand All @@ -2904,29 +2942,43 @@ def partition_by_callable(type: Type) -> Tuple[List[Type], List[Type]]:
callables = []
uncallables = []
for subtype in type.relevant_items():
subcallables, subuncallables = partition_by_callable(subtype)
subcallables, subuncallables = partition_by_callable(subtype, typechecker)
callables.extend(subcallables)
uncallables.extend(subuncallables)
return callables, uncallables

if isinstance(type, TypeVarType):
return partition_by_callable(type.erase_to_union_or_bound())
# We could do better probably?
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So loooong

# Refine the the type variable's bound as our type in the case that
# callable() is true. This unfortuantely loses the information that
# the type is a type variable in that branch.
# This matches what is done for isinstance, but it may be possible to
# do better.
# If it is possible for the false branch to execute, return the original
# type to avoid losing type information.
callables, uncallables = partition_by_callable(type.erase_to_union_or_bound(), typechecker)
uncallables = [type] if len(uncallables) else []
return callables, uncallables

if isinstance(type, Instance):
method = type.type.get_method('__call__')
if method and method.type:
callables, uncallables = partition_by_callable(method.type)
callables, uncallables = partition_by_callable(method.type, typechecker)
if len(callables) and not len(uncallables):
# Only consider the type callable if its __call__ method is
# definitely callable.
return [type], []
return [], [type]

return [], [type]
ret = make_fake_callable(type, typechecker)
return [ret], [type]

# We don't know how properly make the type callable.
return [type], [type]


def conditional_callable_type_map(expr: Expression,
current_type: Optional[Type],
typechecker: TypeChecker,
) -> Tuple[TypeMap, TypeMap]:
"""Takes in an expression and the current type of the expression.

Expand All @@ -2940,7 +2992,7 @@ def conditional_callable_type_map(expr: Expression,
if isinstance(current_type, AnyType):
return {}, {}

callables, uncallables = partition_by_callable(current_type)
callables, uncallables = partition_by_callable(current_type, typechecker)

if len(callables) and len(uncallables):
callable_map = {expr: UnionType.make_union(callables)} if len(callables) else None
Expand Down Expand Up @@ -3069,6 +3121,7 @@ def convert_to_typetype(type_map: TypeMap) -> TypeMap:

def find_isinstance_check(node: Expression,
type_map: Dict[Expression, Type],
typechecker: TypeChecker,
) -> Tuple[TypeMap, TypeMap]:
"""Find any isinstance checks (within a chain of ands). Includes
implicit and explicit checks for None and calls to callable.
Expand Down Expand Up @@ -3122,7 +3175,7 @@ def find_isinstance_check(node: Expression,
expr = node.args[0]
if literal(expr) == LITERAL_TYPE:
vartype = type_map[expr]
return conditional_callable_type_map(expr, vartype)
return conditional_callable_type_map(expr, vartype, typechecker)
elif isinstance(node, ComparisonExpr) and experiments.STRICT_OPTIONAL:
# Check for `x is None` and `x is not None`.
is_not = node.operators == ['is not']
Expand Down Expand Up @@ -3182,23 +3235,23 @@ def find_isinstance_check(node: Expression,
else_map = {ref: else_type} if not isinstance(else_type, UninhabitedType) else None
return if_map, else_map
elif isinstance(node, OpExpr) and node.op == 'and':
left_if_vars, left_else_vars = find_isinstance_check(node.left, type_map)
right_if_vars, right_else_vars = find_isinstance_check(node.right, type_map)
left_if_vars, left_else_vars = find_isinstance_check(node.left, type_map, typechecker)
right_if_vars, right_else_vars = find_isinstance_check(node.right, type_map, typechecker)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silly things 🍌


# (e1 and e2) is true if both e1 and e2 are true,
# and false if at least one of e1 and e2 is false.
return (and_conditional_maps(left_if_vars, right_if_vars),
or_conditional_maps(left_else_vars, right_else_vars))
elif isinstance(node, OpExpr) and node.op == 'or':
left_if_vars, left_else_vars = find_isinstance_check(node.left, type_map)
right_if_vars, right_else_vars = find_isinstance_check(node.right, type_map)
left_if_vars, left_else_vars = find_isinstance_check(node.left, type_map, typechecker)
right_if_vars, right_else_vars = find_isinstance_check(node.right, type_map, typechecker)

# (e1 or e2) is true if at least one of e1 or e2 is true,
# and false if both e1 and e2 are false.
return (or_conditional_maps(left_if_vars, right_if_vars),
and_conditional_maps(left_else_vars, right_else_vars))
elif isinstance(node, UnaryExpr) and node.op == 'not':
left, right = find_isinstance_check(node.expr, type_map)
left, right = find_isinstance_check(node.expr, type_map, typechecker)
return right, left

# Not a supported isinstance check
Expand Down
3 changes: 2 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2284,7 +2284,8 @@ def check_for_comp(self, e: Union[GeneratorExpr, DictionaryComprehension]) -> No
self.accept(condition)

# values are only part of the comprehension when all conditions are true
true_map, _ = mypy.checker.find_isinstance_check(condition, self.chk.type_map)
true_map, _ = mypy.checker.find_isinstance_check(condition, self.chk.type_map,
self.chk)

if true_map:
for var, type in true_map.items():
Expand Down
36 changes: 31 additions & 5 deletions test-data/unit/check-callable.test
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,13 @@ b = B() # type: B
c = A() # type: Union[A, B]

if callable(a):
5 + 'test'
5 + 'test' # E: Unsupported operand types for + ("int" and "str")

if not callable(b):
5 + 'test'

if callable(c):
reveal_type(c) # E: Revealed type is '__main__.B'
reveal_type(c) # E: Revealed type is 'Union[<callable subtype of A>, __main__.B]'
else:
reveal_type(c) # E: Revealed type is '__main__.A'

Expand All @@ -227,7 +227,7 @@ T = Union[Union[int, Callable[[], int]], Union[str, Callable[[], str]]]

def f(t: T) -> None:
if callable(t):
reveal_type(t()) # E: Revealed type is 'Union[builtins.int, builtins.str]'
reveal_type(t()) # E: Revealed type is 'Union[Any, builtins.int, builtins.str]'
else:
reveal_type(t) # E: Revealed type is 'Union[builtins.int, builtins.str]'

Expand All @@ -240,7 +240,7 @@ T = TypeVar('T')

def f(t: T) -> T:
if callable(t):
return 5
return 5 # E: Incompatible return value type (got "int", expected "T")
else:
return t

Expand All @@ -253,7 +253,7 @@ T = TypeVar('T', int, Callable[[], int], Union[str, Callable[[], str]])

def f(t: T) -> None:
if callable(t):
reveal_type(t()) # E: Revealed type is 'builtins.int' # E: Revealed type is 'builtins.str'
reveal_type(t()) # E: Revealed type is 'Any' # E: Revealed type is 'builtins.int' # E: Revealed type is 'Union[Any, builtins.str]'
else:
reveal_type(t) # E: Revealed type is 'builtins.int*' # E: Revealed type is 'builtins.str'

Expand Down Expand Up @@ -343,3 +343,29 @@ else:
'test' + 5

[builtins fixtures/callable.pyi]

[case testCallableObject]

def f(o: object) -> None:
if callable(o):
o()
1 + 'boom' # E: Unsupported operand types for + ("int" and "str")
o()

[builtins fixtures/callable.pyi]

[case testCallableObject2]

class Foo(object):
def bar(self) -> None:
pass

def g(o: Foo) -> None:
o.bar()
if callable(o):
o.bar()
o()
else:
o.bar()

[builtins fixtures/callable.pyi]