Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a dict to keep track of TypedDict fields in semanal #18369

Merged
merged 6 commits into from
Jan 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 58 additions & 55 deletions mypy/semanal_typeddict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections.abc import Collection
from typing import Final

from mypy import errorcodes as codes, message_registry
Expand Down Expand Up @@ -97,21 +98,23 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
existing_info = None
if isinstance(defn.analyzed, TypedDictExpr):
existing_info = defn.analyzed.info

field_types: dict[str, Type] | None
if (
len(defn.base_type_exprs) == 1
and isinstance(defn.base_type_exprs[0], RefExpr)
and defn.base_type_exprs[0].fullname in TPDICT_NAMES
):
# Building a new TypedDict
fields, types, statements, required_keys, readonly_keys = (
field_types, statements, required_keys, readonly_keys = (
self.analyze_typeddict_classdef_fields(defn)
)
if fields is None:
if field_types is None:
return True, None # Defer
if self.api.is_func_scope() and "@" not in defn.name:
defn.name += "@" + str(defn.line)
info = self.build_typeddict_typeinfo(
defn.name, fields, types, required_keys, readonly_keys, defn.line, existing_info
defn.name, field_types, required_keys, readonly_keys, defn.line, existing_info
)
defn.analyzed = TypedDictExpr(info)
defn.analyzed.line = defn.line
Expand Down Expand Up @@ -154,26 +157,24 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
else:
self.fail("All bases of a new TypedDict must be TypedDict types", defn)

keys: list[str] = []
types = []
field_types = {}
required_keys = set()
readonly_keys = set()
# Iterate over bases in reverse order so that leftmost base class' keys take precedence
for base in reversed(typeddict_bases):
self.add_keys_and_types_from_base(
base, keys, types, required_keys, readonly_keys, defn
base, field_types, required_keys, readonly_keys, defn
)
(new_keys, new_types, new_statements, new_required_keys, new_readonly_keys) = (
self.analyze_typeddict_classdef_fields(defn, keys)
(new_field_types, new_statements, new_required_keys, new_readonly_keys) = (
self.analyze_typeddict_classdef_fields(defn, oldfields=field_types)
)
if new_keys is None:
if new_field_types is None:
return True, None # Defer
keys.extend(new_keys)
types.extend(new_types)
field_types.update(new_field_types)
required_keys.update(new_required_keys)
readonly_keys.update(new_readonly_keys)
info = self.build_typeddict_typeinfo(
defn.name, keys, types, required_keys, readonly_keys, defn.line, existing_info
defn.name, field_types, required_keys, readonly_keys, defn.line, existing_info
)
defn.analyzed = TypedDictExpr(info)
defn.analyzed.line = defn.line
Expand All @@ -184,8 +185,7 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
def add_keys_and_types_from_base(
self,
base: Expression,
keys: list[str],
types: list[Type],
field_types: dict[str, Type],
required_keys: set[str],
readonly_keys: set[str],
ctx: Context,
Expand Down Expand Up @@ -224,10 +224,10 @@ def add_keys_and_types_from_base(
with state.strict_optional_set(self.options.strict_optional):
valid_items = self.map_items_to_base(valid_items, tvars, base_args)
for key in base_items:
if key in keys:
if key in field_types:
self.fail(TYPEDDICT_OVERRIDE_MERGE.format(key), ctx)
keys.extend(valid_items.keys())
types.extend(valid_items.values())

field_types.update(valid_items)
required_keys.update(base_typed_dict.required_keys)
readonly_keys.update(base_typed_dict.readonly_keys)

Expand Down Expand Up @@ -280,23 +280,34 @@ def map_items_to_base(
return mapped_items

def analyze_typeddict_classdef_fields(
self, defn: ClassDef, oldfields: list[str] | None = None
) -> tuple[list[str] | None, list[Type], list[Statement], set[str], set[str]]:
self, defn: ClassDef, oldfields: Collection[str] | None = None
) -> tuple[dict[str, Type] | None, list[Statement], set[str], set[str]]:
Comment on lines 282 to +284
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring here should be updated.

"""Analyze fields defined in a TypedDict class definition.

This doesn't consider inherited fields (if any). Also consider totality,
if given.

Return tuple with these items:
* List of keys (or None if found an incomplete reference --> deferral)
* List of types for each key
* Dict of key -> type (or None if found an incomplete reference -> deferral)
* List of statements from defn.defs.body that are legally allowed to be a
part of a TypedDict definition
* Set of required keys
"""
fields: list[str] = []
types: list[Type] = []
fields: dict[str, Type] = {}
readonly_keys = set[str]()
required_keys = set[str]()
statements: list[Statement] = []

total: bool | None = True
for key in defn.keywords:
if key == "total":
total = require_bool_literal_argument(
self.api, defn.keywords["total"], "total", True
)
continue
for_function = ' for "__init_subclass__" of "TypedDict"'
self.msg.unexpected_keyword_argument_for_function(for_function, key, defn)

for stmt in defn.defs.body:
if not isinstance(stmt, AssignmentStmt):
# Still allow pass or ... (for empty TypedDict's) and docstrings
Expand All @@ -320,10 +331,11 @@ def analyze_typeddict_classdef_fields(
self.fail(f'Duplicate TypedDict key "{name}"', stmt)
continue
# Append stmt, name, and type in this case...
fields.append(name)
statements.append(stmt)

field_type: Type
if stmt.unanalyzed_type is None:
types.append(AnyType(TypeOfAny.unannotated))
field_type = AnyType(TypeOfAny.unannotated)
else:
analyzed = self.api.anal_type(
stmt.unanalyzed_type,
Expand All @@ -333,38 +345,27 @@ def analyze_typeddict_classdef_fields(
prohibit_special_class_field_types="TypedDict",
)
if analyzed is None:
return None, [], [], set(), set() # Need to defer
types.append(analyzed)
return None, [], set(), set() # Need to defer
field_type = analyzed
if not has_placeholder(analyzed):
stmt.type = self.extract_meta_info(analyzed, stmt)[0]

field_type, required, readonly = self.extract_meta_info(field_type)
fields[name] = field_type

if (total or required is True) and required is not False:
required_keys.add(name)
if readonly:
readonly_keys.add(name)

# ...despite possible minor failures that allow further analysis.
if stmt.type is None or hasattr(stmt, "new_syntax") and not stmt.new_syntax:
self.fail(TPDICT_CLASS_ERROR, stmt)
elif not isinstance(stmt.rvalue, TempNode):
# x: int assigns rvalue to TempNode(AnyType())
self.fail("Right hand side values are not supported in TypedDict", stmt)
total: bool | None = True
if "total" in defn.keywords:
total = require_bool_literal_argument(self.api, defn.keywords["total"], "total", True)
if defn.keywords and defn.keywords.keys() != {"total"}:
for_function = ' for "__init_subclass__" of "TypedDict"'
for key in defn.keywords:
if key == "total":
continue
self.msg.unexpected_keyword_argument_for_function(for_function, key, defn)

res_types = []
readonly_keys = set()
required_keys = set()
for field, t in zip(fields, types):
typ, required, readonly = self.extract_meta_info(t)
res_types.append(typ)
if (total or required is True) and required is not False:
required_keys.add(field)
if readonly:
readonly_keys.add(field)

return fields, res_types, statements, required_keys, readonly_keys
return fields, statements, required_keys, readonly_keys

def extract_meta_info(
self, typ: Type, context: Context | None = None
Expand Down Expand Up @@ -433,7 +434,7 @@ def check_typeddict(
name += "@" + str(call.line)
else:
name = var_name = "TypedDict@" + str(call.line)
info = self.build_typeddict_typeinfo(name, [], [], set(), set(), call.line, None)
info = self.build_typeddict_typeinfo(name, {}, set(), set(), call.line, None)
else:
if var_name is not None and name != var_name:
self.fail(
Expand Down Expand Up @@ -473,7 +474,12 @@ def check_typeddict(
if isinstance(node.analyzed, TypedDictExpr):
existing_info = node.analyzed.info
info = self.build_typeddict_typeinfo(
name, items, types, required_keys, readonly_keys, call.line, existing_info
name,
dict(zip(items, types)),
required_keys,
readonly_keys,
call.line,
existing_info,
)
info.line = node.line
# Store generated TypeInfo under both names, see semanal_namedtuple for more details.
Expand Down Expand Up @@ -578,8 +584,7 @@ def fail_typeddict_arg(
def build_typeddict_typeinfo(
self,
name: str,
items: list[str],
types: list[Type],
item_types: dict[str, Type],
required_keys: set[str],
readonly_keys: set[str],
line: int,
Expand All @@ -593,9 +598,7 @@ def build_typeddict_typeinfo(
)
assert fallback is not None
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
typeddict_type = TypedDictType(
dict(zip(items, types)), required_keys, readonly_keys, fallback
)
typeddict_type = TypedDictType(item_types, required_keys, readonly_keys, fallback)
if info.special_alias and has_placeholder(info.special_alias.target):
self.api.process_placeholder(
None, "TypedDict item", info, force_progress=typeddict_type != info.typeddict_type
Expand Down
Loading