Skip to content

Commit 68cffa7

Browse files
authored
[stubgen] Improve dataclass init signatures (#18430)
Remove generated incomplete `__init__` signatures for dataclasses. Keep the field specifiers instead.
1 parent c4e2eb7 commit 68cffa7

File tree

3 files changed

+81
-34
lines changed

3 files changed

+81
-34
lines changed

mypy/plugins/dataclasses.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@
7979

8080
# The set of decorators that generate dataclasses.
8181
dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"}
82+
# Default field specifiers for dataclasses
83+
DATACLASS_FIELD_SPECIFIERS: Final = ("dataclasses.Field", "dataclasses.field")
8284

8385

8486
SELF_TVAR_NAME: Final = "_DT"
@@ -87,7 +89,7 @@
8789
order_default=False,
8890
kw_only_default=False,
8991
frozen_default=False,
90-
field_specifiers=("dataclasses.Field", "dataclasses.field"),
92+
field_specifiers=DATACLASS_FIELD_SPECIFIERS,
9193
)
9294
_INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace"
9395
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-post_init"

mypy/stubgen.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
ImportFrom,
9696
IndexExpr,
9797
IntExpr,
98+
LambdaExpr,
9899
ListExpr,
99100
MemberExpr,
100101
MypyFile,
@@ -113,6 +114,7 @@
113114
Var,
114115
)
115116
from mypy.options import Options as MypyOptions
117+
from mypy.plugins.dataclasses import DATACLASS_FIELD_SPECIFIERS
116118
from mypy.semanal_shared import find_dataclass_transform_spec
117119
from mypy.sharedparse import MAGIC_METHODS_POS_ARGS_ONLY
118120
from mypy.stubdoc import ArgSig, FunctionSig
@@ -342,11 +344,12 @@ def visit_index_expr(self, node: IndexExpr) -> str:
342344
base = node.base.accept(self)
343345
index = node.index.accept(self)
344346
if len(index) > 2 and index.startswith("(") and index.endswith(")"):
345-
index = index[1:-1]
347+
index = index[1:-1].rstrip(",")
346348
return f"{base}[{index}]"
347349

348350
def visit_tuple_expr(self, node: TupleExpr) -> str:
349-
return f"({', '.join(n.accept(self) for n in node.items)})"
351+
suffix = "," if len(node.items) == 1 else ""
352+
return f"({', '.join(n.accept(self) for n in node.items)}{suffix})"
350353

351354
def visit_list_expr(self, node: ListExpr) -> str:
352355
return f"[{', '.join(n.accept(self) for n in node.items)}]"
@@ -368,6 +371,10 @@ def visit_op_expr(self, o: OpExpr) -> str:
368371
def visit_star_expr(self, o: StarExpr) -> str:
369372
return f"*{o.expr.accept(self)}"
370373

374+
def visit_lambda_expr(self, o: LambdaExpr) -> str:
375+
# TODO: Required for among other things dataclass.field default_factory
376+
return self.stubgen.add_name("_typeshed.Incomplete")
377+
371378

