Skip to content

Commit 2bbc42f

Browse files
authored
stubgen: generate valid dataclass stubs (#15625)
Fixes #12441 Fixes #9986 Fixes #15966
1 parent 402c8ff commit 2bbc42f

File tree

3 files changed

+244
-6
lines changed

3 files changed

+244
-6
lines changed

mypy/stubgen.py

+51-6
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ def __init__(
657657
self.defined_names: set[str] = set()
658658
# Short names of methods defined in the body of the current class
659659
self.method_names: set[str] = set()
660+
self.processing_dataclass = False
660661

661662
def visit_mypy_file(self, o: MypyFile) -> None:
662663
self.module = o.fullname # Current module being processed
@@ -706,6 +707,12 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
706707
self.clear_decorators()
707708

708709
def visit_func_def(self, o: FuncDef) -> None:
710+
is_dataclass_generated = (
711+
self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated
712+
)
713+
if is_dataclass_generated and o.name != "__init__":
714+
# Skip methods generated by the @dataclass decorator (except for __init__)
715+
return
709716
if (
710717
self.is_private_name(o.name, o.fullname)
711718
or self.is_not_in_all(o.name)
@@ -771,6 +778,12 @@ def visit_func_def(self, o: FuncDef) -> None:
771778
else:
772779
arg = name + annotation
773780
args.append(arg)
781+
if o.name == "__init__" and is_dataclass_generated and "**" in args:
782+
# The dataclass plugin generates invalid nameless "*" and "**" arguments
783+
new_name = "".join(a.split(":", 1)[0] for a in args).replace("*", "")
784+
args[args.index("*")] = f"*{new_name}_" # this name is guaranteed to be unique
785+
args[args.index("**")] = f"**{new_name}__" # same here
786+
774787
retname = None
775788
if o.name != "__init__" and isinstance(o.unanalyzed_type, CallableType):
776789
if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType):
@@ -899,6 +912,9 @@ def visit_class_def(self, o: ClassDef) -> None:
899912
if not self._indent and self._state != EMPTY:
900913
sep = len(self._output)
901914
self.add("\n")
915+
decorators = self.get_class_decorators(o)
916+
for d in decorators:
917+
self.add(f"{self._indent}@{d}\n")
902918
self.add(f"{self._indent}class {o.name}")
903919
self.record_name(o.name)
904920
base_types = self.get_base_types(o)
@@ -934,6 +950,7 @@ def visit_class_def(self, o: ClassDef) -> None:
934950
else:
935951
self._state = CLASS
936952
self.method_names = set()
953+
self.processing_dataclass = False
937954
self._current_class = None
938955

939956
def get_base_types(self, cdef: ClassDef) -> list[str]:
@@ -979,6 +996,21 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
979996
base_types.append(f"{name}={value.accept(p)}")
980997
return base_types
981998

999+
def get_class_decorators(self, cdef: ClassDef) -> list[str]:
1000+
decorators: list[str] = []
1001+
p = AliasPrinter(self)
1002+
for d in cdef.decorators:
1003+
if self.is_dataclass(d):
1004+
decorators.append(d.accept(p))
1005+
self.import_tracker.require_name(get_qualified_name(d))
1006+
self.processing_dataclass = True
1007+
return decorators
1008+
1009+
def is_dataclass(self, expr: Expression) -> bool:
1010+
if isinstance(expr, CallExpr):
1011+
expr = expr.callee
1012+
return self.get_fullname(expr) == "dataclasses.dataclass"
1013+
9821014
def visit_block(self, o: Block) -> None:
9831015
# Unreachable statements may be partially uninitialized and that may
9841016
# cause trouble.
@@ -1336,19 +1368,30 @@ def get_init(
13361368
# Final without type argument is invalid in stubs.
13371369
final_arg = self.get_str_type_of_node(rvalue)
13381370
typename += f"[{final_arg}]"
1371+
elif self.processing_dataclass:
1372+
# attribute without annotation is not a dataclass field, don't add annotation.
1373+
return f"{self._indent}{lvalue} = ...\n"
13391374
else:
13401375
typename = self.get_str_type_of_node(rvalue)
13411376
initializer = self.get_assign_initializer(rvalue)
13421377
return f"{self._indent}{lvalue}: {typename}{initializer}\n"
13431378

13441379
def get_assign_initializer(self, rvalue: Expression) -> str:
13451380
"""Does this rvalue need some special initializer value?"""
1346-
if self._current_class and self._current_class.info:
1347-
# Current rules
1348-
# 1. Return `...` if we are dealing with `NamedTuple` and it has an existing default value
1349-
if self._current_class.info.is_named_tuple and not isinstance(rvalue, TempNode):
1350-
return " = ..."
1351-
# TODO: support other possible cases, where initializer is important
1381+
if not self._current_class:
1382+
return ""
1383+
# Current rules
1384+
# 1. Return `...` if we are dealing with `NamedTuple` or `dataclass` field and
1385+
# it has an existing default value
1386+
if (
1387+
self._current_class.info
1388+
and self._current_class.info.is_named_tuple
1389+
and not isinstance(rvalue, TempNode)
1390+
):
1391+
return " = ..."
1392+
if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
1393+
return " = ..."
1394+
# TODO: support other possible cases, where initializer is important
13521395

13531396
# By default, no initializer is required:
13541397
return ""
@@ -1410,6 +1453,8 @@ def is_private_name(self, name: str, fullname: str | None = None) -> bool:
14101453
return False
14111454
if fullname in EXTRA_EXPORTED:
14121455
return False
1456+
if name == "_":
1457+
return False
14131458
return name.startswith("_") and (not name.endswith("__") or name in IGNORED_DUNDERS)
14141459

14151460
def is_private_member(self, fullname: str) -> bool:

mypy/test/teststubgen.py

+11
Original file line numberDiff line numberDiff line change
@@ -724,11 +724,22 @@ def run_case_inner(self, testcase: DataDrivenTestCase) -> None:
724724

725725
def parse_flags(self, program_text: str, extra: list[str]) -> Options:
726726
flags = re.search("# flags: (.*)$", program_text, flags=re.MULTILINE)
727+
pyversion = None
727728
if flags:
728729
flag_list = flags.group(1).split()
730+
for i, flag in enumerate(flag_list):
731+
if flag.startswith("--python-version="):
732+
pyversion = flag.split("=", 1)[1]
733+
del flag_list[i]
734+
break
729735
else:
730736
flag_list = []
731737
options = parse_options(flag_list + extra)
738+
if pyversion:
739+
# A hack to allow testing old python versions with new language constructs
740+
# This should be rarely used in general as stubgen output should not be version-specific
741+
major, minor = pyversion.split(".", 1)
742+
options.pyversion = (int(major), int(minor))
732743
if "--verbose" not in flag_list:
733744
options.quiet = True
734745
else:

test-data/unit/stubgen.test

+182
Original file line numberDiff line numberDiff line change
@@ -3512,3 +3512,185 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...
35123512

35133513
class X(_Incomplete): ...
35143514
class Y(_Incomplete): ...
3515+
3516+
[case testDataclass]
3517+
import dataclasses
3518+
import dataclasses as dcs
3519+
from dataclasses import dataclass, InitVar, KW_ONLY
3520+
from dataclasses import dataclass as dc
3521+
from typing import ClassVar
3522+
3523+
@dataclasses.dataclass
3524+
class X:
3525+
a: int
3526+
b: str = "hello"
3527+
c: ClassVar
3528+
d: ClassVar = 200
3529+
f: list[int] = field(init=False, default_factory=list)
3530+
g: int = field(default=2, kw_only=True)
3531+
_: KW_ONLY
3532+
h: int = 1
3533+
i: InitVar[str]
3534+
j: InitVar = 100
3535+
non_field = None
3536+
3537+
@dcs.dataclass
3538+
class Y: ...
3539+
3540+
@dataclass
3541+
class Z: ...
3542+
3543+
@dc
3544+
class W: ...
3545+
3546+
@dataclass(init=False, repr=False)
3547+
class V: ...
3548+
3549+
[out]
3550+
import dataclasses
3551+
import dataclasses as dcs
3552+
from dataclasses import InitVar, KW_ONLY, dataclass, dataclass as dc
3553+
from typing import ClassVar
3554+
3555+
@dataclasses.dataclass
3556+
class X:
3557+
a: int
3558+
b: str = ...
3559+
c: ClassVar
3560+
d: ClassVar = ...
3561+
f: list[int] = ...
3562+
g: int = ...
3563+
_: KW_ONLY
3564+
h: int = ...
3565+
i: InitVar[str]
3566+
j: InitVar = ...
3567+
non_field = ...
3568+
3569+
@dcs.dataclass
3570+
class Y: ...
3571+
@dataclass
3572+
class Z: ...
3573+
@dc
3574+
class W: ...
3575+
@dataclass(init=False, repr=False)
3576+
class V: ...
3577+
3578+
[case testDataclass_semanal]
3579+
from dataclasses import dataclass, InitVar
3580+
from typing import ClassVar
3581+
3582+
@dataclass
3583+
class X:
3584+
a: int
3585+
b: str = "hello"
3586+
c: ClassVar
3587+
d: ClassVar = 200
3588+
f: list[int] = field(init=False, default_factory=list)
3589+
g: int = field(default=2, kw_only=True)
3590+
h: int = 1
3591+
i: InitVar[str]
3592+
j: InitVar = 100
3593+
non_field = None
3594+
3595+
@dataclass(init=False, repr=False, frozen=True)
3596+
class Y: ...
3597+
3598+
[out]
3599+
from dataclasses import InitVar, dataclass
3600+
from typing import ClassVar
3601+
3602+
@dataclass
3603+
class X:
3604+
a: int
3605+
b: str = ...
3606+
c: ClassVar
3607+
d: ClassVar = ...
3608+
f: list[int] = ...
3609+
g: int = ...
3610+
h: int = ...
3611+
i: InitVar[str]
3612+
j: InitVar = ...
3613+
non_field = ...
3614+
def __init__(self, a, b, f, g, h, i, j) -> None: ...
3615+
3616+
@dataclass(init=False, repr=False, frozen=True)
3617+
class Y: ...
3618+
3619+
[case testDataclassWithKwOnlyField_semanal]
3620+
# flags: --python-version=3.10
3621+
from dataclasses import dataclass, InitVar, KW_ONLY
3622+
from typing import ClassVar
3623+
3624+
@dataclass
3625+
class X:
3626+
a: int
3627+
b: str = "hello"
3628+
c: ClassVar
3629+
d: ClassVar = 200
3630+
f: list[int] = field(init=False, default_factory=list)
3631+
g: int = field(default=2, kw_only=True)
3632+
_: KW_ONLY
3633+
h: int = 1
3634+
i: InitVar[str]
3635+
j: InitVar = 100
3636+
non_field = None
3637+
3638+
@dataclass(init=False, repr=False, frozen=True)
3639+
class Y: ...
3640+
3641+
[out]
3642+
from dataclasses import InitVar, KW_ONLY, dataclass
3643+
from typing import ClassVar
3644+
3645+
@dataclass
3646+
class X:
3647+
a: int
3648+
b: str = ...
3649+
c: ClassVar
3650+
d: ClassVar = ...
3651+
f: list[int] = ...
3652+
g: int = ...
3653+
_: KW_ONLY
3654+
h: int = ...
3655+
i: InitVar[str]
3656+
j: InitVar = ...
3657+
non_field = ...
3658+
def __init__(self, a, b, f, g, *, h, i, j) -> None: ...
3659+
3660+
@dataclass(init=False, repr=False, frozen=True)
3661+
class Y: ...
3662+
3663+
[case testDataclassWithExplicitGeneratedMethodsOverrides_semanal]
3664+
from dataclasses import dataclass
3665+
3666+
@dataclass
3667+
class X:
3668+
a: int
3669+
def __init__(self, a: int, b: str = ...) -> None: ...
3670+
def __post_init__(self) -> None: ...
3671+
3672+
[out]
3673+
from dataclasses import dataclass
3674+
3675+
@dataclass
3676+
class X:
3677+
a: int
3678+
def __init__(self, a: int, b: str = ...) -> None: ...
3679+
def __post_init__(self) -> None: ...
3680+
3681+
[case testDataclassInheritsFromAny_semanal]
3682+
from dataclasses import dataclass
3683+
import missing
3684+
3685+
@dataclass
3686+
class X(missing.Base):
3687+
a: int
3688+
3689+
[out]
3690+
import missing
3691+
from dataclasses import dataclass
3692+
3693+
@dataclass
3694+
class X(missing.Base):
3695+
a: int
3696+
def __init__(self, *selfa_, a, **selfa__) -> None: ...

0 commit comments

Comments
 (0)