Skip to content

[ty] Recursive protocols #17929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 103 additions & 8 deletions crates/ty_python_semantic/resources/mdtest/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -1569,11 +1569,11 @@ from typing import Protocol, Any
from ty_extensions import is_fully_static, static_assert, is_assignable_to, is_subtype_of, is_equivalent_to

class RecursiveFullyStatic(Protocol):
parent: RecursiveFullyStatic | None
parent: RecursiveFullyStatic
x: int

class RecursiveNonFullyStatic(Protocol):
parent: RecursiveNonFullyStatic | None
parent: RecursiveNonFullyStatic
x: Any

static_assert(is_fully_static(RecursiveFullyStatic))
Expand All @@ -1582,16 +1582,111 @@ static_assert(not is_fully_static(RecursiveNonFullyStatic))
static_assert(not is_subtype_of(RecursiveFullyStatic, RecursiveNonFullyStatic))
static_assert(not is_subtype_of(RecursiveNonFullyStatic, RecursiveFullyStatic))

# TODO: currently leads to a stack overflow
# static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic))
# static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveFullyStatic))
static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveNonFullyStatic))
static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic))
static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveFullyStatic))

class AlsoRecursiveFullyStatic(Protocol):
parent: AlsoRecursiveFullyStatic | None
parent: AlsoRecursiveFullyStatic
x: int

# TODO: currently leads to a stack overflow
# static_assert(is_equivalent_to(AlsoRecursiveFullyStatic, RecursiveFullyStatic))
static_assert(is_equivalent_to(AlsoRecursiveFullyStatic, RecursiveFullyStatic))

class RecursiveOptionalParent(Protocol):
parent: RecursiveOptionalParent | None

static_assert(is_fully_static(RecursiveOptionalParent))

static_assert(is_assignable_to(RecursiveOptionalParent, RecursiveOptionalParent))

static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveOptionalParent))
static_assert(not is_assignable_to(RecursiveOptionalParent, RecursiveNonFullyStatic))

class Other(Protocol):
z: str

def _(rec: RecursiveFullyStatic, other: Other):
reveal_type(rec.parent.parent.parent) # revealed: RecursiveFullyStatic

rec.parent.parent.parent = rec
rec = rec.parent.parent.parent

rec.parent.parent.parent = other # error: [invalid-assignment]
other = rec.parent.parent.parent # error: [invalid-assignment]

class Foo(Protocol):
@property
def x(self) -> "Foo": ...

class Bar(Protocol):
@property
def x(self) -> "Bar": ...

# TODO: this should pass
# error: [static-assert-error]
static_assert(is_equivalent_to(Foo, Bar))
```

### Nested occurrences of self-reference

Make sure that we handle self-reference correctly, even if the self-reference appears deeply nested
within the type of a protocol member:

```toml
[environment]
python-version = "3.12"
```

```py
from __future__ import annotations

from typing import Protocol, Callable
from ty_extensions import Intersection, Not, is_fully_static, is_assignable_to, is_equivalent_to, static_assert

class C: ...

class GenericC[T](Protocol):
pass

class Recursive(Protocol):
direct: Recursive

union: None | Recursive

intersection1: Intersection[C, Recursive]
intersection2: Intersection[C, Not[Recursive]]

t: tuple[int, tuple[str, Recursive]]

callable1: Callable[[int], Recursive]
callable2: Callable[[Recursive], int]

subtype_of: type[Recursive]

generic: GenericC[Recursive]

def method(self, x: Recursive) -> Recursive: ...

nested: Recursive | Callable[[Recursive | Recursive, tuple[Recursive, Recursive]], Recursive | Recursive]

static_assert(is_fully_static(Recursive))
static_assert(is_equivalent_to(Recursive, Recursive))
static_assert(is_assignable_to(Recursive, Recursive))

def _(r: Recursive):
reveal_type(r.direct) # revealed: Recursive
reveal_type(r.union) # revealed: None | Recursive
reveal_type(r.intersection1) # revealed: C & Recursive
reveal_type(r.intersection2) # revealed: C & ~Recursive
reveal_type(r.t) # revealed: tuple[int, tuple[str, Recursive]]
reveal_type(r.callable1) # revealed: (int, /) -> Recursive
reveal_type(r.callable2) # revealed: (Recursive, /) -> int
reveal_type(r.subtype_of) # revealed: type[Recursive]
reveal_type(r.generic) # revealed: GenericC[Recursive]
reveal_type(r.method(r)) # revealed: Recursive
reveal_type(r.nested) # revealed: Recursive | ((Recursive, tuple[Recursive, Recursive], /) -> Recursive)

