Skip to content

Commit b02ddf1

Browse files
ilevkivskyiIvan Levkivskyi
and
Ivan Levkivskyi
authored
Polymorphic inference: basic support for variadic types (#15879)
This is the fifth PR in the series started by #15287, and a last one for the foreseeable future. This completes polymorphic inference sufficiently for extensive experimentation, and enabling polymorphic fallback by default. Remaining items for which I am going to open follow-up issues: * Enable `--new-type-inference` by default (should be done before everything else in this list). * Use polymorphic inference during unification. * Use polymorphic inference as primary an only mechanism, rather than a fallback if basic inference fails in some way. * Move `apply_poly()` logic from `checkexpr.py` to `applytype.py` (this one depends on everything above). * Experiment with backtracking in the new solver. * Experiment with universal quantification for types other that `Callable` (btw we already have a hacky support for capturing a generic function in an instance with `ParamSpec`). Now some comments on the PR proper. First of all I decided to do some clean-up of `TypeVarTuple` support, but added only strictly necessary parts of the cleanup here. Everything else will be in follow up PR(s). The polymorphic inference/solver/application is practically trivial here, so here is my view on how I see large-scale structure of `TypeVarTuple` implementation: * There should be no special-casing in `applytype.py`, so I deleted everything from there (as I did for `ParamSpec`) and complemented `visit_callable_type()` in `expandtype.py`. Basically, `applytype.py` should have three simple steps: validate substitutions (upper bounds, values, argument kinds etc.); call `expand_type()`; update callable type variables (currently we just reduce the number, but in future we may also add variables there, see TODO that I added). * The only valid positions for a variadic item (a.k.a. `UnpackType`) are inside `Instance`s, `TupleType`s, and `CallableType`s. I like how there is an agreement that for callables there should never be a prefix, and instead prefix should be represented with regular positional arguments. I think that ideally we should enforce this with an `assert` in `CallableType` constructor (similar to how I did this for `ParamSpec`). * Completing `expand_type()` should be a priority (since it describes basic semantics of `TypeVarLikeType`s). I think I made good progress in this direction. IIUC the only valid substitution for `*Ts` are `TupleType.items`, `*tuple[X, ...]`, `Any`, and `<nothing>`, so it was not hard. * I propose to only allow `TupleType` (mostly for `semanal.py`, see item below), plain `TypeVarTupleType`, and a homogeneous `tuple` instances inside `UnpackType`. Supporting unions of those is not specified by the PEP and support will likely be quite tricky to implement. Also I propose to even eagerly expand type aliases to tuples (since there is no point in supporting recursive types like `A = Tuple[int, *A]`). * I propose to forcefully flatten nested `TupleType`s, there should be no things like `Tuple[X1, *Tuple[X2, *Ts, Y2], Y1]` etc after semantic analysis. (Similarly to how we always flatten `Parameters` for `ParamSpec`, and how we flatten nested unions in `UnionType` _constructor_). Currently we do the flattening/normalization of tuples in `expand_type()` etc. * I suspect `build_constraints_for_unpack()` may be broken, at least when it was used for tuples and callables it did something wrong in few cases I tested (and there are other symptoms I mentioned in a TODO). I therefore re-implemented logic for callables/tuples using a separate dedicated helper. I will investigate more later. As I mentioned above I only implemented strictly minimal amount of the above plan to make my tests pass, but still wanted to write this out to see if there are any objections (or maybe I don't understand something). If there are no objections to this plan, I will continue it in separate PR(s). Btw, I like how with this plan we will have clear logical parallels between `TypeVarTuple` implementation and (recently updated) `ParamSpec` implementation. --------- Co-authored-by: Ivan Levkivskyi <[email protected]>
1 parent fa84534 commit b02ddf1

10 files changed

+440
-209
lines changed

mypy/applytype.py

+10-54
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33
from typing import Callable, Sequence
44

55
import mypy.subtypes
6-
from mypy.expandtype import expand_type, expand_unpack_with_variables
7-
from mypy.nodes import ARG_STAR, Context
6+
from mypy.expandtype import expand_type
7+
from mypy.nodes import Context
88
from mypy.types import (
99
AnyType,
1010
CallableType,
11-
Instance,
1211
ParamSpecType,
1312
PartialType,
14-
TupleType,
1513
Type,
1614
TypeVarId,
1715
TypeVarLikeType,
@@ -21,7 +19,6 @@
2119
UnpackType,
2220
get_proper_type,
2321
)
24-
from mypy.typevartuples import find_unpack_in_list, replace_starargs
2522

2623

2724
def get_target_type(
@@ -107,6 +104,8 @@ def apply_generic_arguments(
107104
if target_type is not None:
108105
id_to_type[tvar.id] = target_type
109106

107+
# TODO: validate arg_kinds/arg_names for ParamSpec and TypeVarTuple replacements,
108+
# not just type variable bounds above.
110109
param_spec = callable.param_spec()
111110
if param_spec is not None:
112111
nt = id_to_type.get(param_spec.id)
@@ -122,55 +121,9 @@ def apply_generic_arguments(
122121
# Apply arguments to argument types.
123122
var_arg = callable.var_arg()
124123
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
125-
star_index = callable.arg_kinds.index(ARG_STAR)
126-
callable = callable.copy_modified(
127-
arg_types=(
128-
[expand_type(at, id_to_type) for at in callable.arg_types[:star_index]]
129-
+ [callable.arg_types[star_index]]
130-
+ [expand_type(at, id_to_type) for at in callable.arg_types[star_index + 1 :]]
131-
)
132-
)
133-
134-
unpacked_type = get_proper_type(var_arg.typ.type)
135-
if isinstance(unpacked_type, TupleType):
136-
# Assuming for now that because we convert prefixes to positional arguments,
137-
# the first argument is always an unpack.
138-
expanded_tuple = expand_type(unpacked_type, id_to_type)
139-
if isinstance(expanded_tuple, TupleType):
140-
# TODO: handle the case where the tuple has an unpack. This will
141-
# hit an assert below.
142-
expanded_unpack = find_unpack_in_list(expanded_tuple.items)
143-
if expanded_unpack is not None:
144-
callable = callable.copy_modified(
145-
arg_types=(
146-
callable.arg_types[:star_index]
147-
+ [expanded_tuple]
148-
+ callable.arg_types[star_index + 1 :]
149-
)
150-
)
151-
else:
152-
callable = replace_starargs(callable, expanded_tuple.items)
153-
else:
154-
# TODO: handle the case for if we get a variable length tuple.
155-
assert False, f"mypy bug: unimplemented case, {expanded_tuple}"
156-
elif isinstance(unpacked_type, TypeVarTupleType):
157-
expanded_tvt = expand_unpack_with_variables(var_arg.typ, id_to_type)
158-
if isinstance(expanded_tvt, list):
159-
for t in expanded_tvt:
160-
assert not isinstance(t, UnpackType)
161-
callable = replace_starargs(callable, expanded_tvt)
162-
else:
163-
assert isinstance(expanded_tvt, Instance)
164-
assert expanded_tvt.type.fullname == "builtins.tuple"
165-
callable = callable.copy_modified(
166-
arg_types=(
167-
callable.arg_types[:star_index]
168-
+ [expanded_tvt.args[0]]
169-
+ callable.arg_types[star_index + 1 :]
170-
)
171-
)
172-
else:
173-
assert False, "mypy bug: unhandled case applying unpack"
124+
callable = expand_type(callable, id_to_type)
125+
assert isinstance(callable, CallableType)
126+
return callable.copy_modified(variables=[tv for tv in tvars if tv.id not in id_to_type])
174127
else:
175128
callable = callable.copy_modified(
176129
arg_types=[expand_type(at, id_to_type) for at in callable.arg_types]
@@ -183,6 +136,9 @@ def apply_generic_arguments(
183136
type_guard = None
184137

185138
# The callable may retain some type vars if only some were applied.
139+
# TODO: move apply_poly() logic from checkexpr.py here when new inference
140+
# becomes universally used (i.e. in all passes + in unification).
141+
# With this new logic we can actually *add* some new free variables.
186142
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]
187143

188144
return callable.copy_modified(

mypy/checkexpr.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -2373,11 +2373,15 @@ def check_argument_types(
23732373
]
23742374
actual_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * (len(actuals) - 1)
23752375

2376-
assert isinstance(orig_callee_arg_type, TupleType)
2377-
assert orig_callee_arg_type.items
2378-
callee_arg_types = orig_callee_arg_type.items
2376+
# TODO: can we really assert this? What if formal is just plain Unpack[Ts]?
2377+
assert isinstance(orig_callee_arg_type, UnpackType)
2378+
assert isinstance(orig_callee_arg_type.type, ProperType) and isinstance(
2379+
orig_callee_arg_type.type, TupleType
2380+
)
2381+
assert orig_callee_arg_type.type.items
2382+
callee_arg_types = orig_callee_arg_type.type.items
23792383
callee_arg_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * (
2380-
len(orig_callee_arg_type.items) - 1
2384+
len(orig_callee_arg_type.type.items) - 1
23812385
)
23822386
expanded_tuple = True
23832387

@@ -5853,8 +5857,9 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
58535857
return super().visit_param_spec(t)
58545858

58555859
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
5856-
# TODO: Support polymorphic apply for TypeVarTuple.
5857-
raise PolyTranslationError()
5860+
if t in self.poly_tvars and t not in self.bound_tvars:
5861+
raise PolyTranslationError()
5862+
return super().visit_type_var_tuple(t)
58585863

58595864
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
58605865
if not t.args:
@@ -5888,7 +5893,6 @@ def visit_instance(self, t: Instance) -> Type:
58885893
return t.copy_modified(args=new_args)
58895894
# There is the same problem with callback protocols as with aliases
58905895
# (callback protocols are essentially more flexible aliases to callables).
5891-
# Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T].
58925896
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
58935897
if t.type in self.seen_aliases:
58945898
raise PolyTranslationError()
@@ -5923,6 +5927,12 @@ def __init__(self) -> None:
59235927
def visit_type_var(self, t: TypeVarType) -> bool:
59245928
return True
59255929

5930+
def visit_param_spec(self, t: ParamSpecType) -> bool:
5931+
return True
5932+
5933+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
5934+
return True
5935+
59265936

59275937
def has_erased_component(t: Type | None) -> bool:
59285938
return t is not None and t.accept(HasErasedComponentsQuery())

0 commit comments

Comments
 (0)