Skip to content

Commit 642eac4

Browse files
authored
[ty] Recursive protocols (#17929)
## Summary Use a self-reference "marker" ~~and fixpoint iteration~~ to solve the stack overflow problems with recursive protocols. This is not pretty and somewhat tedious, but seems to work fine. Much better than all my fixpoint-iteration attempts anyway. closes astral-sh/ty#93 ## Test Plan New Markdown tests.
1 parent c1b8757 commit 642eac4

File tree

6 files changed

+315
-64
lines changed

6 files changed

+315
-64
lines changed

crates/ty_python_semantic/resources/mdtest/protocols.md

+103-8
Original file line numberDiff line numberDiff line change
@@ -1569,11 +1569,11 @@ from typing import Protocol, Any
15691569
from ty_extensions import is_fully_static, static_assert, is_assignable_to, is_subtype_of, is_equivalent_to
15701570

15711571
class RecursiveFullyStatic(Protocol):
1572-
parent: RecursiveFullyStatic | None
1572+
parent: RecursiveFullyStatic
15731573
x: int
15741574

15751575
class RecursiveNonFullyStatic(Protocol):
1576-
parent: RecursiveNonFullyStatic | None
1576+
parent: RecursiveNonFullyStatic
15771577
x: Any
15781578

15791579
static_assert(is_fully_static(RecursiveFullyStatic))
@@ -1582,16 +1582,111 @@ static_assert(not is_fully_static(RecursiveNonFullyStatic))
15821582
static_assert(not is_subtype_of(RecursiveFullyStatic, RecursiveNonFullyStatic))
15831583
static_assert(not is_subtype_of(RecursiveNonFullyStatic, RecursiveFullyStatic))
15841584

1585-
# TODO: currently leads to a stack overflow
1586-
# static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic))
1587-
# static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveFullyStatic))
1585+
static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveNonFullyStatic))
1586+
static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic))
1587+
static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveFullyStatic))
15881588

15891589
class AlsoRecursiveFullyStatic(Protocol):
1590-
parent: AlsoRecursiveFullyStatic | None
1590+
parent: AlsoRecursiveFullyStatic
15911591
x: int
15921592

