Skip to content

Commit c677b1c

Browse files
committed
Use a dict to keep track of TypedDict fields in semanal
Useful for python#7435
1 parent ac6151a commit c677b1c

File tree

1 file changed

+50
-52
lines changed

1 file changed

+50
-52
lines changed

mypy/semanal_typeddict.py

+50-52
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Final
5+
from typing import Collection, Final
66

77
from mypy import errorcodes as codes, message_registry
88
from mypy.errorcodes import ErrorCode
@@ -97,21 +97,23 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
9797
existing_info = None
9898
if isinstance(defn.analyzed, TypedDictExpr):
9999
existing_info = defn.analyzed.info
100+
101+
field_types: dict[str, Type] | None
100102
if (
101103
len(defn.base_type_exprs) == 1
102104
and isinstance(defn.base_type_exprs[0], RefExpr)
103105
and defn.base_type_exprs[0].fullname in TPDICT_NAMES
104106
):
105107
# Building a new TypedDict
106-
fields, types, statements, required_keys, readonly_keys = (
108+
field_types, statements, required_keys, readonly_keys = (
107109
self.analyze_typeddict_classdef_fields(defn)
108110
)
109-
if fields is None:
111+
if field_types is None:
110112
return True, None # Defer
111113
if self.api.is_func_scope() and "@" not in defn.name:
112114
defn.name += "@" + str(defn.line)
113115
info = self.build_typeddict_typeinfo(
114-
defn.name, fields, types, required_keys, readonly_keys, defn.line, existing_info
116+
defn.name, field_types, required_keys, readonly_keys, defn.line, existing_info
115117
)
116118
defn.analyzed = TypedDictExpr(info)
117119
defn.analyzed.line = defn.line
@@ -154,26 +156,24 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
154156
else:
155157
self.fail("All bases of a new TypedDict must be TypedDict types", defn)
156158

157-
keys: list[str] = []
158-
types = []
159+
field_types = {}
159160
required_keys = set()
160161
readonly_keys = set()
161162
# Iterate over bases in reverse order so that leftmost base class' keys take precedence
162163
for base in reversed(typeddict_bases):
163164
self.add_keys_and_types_from_base(
164-
base, keys, types, required_keys, readonly_keys, defn
165+
base, field_types, required_keys, readonly_keys, defn
165166
)
166-
(new_keys, new_types, new_statements, new_required_keys, new_readonly_keys) = (
167-
self.analyze_typeddict_classdef_fields(defn, keys)
167+
(new_field_types, new_statements, new_required_keys, new_readonly_keys) = (
168+
self.analyze_typeddict_classdef_fields(defn, oldfields=field_types)
168169
)
169-
if new_keys is None:
170+
if new_field_types is None:
170171
return True, None # Defer
171-
keys.extend(new_keys)
172-
types.extend(new_types)
172+
field_types.update(new_field_types)
173173
required_keys.update(new_required_keys)
174174
readonly_keys.update(new_readonly_keys)
175175
info = self.build_typeddict_typeinfo(
176-
defn.name, keys, types, required_keys, readonly_keys, defn.line, existing_info
176+
defn.name, field_types, required_keys, readonly_keys, defn.line, existing_info
177177
)
178178
defn.analyzed = TypedDictExpr(info)
179179
defn.analyzed.line = defn.line
@@ -184,8 +184,7 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
184184
def add_keys_and_types_from_base(
185185
self,
186186
base: Expression,
187-
keys: list[str],
188-
types: list[Type],
187+
field_types: dict[str, Type],
189188
required_keys: set[str],
190189
readonly_keys: set[str],
191190
ctx: Context,
@@ -224,10 +223,10 @@ def add_keys_and_types_from_base(
224223
with state.strict_optional_set(self.options.strict_optional):
225224
valid_items = self.map_items_to_base(valid_items, tvars, base_args)
226225
for key in base_items:
227-
if key in keys:
226+
if key in field_types:
228227
self.fail(TYPEDDICT_OVERRIDE_MERGE.format(key), ctx)
229-
keys.extend(valid_items.keys())
230-
types.extend(valid_items.values())
228+
229+
field_types.update(valid_items)
231230
required_keys.update(base_typed_dict.required_keys)
232231
readonly_keys.update(base_typed_dict.readonly_keys)
233232

@@ -280,8 +279,8 @@ def map_items_to_base(
280279
return mapped_items
281280

282281
def analyze_typeddict_classdef_fields(
283-
self, defn: ClassDef, oldfields: list[str] | None = None
284-
) -> tuple[list[str] | None, list[Type], list[Statement], set[str], set[str]]:
282+
self, defn: ClassDef, oldfields: Collection[str] | None = None
283+
) -> tuple[dict[str, Type] | None, list[Statement], set[str], set[str]]:
285284
"""Analyze fields defined in a TypedDict class definition.
286285
287286
This doesn't consider inherited fields (if any). Also consider totality,
@@ -294,9 +293,19 @@ def analyze_typeddict_classdef_fields(
294293
part of a TypedDict definition
295294
* Set of required keys
296295
"""
297-
fields: list[str] = []
298-
types: list[Type] = []
296+
fields: dict[str, Type] = {}
297+
readonly_keys = set()
298+
required_keys = set()
299299
statements: list[Statement] = []
300+
301+
total: bool | None = True
302+
for key in defn.keywords:
303+
if key == "total":
304+
total = require_bool_literal_argument(self.api, defn.keywords["total"], "total", True)
305+
continue
306+
for_function = ' for "__init_subclass__" of "TypedDict"'
307+
self.msg.unexpected_keyword_argument_for_function(for_function, key, defn)
308+
300309
for stmt in defn.defs.body:
301310
if not isinstance(stmt, AssignmentStmt):
302311
# Still allow pass or ... (for empty TypedDict's) and docstrings
@@ -320,10 +329,11 @@ def analyze_typeddict_classdef_fields(
320329
self.fail(f'Duplicate TypedDict key "{name}"', stmt)
321330
continue
322331
# Append stmt, name, and type in this case...
323-
fields.append(name)
324332
statements.append(stmt)
333+
334+
field_type: Type
325335
if stmt.unanalyzed_type is None:
326-
types.append(AnyType(TypeOfAny.unannotated))
336+
field_type = AnyType(TypeOfAny.unannotated)
327337
else:
328338
analyzed = self.api.anal_type(
329339
stmt.unanalyzed_type,
@@ -333,38 +343,27 @@ def analyze_typeddict_classdef_fields(
333343
prohibit_special_class_field_types="TypedDict",
334344
)
335345
if analyzed is None:
336-
return None, [], [], set(), set() # Need to defer
337-
types.append(analyzed)
346+
return None, [], set(), set() # Need to defer
347+
field_type = analyzed
338348
if not has_placeholder(analyzed):
339349
stmt.type = self.extract_meta_info(analyzed, stmt)[0]
350+
351+
field_type, required, readonly = self.extract_meta_info(field_type)
352+
fields[name] = field_type
353+
354+
if (total or required is True) and required is not False:
355+
required_keys.add(name)
356+
if readonly:
357+
readonly_keys.add(name)
358+
340359
# ...despite possible minor failures that allow further analysis.
341360
if stmt.type is None or hasattr(stmt, "new_syntax") and not stmt.new_syntax:
342361
self.fail(TPDICT_CLASS_ERROR, stmt)
343362
elif not isinstance(stmt.rvalue, TempNode):
344363
# x: int assigns rvalue to TempNode(AnyType())
345364
self.fail("Right hand side values are not supported in TypedDict", stmt)
346-
total: bool | None = True
347-
if "total" in defn.keywords:
348-
total = require_bool_literal_argument(self.api, defn.keywords["total"], "total", True)
349-
if defn.keywords and defn.keywords.keys() != {"total"}:
350-
for_function = ' for "__init_subclass__" of "TypedDict"'
351-
for key in defn.keywords:
352-
if key == "total":
353-
continue
354-
self.msg.unexpected_keyword_argument_for_function(for_function, key, defn)
355-
356-
res_types = []
357-
readonly_keys = set()
358-
required_keys = set()
359-
for field, t in zip(fields, types):
360-
typ, required, readonly = self.extract_meta_info(t)
361-
res_types.append(typ)
362-
if (total or required is True) and required is not False:
363-
required_keys.add(field)
364-
if readonly:
365-
readonly_keys.add(field)
366365

367-
return fields, res_types, statements, required_keys, readonly_keys
366+
return fields, statements, required_keys, readonly_keys
368367

369368
def extract_meta_info(
370369
self, typ: Type, context: Context | None = None
@@ -433,7 +432,7 @@ def check_typeddict(
433432
name += "@" + str(call.line)
434433
else:
435434
name = var_name = "TypedDict@" + str(call.line)
436-
info = self.build_typeddict_typeinfo(name, [], [], set(), set(), call.line, None)
435+
info = self.build_typeddict_typeinfo(name, {}, set(), set(), call.line, None)
437436
else:
438437
if var_name is not None and name != var_name:
439438
self.fail(
@@ -473,7 +472,7 @@ def check_typeddict(
473472
if isinstance(node.analyzed, TypedDictExpr):
474473
existing_info = node.analyzed.info
475474
info = self.build_typeddict_typeinfo(
476-
name, items, types, required_keys, readonly_keys, call.line, existing_info
475+
name, dict(zip(items, types)), required_keys, readonly_keys, call.line, existing_info
477476
)
478477
info.line = node.line
479478
# Store generated TypeInfo under both names, see semanal_namedtuple for more details.
@@ -578,8 +577,7 @@ def fail_typeddict_arg(
578577
def build_typeddict_typeinfo(
579578
self,
580579
name: str,
581-
items: list[str],
582-
types: list[Type],
580+
item_types: dict[str, Type],
583581
required_keys: set[str],
584582
readonly_keys: set[str],
585583
line: int,
@@ -594,7 +592,7 @@ def build_typeddict_typeinfo(
594592
assert fallback is not None
595593
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
596594
typeddict_type = TypedDictType(
597-
dict(zip(items, types)), required_keys, readonly_keys, fallback
595+
item_types, required_keys, readonly_keys, fallback
598596
)
599597
if info.special_alias and has_placeholder(info.special_alias.target):
600598
self.api.process_placeholder(

0 commit comments

Comments
 (0)