Skip to content

Commit 0d708cb

Browse files
New type inference: complete transitive closure (#15754)
This is a first follow-up for #15287 (I like how my PR titles sound like research paper titles, LOL) This PR completes the new type inference foundations by switching to a complete and well founded algorithm [1] for transitive closure (that replaces more ad hoc initial algorithm that covered 80% of cases and was good for experimenting with new inference scheme). In particular the algorithm in this PR covers two important edge cases (see tests). Some comments: * I don't intend to switch the default for `--new-type-inference`, I just want to see the effect of the switch on `mypy_primer`, I will switch back to false before merging * This flag is still not ready to be publicly announced, I am going to make another 2-3 PRs from the list in #15287 before making this public. * I am not adding yet the unit tests as discussed in previous PR. This PR is already quite big, and the next one (support for upper bounds and values) should be much smaller. I am going to add unit tests only for `transitive_closure()` which is the core of new logic. * While working on this I fixed couple bugs exposed in `TypeVarTuple` support: one is rare technical corner case, another one is serious, template and actual where swapped during constraint inference, effectively causing outer/return context to be completely ignored for instances. * It is better to review the PR with "ignore whitespace" option turned on (there is big chunk in solve.py that is just change of indentation). * There is one questionable design choice I am making in this PR, I am adding `extra_tvars` as an attribute of `Constraint` class, while it logically should not be attributed to any individual constraint, but rather to the full list of constrains. However, doing this properly would require changing the return type of `infer_constrains()` and all related functions, which would be a really big refactoring. [1] Definition 7.1 in https://inria.hal.science/inria-00073205/document --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2b613e5 commit 0d708cb

13 files changed

+356
-319
lines changed

mypy/checker.py

+26-46
Original file line numberDiff line numberDiff line change
@@ -734,8 +734,10 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
734734
# def foo(x: str) -> str: ...
735735
#
736736
# See Python 2's map function for a concrete example of this kind of overload.
737+
current_class = self.scope.active_class()
738+
type_vars = current_class.defn.type_vars if current_class else []
737739
with state.strict_optional_set(True):
738-
if is_unsafe_overlapping_overload_signatures(sig1, sig2):
740+
if is_unsafe_overlapping_overload_signatures(sig1, sig2, type_vars):
739741
self.msg.overloaded_signatures_overlap(i + 1, i + j + 2, item.func)
740742

741743
if impl_type is not None:
@@ -1702,7 +1704,9 @@ def is_unsafe_overlapping_op(
17021704
first = forward_tweaked
17031705
second = reverse_tweaked
17041706

1705-
return is_unsafe_overlapping_overload_signatures(first, second)
1707+
current_class = self.scope.active_class()
1708+
type_vars = current_class.defn.type_vars if current_class else []
1709+
return is_unsafe_overlapping_overload_signatures(first, second, type_vars)
17061710

17071711
def check_inplace_operator_method(self, defn: FuncBase) -> None:
17081712
"""Check an inplace operator method such as __iadd__.
@@ -3918,11 +3922,12 @@ def is_valid_defaultdict_partial_value_type(self, t: ProperType) -> bool:
39183922
return True
39193923
if len(t.args) == 1:
39203924
arg = get_proper_type(t.args[0])
3921-
# TODO: This is too permissive -- we only allow TypeVarType since
3922-
# they leak in cases like defaultdict(list) due to a bug.
3923-
# This can result in incorrect types being inferred, but only
3924-
# in rare cases.
3925-
if isinstance(arg, (TypeVarType, UninhabitedType, NoneType)):
3925+
if self.options.new_type_inference:
3926+
allowed = isinstance(arg, (UninhabitedType, NoneType))
3927+
else:
3928+
# Allow leaked TypeVars for legacy inference logic.
3929+
allowed = isinstance(arg, (UninhabitedType, NoneType, TypeVarType))
3930+
if allowed:
39263931
return True
39273932
return False
39283933

@@ -7179,7 +7184,7 @@ def are_argument_counts_overlapping(t: CallableType, s: CallableType) -> bool:
71797184

71807185

71817186
def is_unsafe_overlapping_overload_signatures(
7182-
signature: CallableType, other: CallableType
7187+
signature: CallableType, other: CallableType, class_type_vars: list[TypeVarLikeType]
71837188
) -> bool:
71847189
"""Check if two overloaded signatures are unsafely overlapping or partially overlapping.
71857190
@@ -7198,8 +7203,8 @@ def is_unsafe_overlapping_overload_signatures(
71987203
# This lets us identify cases where the two signatures use completely
71997204
# incompatible types -- e.g. see the testOverloadingInferUnionReturnWithMixedTypevars
72007205
# test case.
7201-
signature = detach_callable(signature)
7202-
other = detach_callable(other)
7206+
signature = detach_callable(signature, class_type_vars)
7207+
other = detach_callable(other, class_type_vars)
72037208

72047209
# Note: We repeat this check twice in both directions due to a slight
72057210
# asymmetry in 'is_callable_compatible'. When checking for partial overlaps,
@@ -7230,7 +7235,7 @@ def is_unsafe_overlapping_overload_signatures(
72307235
)
72317236

72327237

7233-
def detach_callable(typ: CallableType) -> CallableType:
7238+
def detach_callable(typ: CallableType, class_type_vars: list[TypeVarLikeType]) -> CallableType:
72347239
"""Ensures that the callable's type variables are 'detached' and independent of the context.
72357240
72367241
A callable normally keeps track of the type variables it uses within its 'variables' field.
@@ -7240,42 +7245,17 @@ def detach_callable(typ: CallableType) -> CallableType:
72407245
This function will traverse the callable and find all used type vars and add them to the
72417246
variables field if it isn't already present.
72427247
7243-
The caller can then unify on all type variables whether or not the callable is originally
7244-
from a class or not."""
7245-
type_list = typ.arg_types + [typ.ret_type]
7246-
7247-
appear_map: dict[str, list[int]] = {}
7248-
for i, inner_type in enumerate(type_list):
7249-
typevars_available = get_type_vars(inner_type)
7250-
for var in typevars_available:
7251-
if var.fullname not in appear_map:
7252-
appear_map[var.fullname] = []
7253-
appear_map[var.fullname].append(i)
7254-
7255-
used_type_var_names = set()
7256-
for var_name, appearances in appear_map.items():
7257-
used_type_var_names.add(var_name)
7258-
7259-
all_type_vars = get_type_vars(typ)
7260-
new_variables = []
7261-
for var in set(all_type_vars):
7262-
if var.fullname not in used_type_var_names:
7263-
continue
7264-
new_variables.append(
7265-
TypeVarType(
7266-
name=var.name,
7267-
fullname=var.fullname,
7268-
id=var.id,
7269-
values=var.values,
7270-
upper_bound=var.upper_bound,
7271-
default=var.default,
7272-
variance=var.variance,
7273-
)
7274-
)
7275-
out = typ.copy_modified(
7276-
variables=new_variables, arg_types=type_list[:-1], ret_type=type_list[-1]
7248+
The caller can then unify on all type variables whether the callable is originally from
7249+
the class or not."""
7250+
if not class_type_vars:
7251+
# Fast path, nothing to update.
7252+
return typ
7253+
seen_type_vars = set()
7254+
for t in typ.arg_types + [typ.ret_type]:
7255+
seen_type_vars |= set(get_type_vars(t))
7256+
return typ.copy_modified(
7257+
variables=list(typ.variables) + [tv for tv in class_type_vars if tv in seen_type_vars]
72777258
)
7278-
return out
72797259

72807260

72817261
def overload_can_never_match(signature: CallableType, other: CallableType) -> bool:

mypy/checkexpr.py

+19-21
Original file line numberDiff line numberDiff line change
@@ -1857,7 +1857,7 @@ def infer_function_type_arguments_using_context(
18571857
# expects_literal(identity(3)) # Should type-check
18581858
if not is_generic_instance(ctx) and not is_literal_type_like(ctx):
18591859
return callable.copy_modified()
1860-
args = infer_type_arguments(callable.type_var_ids(), ret_type, erased_ctx)
1860+
args = infer_type_arguments(callable.variables, ret_type, erased_ctx)
18611861
# Only substitute non-Uninhabited and non-erased types.
18621862
new_args: list[Type | None] = []
18631863
for arg in args:
@@ -1906,7 +1906,7 @@ def infer_function_type_arguments(
19061906
else:
19071907
pass1_args.append(arg)
19081908

1909-
inferred_args = infer_function_type_arguments(
1909+
inferred_args, _ = infer_function_type_arguments(
19101910
callee_type,
19111911
pass1_args,
19121912
arg_kinds,
@@ -1948,7 +1948,7 @@ def infer_function_type_arguments(
19481948
# variables while allowing for polymorphic solutions, i.e. for solutions
19491949
# potentially involving free variables.
19501950
# TODO: support the similar inference for return type context.
1951-
poly_inferred_args = infer_function_type_arguments(
1951+
poly_inferred_args, free_vars = infer_function_type_arguments(
19521952
callee_type,
19531953
arg_types,
19541954
arg_kinds,
@@ -1957,30 +1957,28 @@ def infer_function_type_arguments(
19571957
strict=self.chk.in_checked_function(),
19581958
allow_polymorphic=True,
19591959
)
1960-
for i, pa in enumerate(get_proper_types(poly_inferred_args)):
1961-
if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa):
1962-
# Indicate that free variables should not be applied in the call below.
1963-
poly_inferred_args[i] = None
19641960
poly_callee_type = self.apply_generic_arguments(
19651961
callee_type, poly_inferred_args, context
19661962
)
1967-
yes_vars = poly_callee_type.variables
1968-
no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables}
1969-
if not set(get_type_vars(poly_callee_type)) & no_vars:
1970-
# Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can
1971-
# be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed.
1972-
applied = apply_poly(poly_callee_type, yes_vars)
1973-
if applied is not None and poly_inferred_args != [UninhabitedType()] * len(
1974-
poly_inferred_args
1975-
):
1976-
freeze_all_type_vars(applied)
1977-
return applied
1963+
# Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can
1964+
# be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed.
1965+
applied = apply_poly(poly_callee_type, free_vars)
1966+
if applied is not None and all(
1967+
a is not None and not isinstance(get_proper_type(a), UninhabitedType)
1968+
for a in poly_inferred_args
1969+
):
1970+
freeze_all_type_vars(applied)
1971+
return applied
19781972
# If it didn't work, erase free variables as <nothing>, to avoid confusing errors.
1973+
unknown = UninhabitedType()
1974+
unknown.ambiguous = True
19791975
inferred_args = [
1980-
expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables})
1976+
expand_type(
1977+
a, {v.id: unknown for v in list(callee_type.variables) + free_vars}
1978+
)
19811979
if a is not None
19821980
else None
1983-
for a in inferred_args
1981+
for a in poly_inferred_args
19841982
]
19851983
else:
19861984
# In dynamically typed functions use implicit 'Any' types for
@@ -2019,7 +2017,7 @@ def infer_function_type_arguments_pass2(
20192017

20202018
arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual)
20212019

2022-
inferred_args = infer_function_type_arguments(
2020+
inferred_args, _ = infer_function_type_arguments(
20232021
callee_type,
20242022
arg_types,
20252023
arg_kinds,

mypy/constraints.py

+41-22
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def __init__(self, type_var: TypeVarLikeType, op: int, target: Type) -> None:
7373
self.op = op
7474
self.target = target
7575
self.origin_type_var = type_var
76+
# These are additional type variables that should be solved for together with type_var.
77+
# TODO: A cleaner solution may be to modify the return type of infer_constraints()
78+
# to include these instead, but this is a rather big refactoring.
79+
self.extra_tvars: list[TypeVarLikeType] = []
7680

7781
def __repr__(self) -> str:
7882
op_str = "<:"
@@ -168,7 +172,9 @@ def infer_constraints_for_callable(
168172
return constraints
169173

170174

171-
def infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]:
175+
def infer_constraints(
176+
template: Type, actual: Type, direction: int, skip_neg_op: bool = False
177+
) -> list[Constraint]:
172178
"""Infer type constraints.
173179
174180
Match a template type, which may contain type variable references,
@@ -187,7 +193,9 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons
187193
((T, S), (X, Y)) --> T :> X and S :> Y
188194
(X[T], Any) --> T <: Any and T :> Any
189195
190-
The constraints are represented as Constraint objects.
196+
The constraints are represented as Constraint objects. If skip_neg_op == True,
197+
then skip adding reverse (polymorphic) constraints (since this is already a call
198+
to infer such constraints).
191199
"""
192200
if any(
193201
get_proper_type(template) == get_proper_type(t)
@@ -202,13 +210,15 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons
202210
# Return early on an empty branch.
203211
return []
204212
type_state.inferring.append((template, actual))
205-
res = _infer_constraints(template, actual, direction)
213+
res = _infer_constraints(template, actual, direction, skip_neg_op)
206214
type_state.inferring.pop()
207215
return res
208-
return _infer_constraints(template, actual, direction)
216+
return _infer_constraints(template, actual, direction, skip_neg_op)
209217

210218

211-
def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]:
219+
def _infer_constraints(
220+
template: Type, actual: Type, direction: int, skip_neg_op: bool
221+
) -> list[Constraint]:
212222
orig_template = template
213223
template = get_proper_type(template)
214224
actual = get_proper_type(actual)
@@ -284,7 +294,7 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Con
284294
return []
285295

286296
# Remaining cases are handled by ConstraintBuilderVisitor.
287-
return template.accept(ConstraintBuilderVisitor(actual, direction))
297+
return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op))
288298

289299

290300
def infer_constraints_if_possible(
@@ -510,10 +520,14 @@ class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]):
510520
# TODO: The value may be None. Is that actually correct?
511521
actual: ProperType
512522

513-
def __init__(self, actual: ProperType, direction: int) -> None:
523+
def __init__(self, actual: ProperType, direction: int, skip_neg_op: bool) -> None:
514524
# Direction must be SUBTYPE_OF or SUPERTYPE_OF.
515525
self.actual = actual
516526
self.direction = direction
527+
# Whether to skip polymorphic inference (involves inference in opposite direction)
528+
# this is used to prevent infinite recursion when both template and actual are
529+
# generic callables.
530+
self.skip_neg_op = skip_neg_op
517531

518532
# Trivial leaf types
519533

@@ -648,13 +662,13 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
648662
assert mapped.type.type_var_tuple_prefix is not None
649663
assert mapped.type.type_var_tuple_suffix is not None
650664

651-
unpack_constraints, mapped_args, instance_args = build_constraints_for_unpack(
652-
mapped.args,
653-
mapped.type.type_var_tuple_prefix,
654-
mapped.type.type_var_tuple_suffix,
665+
unpack_constraints, instance_args, mapped_args = build_constraints_for_unpack(
655666
instance.args,
656667
instance.type.type_var_tuple_prefix,
657668
instance.type.type_var_tuple_suffix,
669+
mapped.args,
670+
mapped.type.type_var_tuple_prefix,
671+
mapped.type.type_var_tuple_suffix,
658672
self.direction,
659673
)
660674
res.extend(unpack_constraints)
@@ -879,6 +893,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
879893
# Note that non-normalized callables can be created in annotations
880894
# using e.g. callback protocols.
881895
template = template.with_unpacked_kwargs()
896+
extra_tvars = False
882897
if isinstance(self.actual, CallableType):
883898
res: list[Constraint] = []
884899
cactual = self.actual.with_unpacked_kwargs()
@@ -890,25 +905,23 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
890905
type_state.infer_polymorphic
891906
and cactual.variables
892907
and cactual.param_spec() is None
908+
and not self.skip_neg_op
893909
# Technically, the correct inferred type for application of e.g.
894910
# Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic
895911
# like U -> U, should be Callable[..., Any], but if U is a self-type, we can
896912
# allow it to leak, to be later bound to self. A bunch of existing code
897913
# depends on this old behaviour.
898914
and not any(tv.id.raw_id == 0 for tv in cactual.variables)
899915
):
900-
# If actual is generic, unify it with template. Note: this is
901-
# not an ideal solution (which would be adding the generic variables
902-
# to the constraint inference set), but it's a good first approximation,
903-
# and this will prevent leaking these variables in the solutions.
904-
# Note: this may infer constraints like T <: S or T <: List[S]
905-
# that contain variables in the target.
906-
unified = mypy.subtypes.unify_generic_callable(
907-
cactual, template, ignore_return=True
916+
# If the actual callable is generic, infer constraints in the opposite
917+
# direction, and indicate to the solver there are extra type variables
918+
# to solve for (see more details in mypy/solve.py).
919+
res.extend(
920+
infer_constraints(
921+
cactual, template, neg_op(self.direction), skip_neg_op=True
922+
)
908923
)
909-
if unified is not None:
910-
cactual = unified
911-
res.extend(infer_constraints(cactual, template, neg_op(self.direction)))
924+
extra_tvars = True
912925

913926
# We can't infer constraints from arguments if the template is Callable[..., T]
914927
# (with literal '...').
@@ -978,6 +991,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
978991
cactual_ret_type = cactual.type_guard
979992

980993
res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))
994+
if extra_tvars:
995+
for c in res:
996+
c.extra_tvars = list(cactual.variables)
981997
return res
982998
elif isinstance(self.actual, AnyType):
983999
param_spec = template.param_spec()
@@ -1205,6 +1221,9 @@ def find_and_build_constraints_for_unpack(
12051221

12061222

12071223
def build_constraints_for_unpack(
1224+
# TODO: this naming is misleading, these should be "actual", not "mapped"
1225+
# both template and actual can be mapped before, depending on direction.
1226+
# Also the convention is to put template related args first.
12081227
mapped: tuple[Type, ...],
12091228
mapped_prefix_len: int | None,
12101229
mapped_suffix_len: int | None,

mypy/expandtype.py

+5
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,11 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
272272
return repl
273273

274274
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
275+
# Sometimes solver may need to expand a type variable with (a copy of) itself
276+
# (usually together with other TypeVars, but it is hard to filter out TypeVarTuples).
277+
repl = self.variables[t.id]
278+
if isinstance(repl, TypeVarTupleType):
279+
return repl
275280
raise NotImplementedError
276281

277282
def visit_unpack_type(self, t: UnpackType) -> Type:

0 commit comments

Comments
 (0)