Skip to content

Commit eb3e176

Browse files
[red-knot] Add callable subtyping for callable instances and bound methods (#17105)
## Summary Trying to improve #17005 Partially fixes #16953 ## Test Plan Update is_subtype_of.md --------- Co-authored-by: Carl Meyer <[email protected]>
1 parent d38f6fc commit eb3e176

File tree

3 files changed

+79
-0
lines changed

3 files changed

+79
-0
lines changed

crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,5 +1099,54 @@ static_assert(is_subtype_of(TypeOf[C.foo], object))
10991099
static_assert(not is_subtype_of(object, TypeOf[C.foo]))
11001100
```
11011101

1102+
### Classes with `__call__`
1103+
1104+
```py
1105+
from typing import Callable
1106+
from knot_extensions import TypeOf, is_subtype_of, static_assert, is_assignable_to
1107+
1108+
class A:
1109+
def __call__(self, a: int) -> int:
1110+
return a
1111+
1112+
a = A()
1113+
1114+
static_assert(is_subtype_of(A, Callable[[int], int]))
1115+
static_assert(not is_subtype_of(A, Callable[[], int]))
1116+
static_assert(not is_subtype_of(Callable[[int], int], A))
1117+
1118+
def f(fn: Callable[[int], int]) -> None: ...
1119+
1120+
f(a)
1121+
```
1122+
1123+
### Bound methods
1124+
1125+
```py
1126+
from typing import Callable
1127+
from knot_extensions import TypeOf, static_assert, is_subtype_of
1128+
1129+
class A:
1130+
def f(self, a: int) -> int:
1131+
return a
1132+
1133+
@classmethod
1134+
def g(cls, a: int) -> int:
1135+
return a
1136+
1137+
a = A()
1138+
1139+
static_assert(is_subtype_of(TypeOf[a.f], Callable[[int], int]))
1140+
static_assert(is_subtype_of(TypeOf[a.g], Callable[[int], int]))
1141+
static_assert(is_subtype_of(TypeOf[A.g], Callable[[int], int]))
1142+
1143+
static_assert(not is_subtype_of(TypeOf[a.f], Callable[[float], int]))
1144+
static_assert(not is_subtype_of(TypeOf[A.g], Callable[[], int]))
1145+
1146+
# TODO: This assertion should be true
1147+
# error: [static-assert-error] "Static assertion error: argument evaluates to `False`"
1148+
static_assert(is_subtype_of(TypeOf[A.f], Callable[[A, int], int]))
1149+
```
1150+
11021151
[special case for float and complex]: https://typing.readthedocs.io/en/latest/spec/special-types.html#special-cases-for-float-and-complex
11031152
[typing documentation]: https://typing.readthedocs.io/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence

crates/red_knot_python_semantic/src/types.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,10 @@ impl<'db> Type<'db> {
727727
.is_subtype_of(db, target)
728728
}
729729

730+
(Type::BoundMethod(self_bound_method), Type::Callable(_)) => self_bound_method
731+
.into_callable_type(db)
732+
.is_subtype_of(db, target),
733+
730734
// A `FunctionLiteral` type is a single-valued type like the other literals handled above,
731735
// so it also, for now, just delegates to its instance fallback.
732736
(Type::FunctionLiteral(_), _) => KnownClass::FunctionType
@@ -833,6 +837,16 @@ impl<'db> Type<'db> {
833837
self_instance.is_subtype_of(db, target_instance)
834838
}
835839

840+
(Type::Instance(_), Type::Callable(_)) => {
841+
let call_symbol = self.member(db, "__call__").symbol;
842+
match call_symbol {
843+
Symbol::Type(Type::BoundMethod(call_function), _) => call_function
844+
.into_callable_type(db)
845+
.is_subtype_of(db, target),
846+
_ => false,
847+
}
848+
}
849+
836850
// Other than the special cases enumerated above,
837851
// `Instance` types are never subtypes of any other variants
838852
(Type::Instance(_), _) => false,
@@ -4414,6 +4428,15 @@ pub struct BoundMethodType<'db> {
44144428
self_instance: Type<'db>,
44154429
}
44164430

4431+
impl<'db> BoundMethodType<'db> {
4432+
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> {
4433+
Type::Callable(CallableType::new(
4434+
db,
4435+
self.function(db).signature(db).bind_self(),
4436+
))
4437+
}
4438+
}
4439+
44174440
/// This type represents the set of all callable objects with a certain signature.
44184441
/// It can be written in type expressions using `typing.Callable`.
44194442
/// `lambda` expressions are inferred directly as `CallableType`s; all function-literal types

crates/red_knot_python_semantic/src/types/signatures.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,13 @@ impl<'db> Signature<'db> {
265265
pub(crate) fn parameters(&self) -> &Parameters<'db> {
266266
&self.parameters
267267
}
268+
269+
pub(crate) fn bind_self(&self) -> Self {
270+
Self {
271+
parameters: Parameters::new(self.parameters().iter().skip(1).cloned()),
272+
return_ty: self.return_ty,
273+
}
274+
}
268275
}
269276

270277
#[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)]

0 commit comments

Comments
 (0)