1593-
# TODO: currently leads to a stack overflow
1594-
# static_assert(is_equivalent_to(AlsoRecursiveFullyStatic, RecursiveFullyStatic))
1593+
static_assert(is_equivalent_to(AlsoRecursiveFullyStatic, RecursiveFullyStatic))
1594+
1595+
class RecursiveOptionalParent(Protocol):
1596+
parent: RecursiveOptionalParent | None
1597+
1598+
static_assert(is_fully_static(RecursiveOptionalParent))
1599+
1600+
static_assert(is_assignable_to(RecursiveOptionalParent, RecursiveOptionalParent))
1601+
1602+
static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveOptionalParent))
1603+
static_assert(not is_assignable_to(RecursiveOptionalParent, RecursiveNonFullyStatic))
1604+
1605+
class Other(Protocol):
1606+
z: str
1607+
1608+
def _(rec: RecursiveFullyStatic, other: Other):
1609+
reveal_type(rec.parent.parent.parent) # revealed: RecursiveFullyStatic
1610+
1611+
rec.parent.parent.parent = rec
1612+
rec = rec.parent.parent.parent
1613+
1614+
rec.parent.parent.parent = other # error: [invalid-assignment]
1615+
other = rec.parent.parent.parent # error: [invalid-assignment]
1616+
1617+
class Foo(Protocol):
1618+
@property
1619+
def x(self) -> "Foo": ...
1620+
1621+
class Bar(Protocol):
1622+
@property
1623+
def x(self) -> "Bar": ...
1624+
1625+
# TODO: this should pass
1626+
# error: [static-assert-error]
1627+
static_assert(is_equivalent_to(Foo, Bar))
1628+
```
1629+
1630+
### Nested occurrences of self-reference
1631+
1632+
Make sure that we handle self-reference correctly, even if the self-reference appears deeply nested
1633+
within the type of a protocol member:
1634+
1635+
```toml
1636+
[environment]
1637+
python-version = "3.12"
1638+
```
1639+
1640+
```py
1641+
from __future__ import annotations
1642+
1643+
from typing import Protocol, Callable
1644+
from ty_extensions import Intersection, Not, is_fully_static, is_assignable_to, is_equivalent_to, static_assert
1645+
1646+
class C: ...
1647+
1648+
class GenericC[T](Protocol):
1649+
pass
1650+
1651+
class Recursive(Protocol):
1652+
direct: Recursive
1653+
1654+
union: None | Recursive
1655+
1656+
intersection1: Intersection[C, Recursive]
1657+
intersection2: Intersection[C, Not[Recursive]]
1658+
1659+
t: tuple[int, tuple[str, Recursive]]
1660+
1661+
callable1: Callable[[int], Recursive]
1662+
callable2: Callable[[Recursive], int]
1663+
1664+
subtype_of: type[Recursive]
1665+
1666+
generic: GenericC[Recursive]
1667+
1668+
def method(self, x: Recursive) -> Recursive: ...
1669+
1670+
nested: Recursive | Callable[[Recursive | Recursive, tuple[Recursive, Recursive]], Recursive | Recursive]
1671+
1672+
static_assert(is_fully_static(Recursive))
1673+
static_assert(is_equivalent_to(Recursive, Recursive))
1674+
static_assert(is_assignable_to(Recursive, Recursive))
1675+
1676+
def _(r: Recursive):
1677+
reveal_type(r.direct) # revealed: Recursive
1678+
reveal_type(r.union) # revealed: None | Recursive
1679+
reveal_type(r.intersection1) # revealed: C & Recursive
1680+
reveal_type(r.intersection2) # revealed: C & ~Recursive
1681+
reveal_type(r.t) # revealed: tuple[int, tuple[str, Recursive]]
1682+
reveal_type(r.callable1) # revealed: (int, /) -> Recursive
1683+
reveal_type(r.callable2) # revealed: (Recursive, /) -> int
1684+
reveal_type(r.subtype_of) # revealed: type[Recursive]
1685+
reveal_type(r.generic) # revealed: GenericC[Recursive]
1686+
reveal_type(r.method(r)) # revealed: Recursive
1687+
reveal_type(r.nested) # revealed: Recursive | ((Recursive, tuple[Recursive, Recursive], /) -> Recursive)
1688+
1689+
reveal_type(r.method(r).callable1(1).direct.t[1][1]) # revealed: Recursive
15951690
```
15961691

15971692
### Regression test: narrowing with self-referential protocols

crates/ty_python_semantic/src/types.rs

+84
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,79 @@ impl<'db> Type<'db> {
587587
matches!(self, Type::Dynamic(DynamicType::Todo(_)))
588588
}
589589

590+
/// Replace references to the class `class` with a self-reference marker. This is currently
591+
/// used for recursive protocols, but could probably be extended to self-referential type-
592+
/// aliases and similar.
593+
#[must_use]
594+
pub fn replace_self_reference(&self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Type<'db> {
595+
match self {
596+
Self::ProtocolInstance(protocol) => {
597+
Self::ProtocolInstance(protocol.replace_self_reference(db, class))
598+
}
599+
600+
Self::Union(union) => UnionType::from_elements(
601+
db,
602+
union
603+
.elements(db)
604+
.iter()
605+
.map(|ty| ty.replace_self_reference(db, class)),
606+
),
607+
608+
Self::Intersection(intersection) => IntersectionBuilder::new(db)
609+
.positive_elements(
610+
intersection
611+
.positive(db)
612+
.iter()
613+
.map(|ty| ty.replace_self_reference(db, class)),
614+
)
615+
.negative_elements(
616+
intersection
617+
.negative(db)
618+
.iter()
619+
.map(|ty| ty.replace_self_reference(db, class)),
620+
)
621+
.build(),
622+
623+
Self::Tuple(tuple) => TupleType::from_elements(
624+
db,
625+
tuple
626+
.elements(db)
627+
.iter()
628+
.map(|ty| ty.replace_self_reference(db, class)),
629+
),
630+
631+
Self::Callable(callable) => Self::Callable(callable.replace_self_reference(db, class)),
632+
633+
Self::GenericAlias(_) | Self::TypeVar(_) => {
634+
// TODO: replace self-references in generic aliases and typevars
635+
*self
636+
}
637+
638+
Self::Dynamic(_)
639+
| Self::AlwaysFalsy
640+
| Self::AlwaysTruthy
641+
| Self::Never
642+
| Self::BooleanLiteral(_)
643+
| Self::BytesLiteral(_)
644+
| Self::StringLiteral(_)
645+
| Self::IntLiteral(_)
646+
| Self::LiteralString
647+
| Self::FunctionLiteral(_)
648+
| Self::ModuleLiteral(_)
649+
| Self::ClassLiteral(_)
650+
| Self::NominalInstance(_)
651+
| Self::KnownInstance(_)
652+
| Self::PropertyInstance(_)
653+
| Self::BoundMethod(_)
654+
| Self::WrapperDescriptor(_)
655+
| Self::MethodWrapper(_)
656+
| Self::DataclassDecorator(_)
657+
| Self::DataclassTransformer(_)
658+
| Self::SubclassOf(_)
659+
| Self::BoundSuper(_) => *self,
660+
}
661+
}
662+
590663
pub fn contains_todo(&self, db: &'db dyn Db) -> bool {
591664
match self {
592665
Self::Dynamic(DynamicType::Todo(_) | DynamicType::SubscriptedProtocol) => true,
@@ -7272,6 +7345,17 @@ impl<'db> CallableType<'db> {
72727345
}
72737346
}
72747347
}
7348+
7349+
/// See [`Type::replace_self_reference`].
7350+
fn replace_self_reference(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Self {
7351+
CallableType::from_overloads(
7352+
db,
7353+
self.signatures(db)
7354+
.iter()
7355+
.cloned()
7356+
.map(|signature| signature.replace_self_reference(db, class)),
7357+
)
7358+
}
72757359
}
72767360

72777361
/// Represents a specific instance of `types.MethodWrapperType`

crates/ty_python_semantic/src/types/builder.rs

+22
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,28 @@ impl<'db> IntersectionBuilder<'db> {
529529
}
530530
}
531531

532+
pub(crate) fn positive_elements<I, T>(mut self, elements: I) -> Self
533+
where
534+
I: IntoIterator<Item = T>,
535+
T: Into<Type<'db>>,
536+
{
537+
for element in elements {
538+
self = self.add_positive(element.into());
539+
}
540+
self
541+
}
542+
543+
pub(crate) fn negative_elements<I, T>(mut self, elements: I) -> Self
544+
where
545+
I: IntoIterator<Item = T>,
546+
T: Into<Type<'db>>,
547+
{
548+
for element in elements {
549+
self = self.add_negative(element.into());
550+
}
551+
self
552+
}
553+
532554
pub(crate) fn build(mut self) -> Type<'db> {
533555
// Avoid allocating the UnionBuilder unnecessarily if we have just one intersection:
534556
if self.intersections.len() == 1 {

crates/ty_python_semantic/src/types/instance.rs

+14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use super::protocol_class::ProtocolInterface;
44
use super::{ClassType, KnownClass, SubclassOfType, Type};
55
use crate::symbol::{Symbol, SymbolAndQualifiers};
66
use crate::types::generics::TypeMapping;
7+
use crate::types::ClassLiteral;
78
use crate::Db;
89

910
pub(super) use synthesized_protocol::SynthesizedProtocolType;
@@ -183,6 +184,19 @@ impl<'db> ProtocolInstanceType<'db> {
183184
}
184185
}
185186

187+
/// Replace references to `class` with a self-reference marker
188+
pub(super) fn replace_self_reference(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Self {
189+
match self.0 {
190+
Protocol::FromClass(class_type) if class_type.class_literal(db).0 == class => {
191+
ProtocolInstanceType(Protocol::Synthesized(SynthesizedProtocolType::new(
192+
db,
193+
ProtocolInterface::SelfReference,
194+
)))
195+
}
196+
_ => self,
197+
}
198+
}
199+
186200
/// Return `true` if any of the members of this protocol type contain any `Todo` types.
187201
pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool {
188202
self.0.interface(db).contains_todo(db)

0 commit comments

Comments
 (0)