reveal_type(r.method(r).callable1(1).direct.t[1][1]) # revealed: Recursive
```

### Regression test: narrowing with self-referential protocols
Expand Down
84 changes: 84 additions & 0 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,79 @@ impl<'db> Type<'db> {
matches!(self, Type::Dynamic(DynamicType::Todo(_)))
}

/// Replace references to the class `class` with a self-reference marker. This is currently
/// used for recursive protocols, but could probably be extended to self-referential type-
/// aliases and similar.
#[must_use]
pub fn replace_self_reference(&self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Type<'db> {
match self {
Self::ProtocolInstance(protocol) => {
Self::ProtocolInstance(protocol.replace_self_reference(db, class))
}

Self::Union(union) => UnionType::from_elements(
db,
union
.elements(db)
.iter()
.map(|ty| ty.replace_self_reference(db, class)),
),

Self::Intersection(intersection) => IntersectionBuilder::new(db)
.positive_elements(
intersection
.positive(db)
.iter()
.map(|ty| ty.replace_self_reference(db, class)),
)
.negative_elements(
intersection
.negative(db)
.iter()
.map(|ty| ty.replace_self_reference(db, class)),
)
.build(),

Self::Tuple(tuple) => TupleType::from_elements(
db,
tuple
.elements(db)
.iter()
.map(|ty| ty.replace_self_reference(db, class)),
),

Self::Callable(callable) => Self::Callable(callable.replace_self_reference(db, class)),

Self::GenericAlias(_) | Self::TypeVar(_) => {
// TODO: replace self-references in generic aliases and typevars
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was just about to ask if we could re-use this technique for generics 😄

👍 for it being a TODO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc-comment of this function is probably more relevant? This is just the place where we would also have to look for references to the protocol class.

Instead of ClassLiteral, I would imagine that this function would maybe have to take Type as an argument if we want to handle self-referential type aliases like type Tree = dict[str, Tree]. Or maybe an enum (different things that we would want to replace).

*self
}

Self::Dynamic(_)
| Self::AlwaysFalsy
| Self::AlwaysTruthy
| Self::Never
| Self::BooleanLiteral(_)
| Self::BytesLiteral(_)
| Self::StringLiteral(_)
| Self::IntLiteral(_)
| Self::LiteralString
| Self::FunctionLiteral(_)
| Self::ModuleLiteral(_)
| Self::ClassLiteral(_)
| Self::NominalInstance(_)
| Self::KnownInstance(_)
| Self::PropertyInstance(_)
| Self::BoundMethod(_)
| Self::WrapperDescriptor(_)
| Self::MethodWrapper(_)
| Self::DataclassDecorator(_)
| Self::DataclassTransformer(_)
| Self::SubclassOf(_)
| Self::BoundSuper(_) => *self,
}
}

pub fn contains_todo(&self, db: &'db dyn Db) -> bool {
match self {
Self::Dynamic(DynamicType::Todo(_) | DynamicType::SubscriptedProtocol) => true,
Expand Down Expand Up @@ -7272,6 +7345,17 @@ impl<'db> CallableType<'db> {
}
}
}

/// See [`Type::replace_self_reference`].
fn replace_self_reference(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Self {
CallableType::from_overloads(
db,
self.signatures(db)
.iter()
.cloned()
.map(|signature| signature.replace_self_reference(db, class)),
)
}
}

/// Represents a specific instance of `types.MethodWrapperType`
Expand Down
22 changes: 22 additions & 0 deletions crates/ty_python_semantic/src/types/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,28 @@ impl<'db> IntersectionBuilder<'db> {
}
}

pub(crate) fn positive_elements<I, T>(mut self, elements: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<Type<'db>>,
{
for element in elements {
self = self.add_positive(element.into());
}
self
}

pub(crate) fn negative_elements<I, T>(mut self, elements: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<Type<'db>>,
{
for element in elements {
self = self.add_negative(element.into());
}
self
}

pub(crate) fn build(mut self) -> Type<'db> {
// Avoid allocating the UnionBuilder unnecessarily if we have just one intersection:
if self.intersections.len() == 1 {
Expand Down
14 changes: 14 additions & 0 deletions crates/ty_python_semantic/src/types/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::protocol_class::ProtocolInterface;
use super::{ClassType, KnownClass, SubclassOfType, Type};
use crate::symbol::{Symbol, SymbolAndQualifiers};
use crate::types::generics::TypeMapping;
use crate::types::ClassLiteral;
use crate::Db;

pub(super) use synthesized_protocol::SynthesizedProtocolType;
Expand Down Expand Up @@ -183,6 +184,19 @@ impl<'db> ProtocolInstanceType<'db> {
}
}

/// Replace references to `class` with a self-reference marker
pub(super) fn replace_self_reference(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Self {
match self.0 {
Protocol::FromClass(class_type) if class_type.class_literal(db).0 == class => {
ProtocolInstanceType(Protocol::Synthesized(SynthesizedProtocolType::new(
db,
ProtocolInterface::SelfReference,
)))
}
_ => self,
}
}

/// Return `true` if any of the members of this protocol type contain any `Todo` types.
pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool {
self.0.interface(db).contains_todo(db)
Expand Down
Loading
Loading