From faae3c0f82f87d1a217ac6c1135f037b898d7cac Mon Sep 17 00:00:00 2001
From: Ali Hamdan <ali.hamdan.dev@gmail.com>
Date: Sat, 8 Jul 2023 12:04:03 +0200
Subject: [PATCH 1/3] stubgen: generate valid dataclass stubs

Fixes #12441
---
 mypy/stubgen.py             | 30 ++++++++++++++++
 test-data/unit/stubgen.test | 70 +++++++++++++++++++++++++++++++++++++
 2 files changed, 100 insertions(+)

diff --git a/mypy/stubgen.py b/mypy/stubgen.py
index 229559ac8120..9758e4c14973 100755
--- a/mypy/stubgen.py
+++ b/mypy/stubgen.py
@@ -101,6 +101,7 @@
     OverloadedFuncDef,
     Statement,
     StrExpr,
+    TempNode,
     TupleExpr,
     TypeInfo,
     UnaryExpr,
@@ -650,6 +651,7 @@ def __init__(
         self.defined_names: set[str] = set()
         # Short names of methods defined in the body of the current class
         self.method_names: set[str] = set()
+        self.processing_dataclass = False
 
     def visit_mypy_file(self, o: MypyFile) -> None:
         self.module = o.fullname  # Current module being processed
@@ -699,6 +701,9 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
                 self.clear_decorators()
 
     def visit_func_def(self, o: FuncDef) -> None:
+        if self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated:
+            # Skip methods generated by the @dataclass decorator
+            return
         if (
             self.is_private_name(o.name, o.fullname)
             or self.is_not_in_all(o.name)
@@ -890,6 +895,9 @@ def visit_class_def(self, o: ClassDef) -> None:
         if not self._indent and self._state != EMPTY:
             sep = len(self._output)
             self.add("\n")
+        decorators = self.get_class_decorators(o)
+        for d in decorators:
+            self.add(f"{self._indent}@{d}\n")
         self.add(f"{self._indent}class {o.name}")
         self.record_name(o.name)
         base_types = self.get_base_types(o)
@@ -921,6 +929,7 @@ def visit_class_def(self, o: ClassDef) -> None:
         else:
             self._state = CLASS
         self.method_names = set()
+        self.processing_dataclass = False
 
     def get_base_types(self, cdef: ClassDef) -> list[str]:
         """Get list of base classes for a class."""
@@ -967,6 +976,21 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
             base_types.append(f"{name}={value.accept(p)}")
         return base_types
 
+    def get_class_decorators(self, cdef: ClassDef) -> list[str]:
+        decorators: list[str] = []
+        p = AliasPrinter(self)
+        for d in cdef.decorators:
+            if self.is_dataclass(d):
+                decorators.append(d.accept(p))
+                self.import_tracker.require_name(get_qualified_name(d))
+                self.processing_dataclass = True
+        return decorators
+
+    def is_dataclass(self, expr: Expression) -> bool:
+        if isinstance(expr, CallExpr):
+            expr = expr.callee
+        return self.get_fullname(expr) == "dataclasses.dataclass"
+
     def visit_block(self, o: Block) -> None:
         # Unreachable statements may be partially uninitialized and that may
         # cause trouble.
@@ -1323,8 +1347,14 @@ def get_init(
                 # Final without type argument is invalid in stubs.
                 final_arg = self.get_str_type_of_node(rvalue)
                 typename += f"[{final_arg}]"
+        elif self.processing_dataclass:
+            # attribute without annotation is not a dataclass field, don't add annotation.
+            return f"{self._indent}{lvalue} = ...\n"
         else:
             typename = self.get_str_type_of_node(rvalue)
+        if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
+            # dataclass field with default value, keep the initializer.
+            return f"{self._indent}{lvalue}: {typename} = ...\n"
         return f"{self._indent}{lvalue}: {typename}\n"
 
     def add(self, string: str) -> None:
diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test
index e1818dc4c4bc..1e73df3ebb60 100644
--- a/test-data/unit/stubgen.test
+++ b/test-data/unit/stubgen.test
@@ -3317,3 +3317,73 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...
 
 class X(_Incomplete): ...
 class Y(_Incomplete): ...
+
+
+[case testDataclass]
+import dataclasses
+import dataclasses as dcs
+from dataclasses import dataclass
+from dataclasses import dataclass as dc
+
+@dataclasses.dataclass
+class X:
+    a: int
+    b: str = "hello"
+    non_field = None
+
+@dcs.dataclass
+class Y: ...
+
+@dataclass
+class Z: ...
+
+@dc
+class W: ...
+
+[out]
+import dataclasses
+import dataclasses as dcs
+from dataclasses import dataclass, dataclass as dc
+
+@dataclasses.dataclass
+class X:
+    a: int
+    b: str = ...
+    non_field = ...
+
+@dcs.dataclass
+class Y: ...
+@dataclass
+class Z: ...
+@dc
+class W: ...
+
+[case testDataclassWithKeywords]
+from dataclasses import dataclass
+
+@dataclass(init=False)
+class X: ...
+
+[out]
+from dataclasses import dataclass
+
+@dataclass(init=False)
+class X: ...
+
+[case testDataclassWithExplicitGeneratedMethodsOverrides_semanal]
+from dataclasses import dataclass
+
+@dataclass
+class X:
+    a: int
+    def __init__(self, a: int, b: str = ...) -> None: ...
+    def __post_init__(self) -> None: ...
+
+[out]
+from dataclasses import dataclass
+
+@dataclass
+class X:
+    a: int
+    def __init__(self, a: int, b: str = ...) -> None: ...
+    def __post_init__(self) -> None: ...

From 691b8e9550b52fb096c548eca0ac65bc7dc891f9 Mon Sep 17 00:00:00 2001
From: Ali Hamdan <ali.hamdan.dev@gmail.com>
Date: Fri, 14 Jul 2023 11:16:38 +0200
Subject: [PATCH 2/3] Keep `__init__` in dataclasses and add more tests

We cannot safely remove `__init__` and depend on the plugin
because its signature depends on dataclass field assignments to
`dataclasses.field` and these assignments are not included in the stub
---
 mypy/stubgen.py             | 15 ++++++-
 test-data/unit/stubgen.test | 90 ++++++++++++++++++++++++++++++++-----
 2 files changed, 93 insertions(+), 12 deletions(-)

diff --git a/mypy/stubgen.py b/mypy/stubgen.py
index 9758e4c14973..6d5134ec9ec4 100755
--- a/mypy/stubgen.py
+++ b/mypy/stubgen.py
@@ -701,8 +701,11 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
                 self.clear_decorators()
 
     def visit_func_def(self, o: FuncDef) -> None:
-        if self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated:
-            # Skip methods generated by the @dataclass decorator
+        is_dataclass_generated = (
+            self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated
+        )
+        if is_dataclass_generated and o.name != "__init__":
+            # Skip methods generated by the @dataclass decorator (except for __init__)
             return
         if (
             self.is_private_name(o.name, o.fullname)
@@ -769,6 +772,12 @@ def visit_func_def(self, o: FuncDef) -> None:
             else:
                 arg = name + annotation
             args.append(arg)
+        if o.name == "__init__" and is_dataclass_generated and "**" in args:
+            # The dataclass plugin generates invalid nameless "*" and "**" arguments
+            new_name = "".join(a.split(":", 1)[0] for a in args).replace("*", "")
+            args[args.index("*")] = f"*{new_name}_"  # this name is guaranteed to be unique
+            args[args.index("**")] = f"**{new_name}__"  # same here
+
         retname = None
         if o.name != "__init__" and isinstance(o.unanalyzed_type, CallableType):
             if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType):
@@ -1413,6 +1422,8 @@ def is_private_name(self, name: str, fullname: str | None = None) -> bool:
             return False
         if fullname in EXTRA_EXPORTED:
             return False
+        if name == "_":
+            return False
         return name.startswith("_") and (not name.endswith("__") or name in IGNORED_DUNDERS)
 
     def is_private_member(self, fullname: str) -> bool:
diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test
index 1e73df3ebb60..5fed8cd2e61f 100644
--- a/test-data/unit/stubgen.test
+++ b/test-data/unit/stubgen.test
@@ -3318,17 +3318,25 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...
 class X(_Incomplete): ...
 class Y(_Incomplete): ...
 
-
 [case testDataclass]
 import dataclasses
 import dataclasses as dcs
-from dataclasses import dataclass
+from dataclasses import dataclass, InitVar, KW_ONLY
 from dataclasses import dataclass as dc
+from typing import ClassVar
 
 @dataclasses.dataclass
 class X:
     a: int
     b: str = "hello"
+    c: ClassVar
+    d: ClassVar = 200
+    f: list[int] = field(init=False, default_factory=list)
+    g: int = field(default=2, kw_only=True)
+    _: KW_ONLY
+    h: int = 1
+    i: InitVar[str]
+    j: InitVar = 100
     non_field = None
 
 @dcs.dataclass
@@ -3340,15 +3348,27 @@ class Z: ...
 @dc
 class W: ...
 
+@dataclass(init=False, repr=False)
+class V: ...
+
 [out]
 import dataclasses
 import dataclasses as dcs
-from dataclasses import dataclass, dataclass as dc
+from dataclasses import InitVar, KW_ONLY, dataclass, dataclass as dc
+from typing import ClassVar
 
 @dataclasses.dataclass
 class X:
     a: int
     b: str = ...
+    c: ClassVar
+    d: ClassVar = ...
+    f: list[int] = ...
+    g: int = ...
+    _: KW_ONLY
+    h: int = ...
+    i: InitVar[str]
+    j: InitVar = ...
     non_field = ...
 
 @dcs.dataclass
@@ -3357,18 +3377,51 @@ class Y: ...
 class Z: ...
 @dc
 class W: ...
+@dataclass(init=False, repr=False)
+class V: ...
 
-[case testDataclassWithKeywords]
-from dataclasses import dataclass
+[case testDataclass_semanal]
+from dataclasses import dataclass, InitVar, KW_ONLY
+from typing import ClassVar
 
-@dataclass(init=False)
-class X: ...
+@dataclass
+class X:
+    a: int
+    b: str = "hello"
+    c: ClassVar
+    d: ClassVar = 200
+    f: list[int] = field(init=False, default_factory=list)
+    g: int = field(default=2, kw_only=True)
+    _: KW_ONLY
+    h: int = 1
+    i: InitVar[str]
+    j: InitVar = 100
+    non_field = None
+
+@dataclass(init=False, repr=False, frozen=True)
+class Y: ...
 
 [out]
-from dataclasses import dataclass
+from dataclasses import InitVar, KW_ONLY, dataclass
+from typing import ClassVar
 
-@dataclass(init=False)
-class X: ...
+@dataclass
+class X:
+    a: int
+    b: str = ...
+    c: ClassVar
+    d: ClassVar = ...
+    f: list[int] = ...
+    g: int = ...
+    _: KW_ONLY
+    h: int = ...
+    i: InitVar[str]
+    j: InitVar = ...
+    non_field = ...
+    def __init__(self, a, b, f, g, *, h, i, j) -> None: ...
+
+@dataclass(init=False, repr=False, frozen=True)
+class Y: ...
 
 [case testDataclassWithExplicitGeneratedMethodsOverrides_semanal]
 from dataclasses import dataclass
@@ -3387,3 +3440,20 @@ class X:
     a: int
     def __init__(self, a: int, b: str = ...) -> None: ...
     def __post_init__(self) -> None: ...
+
+[case testDataclassInheritsFromAny_semanal]
+from dataclasses import dataclass
+import missing
+
+@dataclass
+class X(missing.Base):
+    a: int
+
+[out]
+import missing
+from dataclasses import dataclass
+
+@dataclass
+class X(missing.Base):
+    a: int
+    def __init__(self, *selfa_, a, **selfa__) -> None: ...

From be3028ea5558572fd644613efd1bf890d6f18def Mon Sep 17 00:00:00 2001
From: Ali Hamdan <ali.hamdan.dev@gmail.com>
Date: Sat, 26 Aug 2023 22:07:24 +0200
Subject: [PATCH 3/3] Fix tests running on older python versions

---
 mypy/test/teststubgen.py    | 11 ++++++++++
 test-data/unit/stubgen.test | 42 +++++++++++++++++++++++++++++++++++++
 2 files changed, 53 insertions(+)

diff --git a/mypy/test/teststubgen.py b/mypy/test/teststubgen.py
index 79d380785a39..7e30515ac892 100644
--- a/mypy/test/teststubgen.py
+++ b/mypy/test/teststubgen.py
@@ -724,11 +724,22 @@ def run_case_inner(self, testcase: DataDrivenTestCase) -> None:
 
     def parse_flags(self, program_text: str, extra: list[str]) -> Options:
         flags = re.search("# flags: (.*)$", program_text, flags=re.MULTILINE)
+        pyversion = None
         if flags:
             flag_list = flags.group(1).split()
+            for i, flag in enumerate(flag_list):
+                if flag.startswith("--python-version="):
+                    pyversion = flag.split("=", 1)[1]
+                    del flag_list[i]
+                    break
         else:
             flag_list = []
         options = parse_options(flag_list + extra)
+        if pyversion:
+            # A hack to allow testing old python versions with new language constructs
+            # This should be rarely used in general as stubgen output should not be version-specific
+            major, minor = pyversion.split(".", 1)
+            options.pyversion = (int(major), int(minor))
         if "--verbose" not in flag_list:
             options.quiet = True
         else:
diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test
index 64b759d0c960..828680fadcf2 100644
--- a/test-data/unit/stubgen.test
+++ b/test-data/unit/stubgen.test
@@ -3576,6 +3576,48 @@ class W: ...
 class V: ...
 
 [case testDataclass_semanal]
+from dataclasses import dataclass, InitVar
+from typing import ClassVar
+
+@dataclass
+class X:
+    a: int
+    b: str = "hello"
+    c: ClassVar
+    d: ClassVar = 200
+    f: list[int] = field(init=False, default_factory=list)
+    g: int = field(default=2, kw_only=True)
+    h: int = 1
+    i: InitVar[str]
+    j: InitVar = 100
+    non_field = None
+
+@dataclass(init=False, repr=False, frozen=True)
+class Y: ...
+
+[out]
+from dataclasses import InitVar, dataclass
+from typing import ClassVar
+
+@dataclass
+class X:
+    a: int
+    b: str = ...
+    c: ClassVar
+    d: ClassVar = ...
+    f: list[int] = ...
+    g: int = ...
+    h: int = ...
+    i: InitVar[str]
+    j: InitVar = ...
+    non_field = ...
+    def __init__(self, a, b, f, g, h, i, j) -> None: ...
+
+@dataclass(init=False, repr=False, frozen=True)
+class Y: ...
+
+[case testDataclassWithKwOnlyField_semanal]
+# flags: --python-version=3.10
 from dataclasses import dataclass, InitVar, KW_ONLY
 from typing import ClassVar