372379
def find_defined_names(file: MypyFile) -> set[str]:
373380
finder = DefinitionFinder()
@@ -482,6 +489,7 @@ def __init__(
482489
self.method_names: set[str] = set()
483490
self.processing_enum = False
484491
self.processing_dataclass = False
492+
self.dataclass_field_specifier: tuple[str, ...] = ()
485493

486494
@property
487495
def _current_class(self) -> ClassDef | None:
@@ -636,8 +644,8 @@ def visit_func_def(self, o: FuncDef) -> None:
636644
is_dataclass_generated = (
637645
self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated
638646
)
639-
if is_dataclass_generated and o.name != "__init__":
640-
# Skip methods generated by the @dataclass decorator (except for __init__)
647+
if is_dataclass_generated:
648+
# Skip methods generated by the @dataclass decorator
641649
return
642650
if (
643651
self.is_private_name(o.name, o.fullname)
@@ -793,8 +801,9 @@ def visit_class_def(self, o: ClassDef) -> None:
793801
self.add(f"{self._indent}{docstring}\n")
794802
n = len(self._output)
795803
self._vars.append([])
796-
if self.analyzed and find_dataclass_transform_spec(o):
804+
if self.analyzed and (spec := find_dataclass_transform_spec(o)):
797805
self.processing_dataclass = True
806+
self.dataclass_field_specifier = spec.field_specifiers
798807
super().visit_class_def(o)
799808
self.dedent()
800809
self._vars.pop()
@@ -809,6 +818,7 @@ def visit_class_def(self, o: ClassDef) -> None:
809818
self._state = CLASS
810819
self.method_names = set()
811820
self.processing_dataclass = False
821+
self.dataclass_field_specifier = ()
812822
self._class_stack.pop(-1)
813823
self.processing_enum = False
814824

@@ -879,8 +889,9 @@ def is_dataclass_transform(self, expr: Expression) -> bool:
879889
expr = expr.callee
880890
if self.get_fullname(expr) in DATACLASS_TRANSFORM_NAMES:
881891
return True
882-
if find_dataclass_transform_spec(expr) is not None:
892+
if (spec := find_dataclass_transform_spec(expr)) is not None:
883893
self.processing_dataclass = True
894+
self.dataclass_field_specifier = spec.field_specifiers
884895
return True
885896
return False
886897

@@ -1259,8 +1270,14 @@ def get_assign_initializer(self, rvalue: Expression) -> str:
12591270
and not isinstance(rvalue, TempNode)
12601271
):
12611272
return " = ..."
1262-
if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
1263-
return " = ..."
1273+
if self.processing_dataclass:
1274+
if isinstance(rvalue, CallExpr):
1275+
fullname = self.get_fullname(rvalue.callee)
1276+
if fullname in (self.dataclass_field_specifier or DATACLASS_FIELD_SPECIFIERS):
1277+
p = AliasPrinter(self)
1278+
return f" = {rvalue.accept(p)}"
1279+
if not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
1280+
return " = ..."
12641281
# TODO: support other possible cases, where initializer is important
12651282

12661283
# By default, no initializer is required:

test-data/unit/stubgen.test

+53-25
Original file line numberDiff line numberDiff line change
@@ -3101,15 +3101,14 @@ import attrs
31013101

31023102
@attrs.define
31033103
class C:
3104-
x = attrs.field()
3104+
x: int = attrs.field()
31053105

31063106
[out]
31073107
import attrs
31083108

31093109
@attrs.define
31103110
class C:
3111-
x = ...
3112-
def __init__(self, x) -> None: ...
3111+
x: int = attrs.field()
31133112

31143113
[case testNamedTupleInClass]
31153114
from collections import namedtuple
@@ -4050,8 +4049,9 @@ def i(x=..., y=..., z=...) -> None: ...
40504049
[case testDataclass]
40514050
import dataclasses
40524051
import dataclasses as dcs
4053-
from dataclasses import dataclass, InitVar, KW_ONLY
4052+
from dataclasses import dataclass, field, Field, InitVar, KW_ONLY
40544053
from dataclasses import dataclass as dc
4054+
from datetime import datetime
40554055
from typing import ClassVar
40564056

40574057
@dataclasses.dataclass
@@ -4066,6 +4066,10 @@ class X:
40664066
h: int = 1
40674067
i: InitVar[str]
40684068
j: InitVar = 100
4069+
# Lambda not supported yet -> marked as Incomplete instead
4070+
k: str = Field(
4071+
default_factory=lambda: datetime.utcnow().isoformat(" ", timespec="seconds")
4072+
)
40694073
non_field = None
40704074

40714075
@dcs.dataclass
@@ -4083,7 +4087,8 @@ class V: ...
40834087
[out]
40844088
import dataclasses
40854089
import dataclasses as dcs
4086-
from dataclasses import InitVar, KW_ONLY, dataclass, dataclass as dc
4090+
from _typeshed import Incomplete
4091+
from dataclasses import Field, InitVar, KW_ONLY, dataclass, dataclass as dc, field
40874092
from typing import ClassVar
40884093

40894094
@dataclasses.dataclass
@@ -4092,12 +4097,13 @@ class X:
40924097
b: str = ...
40934098
c: ClassVar
40944099
d: ClassVar = ...
4095-
f: list[int] = ...
4096-
g: int = ...
4100+
f: list[int] = field(init=False, default_factory=list)
4101+
g: int = field(default=2, kw_only=True)
40974102
_: KW_ONLY
40984103
h: int = ...
40994104
i: InitVar[str]
41004105
j: InitVar = ...
4106+
k: str = Field(default_factory=Incomplete)
41014107
non_field = ...
41024108

41034109
@dcs.dataclass
@@ -4110,8 +4116,9 @@ class W: ...
41104116
class V: ...
41114117

41124118
[case testDataclass_semanal]
4113-
from dataclasses import InitVar, dataclass, field
4119+
from dataclasses import Field, InitVar, dataclass, field
41144120
from typing import ClassVar
4121+
from datetime import datetime
41154122

41164123
@dataclass
41174124
class X:
@@ -4125,13 +4132,18 @@ class X:
41254132
h: int = 1
41264133
i: InitVar = 100
41274134
j: list[int] = field(default_factory=list)
4135+
# Lambda not supported yet -> marked as Incomplete instead
4136+
k: str = Field(
4137+
default_factory=lambda: datetime.utcnow().isoformat(" ", timespec="seconds")
4138+
)
41284139
non_field = None
41294140

41304141
@dataclass(init=False, repr=False, frozen=True)
41314142
class Y: ...
41324143

41334144
[out]
4134-
from dataclasses import InitVar, dataclass
4145+
from _typeshed import Incomplete
4146+
from dataclasses import Field, InitVar, dataclass, field
41354147
from typing import ClassVar
41364148

41374149
@dataclass
@@ -4141,13 +4153,13 @@ class X:
41414153
c: str = ...
41424154
d: ClassVar
41434155
e: ClassVar = ...
4144-
f: list[int] = ...
4145-
g: int = ...
4156+
f: list[int] = field(init=False, default_factory=list)
4157+
g: int = field(default=2, kw_only=True)
41464158
h: int = ...
41474159
i: InitVar = ...
4148-
j: list[int] = ...
4160+
j: list[int] = field(default_factory=list)
4161+
k: str = Field(default_factory=Incomplete)
41494162
non_field = ...
4150-
def __init__(self, a, b, c=..., *, g=..., h=..., i=..., j=...) -> None: ...
41514163

41524164
@dataclass(init=False, repr=False, frozen=True)
41534165
class Y: ...
@@ -4175,7 +4187,7 @@ class X:
41754187
class Y: ...
41764188

41774189
[out]
4178-
from dataclasses import InitVar, KW_ONLY, dataclass
4190+
from dataclasses import InitVar, KW_ONLY, dataclass, field
41794191
from typing import ClassVar
41804192

41814193
@dataclass
@@ -4184,14 +4196,13 @@ class X:
41844196
b: str = ...
41854197
c: ClassVar
41864198
d: ClassVar = ...
4187-
f: list[int] = ...
4188-
g: int = ...
4199+
f: list[int] = field(init=False, default_factory=list)
4200+
g: int = field(default=2, kw_only=True)
41894201
_: KW_ONLY
41904202
h: int = ...
41914203
i: InitVar[str]
41924204
j: InitVar = ...
41934205
non_field = ...
4194-
def __init__(self, a, b=..., *, g=..., h=..., i, j=...) -> None: ...
41954206

41964207
@dataclass(init=False, repr=False, frozen=True)
41974208
class Y: ...
@@ -4236,15 +4247,13 @@ from dataclasses import dataclass
42364247
@dataclass
42374248
class X(missing.Base):
42384249
a: int
4239-
def __init__(self, *generated_args, a, **generated_kwargs) -> None: ...
42404250

42414251
@dataclass
42424252
class Y(missing.Base):
42434253
generated_args: str
42444254
generated_args_: str
42454255
generated_kwargs: float
42464256
generated_kwargs_: float
4247-
def __init__(self, *generated_args__, generated_args, generated_args_, generated_kwargs, generated_kwargs_, **generated_kwargs__) -> None: ...
42484257

42494258
[case testDataclassTransform]
42504259
# dataclass_transform detection only works with sementic analysis.
@@ -4298,6 +4307,7 @@ class Z(metaclass=DCMeta):
42984307

42994308
[case testDataclassTransformDecorator_semanal]
43004309
import typing_extensions
4310+
from dataclasses import field
43014311

43024312
@typing_extensions.dataclass_transform(kw_only_default=True)
43034313
def create_model(cls):
@@ -4307,9 +4317,11 @@ def create_model(cls):
43074317
class X:
43084318
a: int
43094319
b: str = "hello"
4320+
c: bool = field(default=True)
43104321

43114322
[out]
43124323
import typing_extensions
4324+
from dataclasses import field
43134325

43144326
@typing_extensions.dataclass_transform(kw_only_default=True)
43154327
def create_model(cls): ...
@@ -4318,9 +4330,10 @@ def create_model(cls): ...
43184330
class X:
43194331
a: int
43204332
b: str = ...
4321-
def __init__(self, *, a, b=...) -> None: ...
4333+
c: bool = field(default=True)
43224334

43234335
[case testDataclassTransformClass_semanal]
4336+
from dataclasses import field
43244337
from typing_extensions import dataclass_transform
43254338

43264339
@dataclass_transform(kw_only_default=True)
@@ -4329,8 +4342,10 @@ class ModelBase: ...
43294342
class X(ModelBase):
43304343
a: int
43314344
b: str = "hello"
4345+
c: bool = field(default=True)
43324346

43334347
[out]
4348+
from dataclasses import field
43344349
from typing_extensions import dataclass_transform
43354350

43364351
@dataclass_transform(kw_only_default=True)
@@ -4339,28 +4354,42 @@ class ModelBase: ...
43394354
class X(ModelBase):
43404355
a: int
43414356
b: str = ...
4342-
def __init__(self, *, a, b=...) -> None: ...
4357+
c: bool = field(default=True)
43434358

43444359
[case testDataclassTransformMetaclass_semanal]
4360+
from dataclasses import field
4361+
from typing import Any
43454362
from typing_extensions import dataclass_transform
43464363

4347-
@dataclass_transform(kw_only_default=True)
4364+
def custom_field(*, default: bool, kw_only: bool) -> Any: ...
4365+
4366+
@dataclass_transform(kw_only_default=True, field_specifiers=(custom_field,))
43484367
class DCMeta(type): ...
43494368

43504369
class X(metaclass=DCMeta):
43514370
a: int
43524371
b: str = "hello"
4372+
c: bool = field(default=True) # should be ignored, not field_specifier here
4373+
4374+
class Y(X):
4375+
d: str = custom_field(default="Hello")
43534376

43544377
[out]
4378+
from typing import Any
43554379
from typing_extensions import dataclass_transform
43564380

4357-
@dataclass_transform(kw_only_default=True)
4381+
def custom_field(*, default: bool, kw_only: bool) -> Any: ...
4382+
4383+
@dataclass_transform(kw_only_default=True, field_specifiers=(custom_field,))
43584384
class DCMeta(type): ...
43594385

43604386
class X(metaclass=DCMeta):
43614387
a: int
43624388
b: str = ...
4363-
def __init__(self, *, a, b=...) -> None: ...
4389+
c: bool = ...
4390+
4391+
class Y(X):
4392+
d: str = custom_field(default='Hello')
43644393

43654394
[case testAlwaysUsePEP604Union]
43664395
import typing
@@ -4662,4 +4691,3 @@ class DCMeta(type): ...
46624691

46634692
class DC(metaclass=DCMeta):
46644693
x: str
4665-
def __init__(self, x) -> None: ...

0 commit comments

Comments
 (0)