Skip to content

Commit b4de245

Browse files
authored
[red-knot] Dataclasses: support order=True (#17406)
## Summary Support dataclasses with `order=True`: ```py @DataClass(order=True) class WithOrder: x: int WithOrder(1) < WithOrder(2) # no error ``` Also adds some additional tests to `dataclasses.md`. ticket: #16651 ## Test Plan New Markdown tests
1 parent 914095d commit b4de245

File tree

4 files changed

+219
-30
lines changed

4 files changed

+219
-30
lines changed

.github/workflows/mypy_primer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
--type-checker knot \
6969
--old base_commit \
7070
--new "$GITHUB_SHA" \
71-
--project-selector '/(mypy_primer|black|pyp|git-revise|zipp|arrow|isort|itsdangerous|rich|packaging|pybind11|pyinstrument|typeshed-stats|scrapy|werkzeug|bidict|async-utils)$' \
71+
--project-selector '/(mypy_primer|black|pyp|git-revise|zipp|arrow|isort|itsdangerous|rich|packaging|pybind11|pyinstrument|typeshed-stats|scrapy|werkzeug|bidict|async-utils|python-chess|dacite|python-htmlgen|paroxython|porcupine|psycopg)$' \
7272
--output concise \
7373
--debug > mypy_primer.diff || [ $? -eq 1 ]
7474

crates/red_knot_python_semantic/resources/mdtest/dataclasses.md

Lines changed: 187 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,125 @@ repr(C())
9191
C() == C()
9292
```
9393

94+
## Other dataclass parameters
95+
96+
### `repr`
97+
98+
A custom `__repr__` method is generated by default. It can be disabled by passing `repr=False`, but
99+
in that case `__repr__` is still available via `object.__repr__`:
100+
101+
```py
102+
from dataclasses import dataclass
103+
104+
@dataclass(repr=False)
105+
class WithoutRepr:
106+
x: int
107+
108+
reveal_type(WithoutRepr(1).__repr__) # revealed: bound method WithoutRepr.__repr__() -> str
109+
```
110+
111+
### `eq`
112+
113+
The same is true for `__eq__`. Setting `eq=False` disables the generated `__eq__` method, but
114+
`__eq__` is still available via `object.__eq__`:
115+
116+
```py
117+
from dataclasses import dataclass
118+
119+
@dataclass(eq=False)
120+
class WithoutEq:
121+
x: int
122+
123+
reveal_type(WithoutEq(1) == WithoutEq(2)) # revealed: bool
124+
```
125+
126+
### `order`
127+
128+
`order` is set to `False` by default. If `order=True`, `__lt__`, `__le__`, `__gt__`, and `__ge__`
129+
methods will be generated:
130+
131+
```py
132+
from dataclasses import dataclass
133+
134+
@dataclass
135+
class WithoutOrder:
136+
x: int
137+
138+
WithoutOrder(1) < WithoutOrder(2) # error: [unsupported-operator]
139+
WithoutOrder(1) <= WithoutOrder(2) # error: [unsupported-operator]
140+
WithoutOrder(1) > WithoutOrder(2) # error: [unsupported-operator]
141+
WithoutOrder(1) >= WithoutOrder(2) # error: [unsupported-operator]
142+
143+
@dataclass(order=True)
144+
class WithOrder:
145+
x: int
146+
147+
WithOrder(1) < WithOrder(2)
148+
WithOrder(1) <= WithOrder(2)
149+
WithOrder(1) > WithOrder(2)
150+
WithOrder(1) >= WithOrder(2)
151+
```
152+
153+
Comparisons are only allowed for `WithOrder` instances:
154+
155+
```py
156+
WithOrder(1) < 2 # error: [unsupported-operator]
157+
WithOrder(1) <= 2 # error: [unsupported-operator]
158+
WithOrder(1) > 2 # error: [unsupported-operator]
159+
WithOrder(1) >= 2 # error: [unsupported-operator]
160+
```
161+
162+
This also works for generic dataclasses:
163+
164+
```py
165+
from dataclasses import dataclass
166+
167+
@dataclass(order=True)
168+
class GenericWithOrder[T]:
169+
x: T
170+
171+
GenericWithOrder[int](1) < GenericWithOrder[int](1)
172+
173+
GenericWithOrder[int](1) < GenericWithOrder[str]("a") # error: [unsupported-operator]
174+
```
175+
176+
If a class already defines one of the comparison methods, a `TypeError` is raised at runtime.
177+
Ideally, we would emit a diagnostic in that case:
178+
179+
```py
180+
@dataclass(order=True)
181+
class AlreadyHasCustomDunderLt:
182+
x: int
183+
184+
# TODO: Ideally, we would emit a diagnostic here
185+
def __lt__(self, other: object) -> bool:
186+
return False
187+
```
188+
189+
### `unsafe_hash`
190+
191+
To do
192+
193+
### `frozen`
194+
195+
To do
196+
197+
### `match_args`
198+
199+
To do
200+
201+
### `kw_only`
202+
203+
To do
204+
205+
### `slots`
206+
207+
To do
208+
209+
### `weakref_slot`
210+
211+
To do
212+
94213
## Inheritance
95214

96215
### Normal class inheriting from a dataclass
@@ -168,13 +287,30 @@ reveal_type(d_int.description) # revealed: str
168287
DataWithDescription[int](None, "description")
169288
```
170289

171-
## Frozen instances
290+
## Descriptor-typed fields
172291

173-
To do
292+
```py
293+
from dataclasses import dataclass
174294

175-
## Descriptor-typed fields
295+
class Descriptor:
296+
_value: int = 0
176297

177-
To do
298+
def __get__(self, instance, owner) -> str:
299+
return str(self._value)
300+
301+
def __set__(self, instance, value: int) -> None:
302+
self._value = value
303+
304+
@dataclass
305+
class C:
306+
d: Descriptor = Descriptor()
307+
308+
c = C(1)
309+
reveal_type(c.d) # revealed: str
310+
311+
# TODO: should be an error
312+
C("a")
313+
```
178314

179315
## `dataclasses.field`
180316

@@ -197,18 +333,61 @@ class C:
197333
reveal_type(C.__init__) # revealed: (*args: Any, **kwargs: Any) -> None
198334
```
199335

200-
### Dataclass with `init=False`
336+
### Dataclass with custom `__init__` method
201337

202-
To do
338+
If a class already defines `__init__`, it is not replaced by the `dataclass` decorator.
203339

204-
### Dataclass with custom `__init__` method
340+
```py
341+
from dataclasses import dataclass
205342

206-
To do
343+
@dataclass(init=True)
344+
class C:
345+
x: str
346+
347+
def __init__(self, x: int) -> None:
348+
self.x = str(x)
349+
350+
C(1) # OK
351+
352+
# TODO: should be an error
353+
C("a")
354+
```
355+
356+
Similarly, if we set `init=False`, we still recognize the custom `__init__` method:
357+
358+
```py
359+
@dataclass(init=False)
360+
class D:
361+
def __init__(self, x: int) -> None:
362+
self.x = str(x)
363+
364+
D(1) # OK
365+
D() # error: [missing-argument]
366+
```
207367

208368
### Dataclass with `ClassVar`s
209369

210370
To do
211371

372+
### Return type of `dataclass(...)`
373+
374+
A call like `dataclass(order=True)` returns a callable itself, which is then used as the decorator.
375+
We can store the callable in a variable and later use it as a decorator:
376+
377+
```py
378+
from dataclasses import dataclass
379+
380+
dataclass_with_order = dataclass(order=True)
381+
382+
reveal_type(dataclass_with_order) # revealed: <decorator produced by dataclasses.dataclass>
383+
384+
@dataclass_with_order
385+
class C:
386+
x: int
387+
388+
C(1) < C(2) # ok
389+
```
390+
212391
### Using `dataclass` as a function
213392

214393
To do

crates/red_knot_python_semantic/src/types/class.rs

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -823,21 +823,31 @@ impl<'db> ClassLiteralType<'db> {
823823
name: &str,
824824
) -> SymbolAndQualifiers<'db> {
825825
if let Some(metadata) = self.dataclass_metadata(db) {
826-
if name == "__init__" {
827-
if metadata.contains(DataclassMetadata::INIT) {
828-
// TODO: Generate the signature from the attributes on the class
829-
let init_signature = Signature::new(
830-
Parameters::new([
831-
Parameter::variadic(Name::new_static("args"))
832-
.with_annotated_type(Type::any()),
833-
Parameter::keyword_variadic(Name::new_static("kwargs"))
834-
.with_annotated_type(Type::any()),
835-
]),
836-
Some(Type::none(db)),
826+
if name == "__init__" && metadata.contains(DataclassMetadata::INIT) {
827+
// TODO: Generate the signature from the attributes on the class
828+
let init_signature = Signature::new(
829+
Parameters::new([
830+
Parameter::variadic(Name::new_static("args"))
831+
.with_annotated_type(Type::any()),
832+
Parameter::keyword_variadic(Name::new_static("kwargs"))
833+
.with_annotated_type(Type::any()),
834+
]),
835+
Some(Type::none(db)),
836+
);
837+
838+
return Symbol::bound(Type::Callable(CallableType::new(db, init_signature))).into();
839+
} else if matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") {
840+
if metadata.contains(DataclassMetadata::ORDER) {
841+
let signature = Signature::new(
842+
Parameters::new([Parameter::positional_or_keyword(Name::new_static(
843+
"other",
844+
))
845+
.with_annotated_type(Type::instance(
846+
self.apply_optional_specialization(db, specialization),
847+
))]),
848+
Some(KnownClass::Bool.to_instance(db)),
837849
);
838-
839-
return Symbol::bound(Type::Callable(CallableType::new(db, init_signature)))
840-
.into();
850+
return Symbol::bound(Type::Callable(CallableType::new(db, signature))).into();
841851
}
842852
}
843853
}

crates/red_knot_python_semantic/src/types/generics.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::Db;
1313
///
1414
/// TODO: Handle nested generic contexts better, with actual parent links to the lexically
1515
/// containing context.
16-
#[salsa::tracked(debug)]
16+
#[salsa::interned(debug)]
1717
pub struct GenericContext<'db> {
1818
#[return_ref]
1919
pub(crate) variables: Box<[TypeVarInstance<'db>]>,
@@ -25,7 +25,7 @@ impl<'db> GenericContext<'db> {
2525
index: &'db SemanticIndex<'db>,
2626
type_params_node: &ast::TypeParams,
2727
) -> Self {
28-
let variables = type_params_node
28+
let variables: Box<[_]> = type_params_node
2929
.iter()
3030
.filter_map(|type_param| Self::variable_from_type_param(db, index, type_param))
3131
.collect();
@@ -116,7 +116,7 @@ impl<'db> GenericContext<'db> {
116116
///
117117
/// TODO: Handle nested specializations better, with actual parent links to the specialization of
118118
/// the lexically containing context.
119-
#[salsa::tracked(debug)]
119+
#[salsa::interned(debug)]
120120
pub struct Specialization<'db> {
121121
pub(crate) generic_context: GenericContext<'db>,
122122
#[return_ref]
@@ -138,7 +138,7 @@ impl<'db> Specialization<'db> {
138138
/// That lets us produce the generic alias `A[int]`, which is the corresponding entry in the
139139
/// MRO of `B[int]`.
140140
pub(crate) fn apply_specialization(self, db: &'db dyn Db, other: Specialization<'db>) -> Self {
141-
let types = self
141+
let types: Box<[_]> = self
142142
.types(db)
143143
.into_iter()
144144
.map(|ty| ty.apply_specialization(db, other))
@@ -154,7 +154,7 @@ impl<'db> Specialization<'db> {
154154
pub(crate) fn combine(self, db: &'db dyn Db, other: Self) -> Self {
155155
let generic_context = self.generic_context(db);
156156
assert!(other.generic_context(db) == generic_context);
157-
let types = self
157+
let types: Box<[_]> = self
158158
.types(db)
159159
.into_iter()
160160
.zip(other.types(db))
@@ -167,7 +167,7 @@ impl<'db> Specialization<'db> {
167167
}
168168

169169
pub(crate) fn normalized(self, db: &'db dyn Db) -> Self {
170-
let types = self.types(db).iter().map(|ty| ty.normalized(db)).collect();
170+
let types: Box<[_]> = self.types(db).iter().map(|ty| ty.normalized(db)).collect();
171171
Self::new(db, self.generic_context(db), types)
172172
}
173173

@@ -201,7 +201,7 @@ impl<'db> SpecializationBuilder<'db> {
201201
}
202202

203203
pub(crate) fn build(mut self) -> Specialization<'db> {
204-
let types = self
204+
let types: Box<[_]> = self
205205
.generic_context
206206
.variables(self.db)
207207
.iter()

0 commit comments

Comments
 (0)