@@ -657,6 +657,7 @@ def __init__(
657
657
self .defined_names : set [str ] = set ()
658
658
# Short names of methods defined in the body of the current class
659
659
self .method_names : set [str ] = set ()
660
+ self .processing_dataclass = False
660
661
661
662
def visit_mypy_file (self , o : MypyFile ) -> None :
662
663
self .module = o .fullname # Current module being processed
@@ -706,6 +707,12 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
706
707
self .clear_decorators ()
707
708
708
709
def visit_func_def (self , o : FuncDef ) -> None :
710
+ is_dataclass_generated = (
711
+ self .analyzed and self .processing_dataclass and o .info .names [o .name ].plugin_generated
712
+ )
713
+ if is_dataclass_generated and o .name != "__init__" :
714
+ # Skip methods generated by the @dataclass decorator (except for __init__)
715
+ return
709
716
if (
710
717
self .is_private_name (o .name , o .fullname )
711
718
or self .is_not_in_all (o .name )
@@ -771,6 +778,12 @@ def visit_func_def(self, o: FuncDef) -> None:
771
778
else :
772
779
arg = name + annotation
773
780
args .append (arg )
781
+ if o .name == "__init__" and is_dataclass_generated and "**" in args :
782
+ # The dataclass plugin generates invalid nameless "*" and "**" arguments
783
+ new_name = "" .join (a .split (":" , 1 )[0 ] for a in args ).replace ("*" , "" )
784
+ args [args .index ("*" )] = f"*{ new_name } _" # this name is guaranteed to be unique
785
+ args [args .index ("**" )] = f"**{ new_name } __" # same here
786
+
774
787
retname = None
775
788
if o .name != "__init__" and isinstance (o .unanalyzed_type , CallableType ):
776
789
if isinstance (get_proper_type (o .unanalyzed_type .ret_type ), AnyType ):
@@ -899,6 +912,9 @@ def visit_class_def(self, o: ClassDef) -> None:
899
912
if not self ._indent and self ._state != EMPTY :
900
913
sep = len (self ._output )
901
914
self .add ("\n " )
915
+ decorators = self .get_class_decorators (o )
916
+ for d in decorators :
917
+ self .add (f"{ self ._indent } @{ d } \n " )
902
918
self .add (f"{ self ._indent } class { o .name } " )
903
919
self .record_name (o .name )
904
920
base_types = self .get_base_types (o )
@@ -934,6 +950,7 @@ def visit_class_def(self, o: ClassDef) -> None:
934
950
else :
935
951
self ._state = CLASS
936
952
self .method_names = set ()
953
+ self .processing_dataclass = False
937
954
self ._current_class = None
938
955
939
956
def get_base_types (self , cdef : ClassDef ) -> list [str ]:
@@ -979,6 +996,21 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
979
996
base_types .append (f"{ name } ={ value .accept (p )} " )
980
997
return base_types
981
998
999
+ def get_class_decorators (self , cdef : ClassDef ) -> list [str ]:
1000
+ decorators : list [str ] = []
1001
+ p = AliasPrinter (self )
1002
+ for d in cdef .decorators :
1003
+ if self .is_dataclass (d ):
1004
+ decorators .append (d .accept (p ))
1005
+ self .import_tracker .require_name (get_qualified_name (d ))
1006
+ self .processing_dataclass = True
1007
+ return decorators
1008
+
1009
+ def is_dataclass (self , expr : Expression ) -> bool :
1010
+ if isinstance (expr , CallExpr ):
1011
+ expr = expr .callee
1012
+ return self .get_fullname (expr ) == "dataclasses.dataclass"
1013
+
982
1014
def visit_block (self , o : Block ) -> None :
983
1015
# Unreachable statements may be partially uninitialized and that may
984
1016
# cause trouble.
@@ -1336,19 +1368,30 @@ def get_init(
1336
1368
# Final without type argument is invalid in stubs.
1337
1369
final_arg = self .get_str_type_of_node (rvalue )
1338
1370
typename += f"[{ final_arg } ]"
1371
+ elif self .processing_dataclass :
1372
+ # attribute without annotation is not a dataclass field, don't add annotation.
1373
+ return f"{ self ._indent } { lvalue } = ...\n "
1339
1374
else :
1340
1375
typename = self .get_str_type_of_node (rvalue )
1341
1376
initializer = self .get_assign_initializer (rvalue )
1342
1377
return f"{ self ._indent } { lvalue } : { typename } { initializer } \n "
1343
1378
1344
1379
def get_assign_initializer (self , rvalue : Expression ) -> str :
1345
1380
"""Does this rvalue need some special initializer value?"""
1346
- if self ._current_class and self ._current_class .info :
1347
- # Current rules
1348
- # 1. Return `...` if we are dealing with `NamedTuple` and it has an existing default value
1349
- if self ._current_class .info .is_named_tuple and not isinstance (rvalue , TempNode ):
1350
- return " = ..."
1351
- # TODO: support other possible cases, where initializer is important
1381
+ if not self ._current_class :
1382
+ return ""
1383
+ # Current rules
1384
+ # 1. Return `...` if we are dealing with `NamedTuple` or `dataclass` field and
1385
+ # it has an existing default value
1386
+ if (
1387
+ self ._current_class .info
1388
+ and self ._current_class .info .is_named_tuple
1389
+ and not isinstance (rvalue , TempNode )
1390
+ ):
1391
+ return " = ..."
1392
+ if self .processing_dataclass and not (isinstance (rvalue , TempNode ) and rvalue .no_rhs ):
1393
+ return " = ..."
1394
+ # TODO: support other possible cases, where initializer is important
1352
1395
1353
1396
# By default, no initializer is required:
1354
1397
return ""
@@ -1410,6 +1453,8 @@ def is_private_name(self, name: str, fullname: str | None = None) -> bool:
1410
1453
return False
1411
1454
if fullname in EXTRA_EXPORTED :
1412
1455
return False
1456
+ if name == "_" :
1457
+ return False
1413
1458
return name .startswith ("_" ) and (not name .endswith ("__" ) or name in IGNORED_DUNDERS )
1414
1459
1415
1460
def is_private_member (self , fullname : str ) -> bool :
0 commit comments