Skip to content

Commit 4dd1d21

Browse files
hauntsaninjax612skm
authored andcommitted
Use a dict to keep track of TypedDict fields in semanal (python#18369)
Useful for python#7435
1 parent 91b4784 commit 4dd1d21

File tree

1 file changed

+58
-55
lines changed

1 file changed

+58
-55
lines changed

mypy/semanal_typeddict.py

+58-55
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from collections.abc import Collection
56
from typing import Final
67

78
from mypy import errorcodes as codes, message_registry
@@ -97,21 +98,23 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
9798
existing_info = None
9899
if isinstance(defn.analyzed, TypedDictExpr):
99100
existing_info = defn.analyzed.info
101+
102+
field_types: dict[str, Type] | None
100103
if (
101104
len(defn.base_type_exprs) == 1
102105
and isinstance(defn.base_type_exprs[0], RefExpr)
103106
and defn.base_type_exprs[0].fullname in TPDICT_NAMES
104107
):
105108
# Building a new TypedDict
106-
fields, types, statements, required_keys, readonly_keys = (
109+
field_types, statements, required_keys, readonly_keys = (
107110
self.analyze_typeddict_classdef_fields(defn)
108111
)
109-
if fields is None:
112+
if field_types is None:
110113
return True, None # Defer
111114
if self.api.is_func_scope() and "@" not in defn.name:
112115
defn.name += "@" + str(defn.line)
113116
info = self.build_typeddict_typeinfo(
114-
defn.name, fields, types, required_keys, readonly_keys, defn.line, existing_info
117+
defn.name, field_types, required_keys, readonly_keys, defn.line, existing_info
115118
)
116119
defn.analyzed = TypedDictExpr(info)
117120
defn.analyzed.line = defn.line
@@ -154,26 +157,24 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
154157
else:
155158
self.fail("All bases of a new TypedDict must be TypedDict types", defn)
156159

157-
keys: list[str] = []
158-
types = []
160+
field_types = {}
159161
required_keys = set()
160162
readonly_keys = set()
161163
# Iterate over bases in reverse order so that leftmost base class' keys take precedence
162164
for base in reversed(typeddict_bases):
163165
self.add_keys_and_types_from_base(
164-
base, keys, types, required_keys, readonly_keys, defn
166+
base, field_types, required_keys, readonly_keys, defn
165167
)
166-
(new_keys, new_types, new_statements, new_required_keys, new_readonly_keys) = (
167-
self.analyze_typeddict_classdef_fields(defn, keys)
168+
(new_field_types, new_statements, new_required_keys, new_readonly_keys) = (
169+
self.analyze_typeddict_classdef_fields(defn, oldfields=field_types)
168170
)
169-
if new_keys is None:
171+
if new_field_types is None:
170172
return True, None # Defer
171-
keys.extend(new_keys)
172-
types.extend(new_types)
173+
field_types.update(new_field_types)
173174
required_keys.update(new_required_keys)
174175
readonly_keys.update(new_readonly_keys)
175176
info = self.build_typeddict_typeinfo(
176-
defn.name, keys, types, required_keys, readonly_keys, defn.line, existing_info
177+
defn.name, field_types, required_keys, readonly_keys, defn.line, existing_info
177178
)
178179
defn.analyzed = TypedDictExpr(info)
179180
defn.analyzed.line = defn.line
@@ -184,8 +185,7 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
184185
def add_keys_and_types_from_base(
185186
self,
186187
base: Expression,
187-
keys: list[str],
188-
types: list[Type],
188+
field_types: dict[str, Type],
189189
required_keys: set[str],
190190
readonly_keys: set[str],
191191
ctx: Context,
@@ -224,10 +224,10 @@ def add_keys_and_types_from_base(
224224
with state.strict_optional_set(self.options.strict_optional):
225225
valid_items = self.map_items_to_base(valid_items, tvars, base_args)
226226
for key in base_items:
227-
if key in keys:
227+
if key in field_types:
228228
self.fail(TYPEDDICT_OVERRIDE_MERGE.format(key), ctx)
229-
keys.extend(valid_items.keys())
230-
types.extend(valid_items.values())
229+
230+
field_types.update(valid_items)
231231
required_keys.update(base_typed_dict.required_keys)
232232
readonly_keys.update(base_typed_dict.readonly_keys)
233233

@@ -280,23 +280,34 @@ def map_items_to_base(
280280
return mapped_items
281281

282282
def analyze_typeddict_classdef_fields(
283-
self, defn: ClassDef, oldfields: list[str] | None = None
284-
) -> tuple[list[str] | None, list[Type], list[Statement], set[str], set[str]]:
283+
self, defn: ClassDef, oldfields: Collection[str] | None = None
284+
) -> tuple[dict[str, Type] | None, list[Statement], set[str], set[str]]:
285285
"""Analyze fields defined in a TypedDict class definition.
286286
287287
This doesn't consider inherited fields (if any). Also consider totality,
288288
if given.
289289
290290
Return tuple with these items:
291-
* List of keys (or None if found an incomplete reference --> deferral)
292-
* List of types for each key
291+
* Dict of key -> type (or None if found an incomplete reference -> deferral)
293292
* List of statements from defn.defs.body that are legally allowed to be a
294293
part of a TypedDict definition
295294
* Set of required keys
296295
"""
297-
fields: list[str] = []
298-
types: list[Type] = []
296+
fields: dict[str, Type] = {}
297+
readonly_keys = set[str]()
298+
required_keys = set[str]()
299299
statements: list[Statement] = []
300+
301+
total: bool | None = True
302+
for key in defn.keywords:
303+
if key == "total":
304+
total = require_bool_literal_argument(
305+
self.api, defn.keywords["total"], "total", True
306+
)
307+
continue
308+
for_function = ' for "__init_subclass__" of "TypedDict"'
309+
self.msg.unexpected_keyword_argument_for_function(for_function, key, defn)
310+
300311
for stmt in defn.defs.body:
301312
if not isinstance(stmt, AssignmentStmt):
302313
# Still allow pass or ... (for empty TypedDict's) and docstrings
@@ -320,10 +331,11 @@ def analyze_typeddict_classdef_fields(
320331
self.fail(f'Duplicate TypedDict key "{name}"', stmt)
321332
continue
322333
# Append stmt, name, and type in this case...
323-
fields.append(name)
324334
statements.append(stmt)
335+
336+
field_type: Type
325337
if stmt.unanalyzed_type is None:
326-
types.append(AnyType(TypeOfAny.unannotated))
338+
field_type = AnyType(TypeOfAny.unannotated)
327339
else:
328340
analyzed = self.api.anal_type(
329341
stmt.unanalyzed_type,
@@ -333,38 +345,27 @@ def analyze_typeddict_classdef_fields(
333345
prohibit_special_class_field_types="TypedDict",
334346
)
335347
if analyzed is None:
336-
return None, [], [], set(), set() # Need to defer
337-
types.append(analyzed)
348+
return None, [], set(), set() # Need to defer
349+
field_type = analyzed
338350
if not has_placeholder(analyzed):
339351
stmt.type = self.extract_meta_info(analyzed, stmt)[0]
352+
353+
field_type, required, readonly = self.extract_meta_info(field_type)
354+
fields[name] = field_type
355+
356+
if (total or required is True) and required is not False:
357+
required_keys.add(name)
358+
if readonly:
359+
readonly_keys.add(name)
360+
340361
# ...despite possible minor failures that allow further analysis.
341362
if stmt.type is None or hasattr(stmt, "new_syntax") and not stmt.new_syntax:
342363
self.fail(TPDICT_CLASS_ERROR, stmt)
343364
elif not isinstance(stmt.rvalue, TempNode):
344365
# x: int assigns rvalue to TempNode(AnyType())
345366
self.fail("Right hand side values are not supported in TypedDict", stmt)
346-
total: bool | None = True
347-
if "total" in defn.keywords:
348-
total = require_bool_literal_argument(self.api, defn.keywords["total"], "total", True)
349-
if defn.keywords and defn.keywords.keys() != {"total"}:
350-
for_function = ' for "__init_subclass__" of "TypedDict"'
351-
for key in defn.keywords:
352-
if key == "total":
353-
continue
354-
self.msg.unexpected_keyword_argument_for_function(for_function, key, defn)
355367

356-
res_types = []
357-
readonly_keys = set()
358-
required_keys = set()
359-
for field, t in zip(fields, types):
360-
typ, required, readonly = self.extract_meta_info(t)
361-
res_types.append(typ)
362-
if (total or required is True) and required is not False:
363-
required_keys.add(field)
364-
if readonly:
365-
readonly_keys.add(field)
366-
367-
return fields, res_types, statements, required_keys, readonly_keys
368+
return fields, statements, required_keys, readonly_keys
368369

369370
def extract_meta_info(
370371
self, typ: Type, context: Context | None = None
@@ -433,7 +434,7 @@ def check_typeddict(
433434
name += "@" + str(call.line)
434435
else:
435436
name = var_name = "TypedDict@" + str(call.line)
436-
info = self.build_typeddict_typeinfo(name, [], [], set(), set(), call.line, None)
437+
info = self.build_typeddict_typeinfo(name, {}, set(), set(), call.line, None)
437438
else:
438439
if var_name is not None and name != var_name:
439440
self.fail(
@@ -473,7 +474,12 @@ def check_typeddict(
473474
if isinstance(node.analyzed, TypedDictExpr):
474475
existing_info = node.analyzed.info
475476
info = self.build_typeddict_typeinfo(
476-
name, items, types, required_keys, readonly_keys, call.line, existing_info
477+
name,
478+
dict(zip(items, types)),
479+
required_keys,
480+
readonly_keys,
481+
call.line,
482+
existing_info,
477483
)
478484
info.line = node.line
479485
# Store generated TypeInfo under both names, see semanal_namedtuple for more details.
@@ -578,8 +584,7 @@ def fail_typeddict_arg(
578584
def build_typeddict_typeinfo(
579585
self,
580586
name: str,
581-
items: list[str],
582-
types: list[Type],
587+
item_types: dict[str, Type],
583588
required_keys: set[str],
584589
readonly_keys: set[str],
585590
line: int,
@@ -593,9 +598,7 @@ def build_typeddict_typeinfo(
593598
)
594599
assert fallback is not None
595600
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
596-
typeddict_type = TypedDictType(
597-
dict(zip(items, types)), required_keys, readonly_keys, fallback
598-
)
601+
typeddict_type = TypedDictType(item_types, required_keys, readonly_keys, fallback)
599602
if info.special_alias and has_placeholder(info.special_alias.target):
600603
self.api.process_placeholder(
601604
None, "TypedDict item", info, force_progress=typeddict_type != info.typeddict_type

0 commit comments

Comments
 (0)