Skip to content

Commit b327557

Browse files
authored
Subtyping and inference of user defined variadic types (#16076)
The second part of support for user defined variadic types comes as a single PR, it was hard to split into smaller parts. This part covers subtyping and inference (and relies on the first part: type analysis, normalization, and expansion, concluded by #15991). Note btw that the third (and last) part that covers actually using all the stuff in `checkexpr.py` will likely come as several smaller PRs. Some comments on this PR: * First good news: it looks like instances subtyping/inference can be handled in a really simple way, we just need to find correct type arguments mapping for each type variable, and perform procedures argument by argument (note this heavily relies on the normalization). Also callable subtyping inference for variadic items effectively defers to corresponding tuple types. This way all code paths will ultimately go through variadic tuple subtyping/inference (there is still a bunch of boilerplate to do the mapping, but it is quite simple). * Second some bad news: a lot of edge cases involving `*tuple[X, ...]` were missing everywhere (even couple cases in the code I touched before). I added all that were either simple or important. We can handle more if users will ask, since it is quite tricky. * Note that I handle variadic tuples essentially as infinite unions, the core of the logic for this (and for most of this PR FWIW) is in `variadic_tuple_subtype()`. * Previously `Foo[*tuple[int, ...]]` was considered a subtype of `Foo[int, int]`. I think this is wrong. I didn't find where this is required in the PEP (see one case below however), and mypy currently considers `tuple[int, ...]` not a subtype of `tuple[int, int]` (vice versa are subtypes), and similarly `(*args: int)` vs `(x: int, y: int)` for callables. Because of the logic I described in the first comment, the same logic now uniformly applies to instances as well. * Note however the PEP requires special casing of `Foo[*tuple[Any, ...]]` (equivalent to bare `Foo`), and I agree we should do this. I added a minimal special case for this. Note we also do this for callables as well (`*args: Any` is very different from `*args: object`). And I think we should special case `tuple[Any, ...] <: tuple[int, int]` as well. In the future we can even extend the special casing to `tuple[int, *tuple[Any, ...], int]` in the spirit of #15913 * In this PR I specifically only handle the PEP required item from above for instances. For plain tuples I left a TODO, @hauntsaninja may implement it since it is needed for other unrelated PR. * I make the default upper bound for `TypeVarTupleType` to be `tuple[object, ...]`. I think it can never be `object` (and this simplifies some subtyping corner cases). * TBH I didn't look into callables subtyping/inference very deeply (unlike instances and tuples), if needed we can improve their handling later. * Note I remove some failing unit tests because they test non-nomralized forms that should never appear now. We should probably add some more unit tests, but TBH I am quite tired now.
1 parent 66fbf5b commit b327557

19 files changed

+943
-515
lines changed

mypy/constraints.py

+109-122
Large diffs are not rendered by default.

mypy/erasetype.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,16 @@ def visit_deleted_type(self, t: DeletedType) -> ProperType:
7777
return t
7878

7979
def visit_instance(self, t: Instance) -> ProperType:
80-
return Instance(t.type, [AnyType(TypeOfAny.special_form)] * len(t.args), t.line)
80+
args: list[Type] = []
81+
for tv in t.type.defn.type_vars:
82+
# Valid erasure for *Ts is *tuple[Any, ...], not just Any.
83+
if isinstance(tv, TypeVarTupleType):
84+
args.append(
85+
tv.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])
86+
)
87+
else:
88+
args.append(AnyType(TypeOfAny.special_form))
89+
return Instance(t.type, args, t.line)
8190

8291
def visit_type_var(self, t: TypeVarType) -> ProperType:
8392
return AnyType(TypeOfAny.special_form)

mypy/expandtype.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
255255
variables=[*t.prefix.variables, *repl.variables],
256256
)
257257
else:
258-
# TODO: replace this with "assert False"
258+
# We could encode Any as trivial parameters etc., but it would be too verbose.
259+
# TODO: assert this is a trivial type, like Any, Never, or object.
259260
return repl
260261

261262
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:

mypy/fixup.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,17 @@ def visit_type_info(self, info: TypeInfo) -> None:
8181
info.update_tuple_type(info.tuple_type)
8282
if info.special_alias:
8383
info.special_alias.alias_tvars = list(info.defn.type_vars)
84+
for i, t in enumerate(info.defn.type_vars):
85+
if isinstance(t, TypeVarTupleType):
86+
info.special_alias.tvar_tuple_index = i
8487
if info.typeddict_type:
8588
info.typeddict_type.accept(self.type_fixer)
8689
info.update_typeddict_type(info.typeddict_type)
8790
if info.special_alias:
8891
info.special_alias.alias_tvars = list(info.defn.type_vars)
92+
for i, t in enumerate(info.defn.type_vars):
93+
if isinstance(t, TypeVarTupleType):
94+
info.special_alias.tvar_tuple_index = i
8995
if info.declared_metaclass:
9096
info.declared_metaclass.accept(self.type_fixer)
9197
if info.metaclass_type:
@@ -166,11 +172,7 @@ def visit_decorator(self, d: Decorator) -> None:
166172

167173
def visit_class_def(self, c: ClassDef) -> None:
168174
for v in c.type_vars:
169-
if isinstance(v, TypeVarType):
170-
for value in v.values:
171-
value.accept(self.type_fixer)
172-
v.upper_bound.accept(self.type_fixer)
173-
v.default.accept(self.type_fixer)
175+
v.accept(self.type_fixer)
174176

175177
def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
176178
for value in tv.values:
@@ -184,6 +186,7 @@ def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
184186

185187
def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
186188
tv.upper_bound.accept(self.type_fixer)
189+
tv.tuple_fallback.accept(self.type_fixer)
187190
tv.default.accept(self.type_fixer)
188191

189192
def visit_var(self, v: Var) -> None:
@@ -314,6 +317,7 @@ def visit_param_spec(self, p: ParamSpecType) -> None:
314317
p.default.accept(self)
315318

316319
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
320+
t.tuple_fallback.accept(self)
317321
t.upper_bound.accept(self)
318322
t.default.accept(self)
319323

@@ -336,9 +340,6 @@ def visit_union_type(self, ut: UnionType) -> None:
336340
for it in ut.items:
337341
it.accept(self)
338342

339-
def visit_void(self, o: Any) -> None:
340-
pass # Nothing to descend into.
341-
342343
def visit_type_type(self, t: TypeType) -> None:
343344
t.item.accept(self)
344345

mypy/join.py

+148-6
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@
4343
UninhabitedType,
4444
UnionType,
4545
UnpackType,
46+
find_unpack_in_list,
4647
get_proper_type,
4748
get_proper_types,
49+
split_with_prefix_and_suffix,
4850
)
4951

5052

@@ -67,7 +69,25 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
6769
args: list[Type] = []
6870
# N.B: We use zip instead of indexing because the lengths might have
6971
# mismatches during daemon reprocessing.
70-
for ta, sa, type_var in zip(t.args, s.args, t.type.defn.type_vars):
72+
if t.type.has_type_var_tuple_type:
73+
# We handle joins of variadic instances by simply creating correct mapping
74+
# for type arguments and compute the individual joins same as for regular
75+
# instances. All the heavy lifting is done in the join of tuple types.
76+
assert s.type.type_var_tuple_prefix is not None
77+
assert s.type.type_var_tuple_suffix is not None
78+
prefix = s.type.type_var_tuple_prefix
79+
suffix = s.type.type_var_tuple_suffix
80+
tvt = s.type.defn.type_vars[prefix]
81+
assert isinstance(tvt, TypeVarTupleType)
82+
fallback = tvt.tuple_fallback
83+
s_prefix, s_middle, s_suffix = split_with_prefix_and_suffix(s.args, prefix, suffix)
84+
t_prefix, t_middle, t_suffix = split_with_prefix_and_suffix(t.args, prefix, suffix)
85+
s_args = s_prefix + (TupleType(list(s_middle), fallback),) + s_suffix
86+
t_args = t_prefix + (TupleType(list(t_middle), fallback),) + t_suffix
87+
else:
88+
t_args = t.args
89+
s_args = s.args
90+
for ta, sa, type_var in zip(t_args, s_args, t.type.defn.type_vars):
7191
ta_proper = get_proper_type(ta)
7292
sa_proper = get_proper_type(sa)
7393
new_type: Type | None = None
@@ -93,6 +113,18 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
93113
# If the types are different but equivalent, then an Any is involved
94114
# so using a join in the contravariant case is also OK.
95115
new_type = join_types(ta, sa, self)
116+
elif isinstance(type_var, TypeVarTupleType):
117+
new_type = get_proper_type(join_types(ta, sa, self))
118+
# Put the joined arguments back into instance in the normal form:
119+
# a) Tuple[X, Y, Z] -> [X, Y, Z]
120+
# b) tuple[X, ...] -> [*tuple[X, ...]]
121+
if isinstance(new_type, Instance):
122+
assert new_type.type.fullname == "builtins.tuple"
123+
new_type = UnpackType(new_type)
124+
else:
125+
assert isinstance(new_type, TupleType)
126+
args.extend(new_type.items)
127+
continue
96128
else:
97129
# ParamSpec type variables behave the same, independent of variance
98130
if not is_equivalent(ta, sa):
@@ -440,6 +472,113 @@ def visit_overloaded(self, t: Overloaded) -> ProperType:
440472
return join_types(t, call)
441473
return join_types(t.fallback, s)
442474

475+
def join_tuples(self, s: TupleType, t: TupleType) -> list[Type] | None:
476+
"""Join two tuple types while handling variadic entries.
477+
478+
This is surprisingly tricky, and we don't handle some tricky corner cases.
479+
Most of the trickiness comes from the variadic tuple items like *tuple[X, ...]
480+
since they can have arbitrary partial overlaps (while *Ts can't be split).
481+
"""
482+
s_unpack_index = find_unpack_in_list(s.items)
483+
t_unpack_index = find_unpack_in_list(t.items)
484+
if s_unpack_index is None and t_unpack_index is None:
485+
if s.length() == t.length():
486+
items: list[Type] = []
487+
for i in range(t.length()):
488+
items.append(join_types(t.items[i], s.items[i]))
489+
return items
490+
return None
491+
if s_unpack_index is not None and t_unpack_index is not None:
492+
# The most complex case: both tuples have an upack item.
493+
s_unpack = s.items[s_unpack_index]
494+
assert isinstance(s_unpack, UnpackType)
495+
s_unpacked = get_proper_type(s_unpack.type)
496+
t_unpack = t.items[t_unpack_index]
497+
assert isinstance(t_unpack, UnpackType)
498+
t_unpacked = get_proper_type(t_unpack.type)
499+
if s.length() == t.length() and s_unpack_index == t_unpack_index:
500+
# We can handle a case where arity is perfectly aligned, e.g.
501+
# join(Tuple[X1, *tuple[Y1, ...], Z1], Tuple[X2, *tuple[Y2, ...], Z2]).
502+
# We can essentially perform the join elementwise.
503+
prefix_len = t_unpack_index
504+
suffix_len = t.length() - t_unpack_index - 1
505+
items = []
506+
for si, ti in zip(s.items[:prefix_len], t.items[:prefix_len]):
507+
items.append(join_types(si, ti))
508+
joined = join_types(s_unpacked, t_unpacked)
509+
if isinstance(joined, TypeVarTupleType):
510+
items.append(UnpackType(joined))
511+
elif isinstance(joined, Instance) and joined.type.fullname == "builtins.tuple":
512+
items.append(UnpackType(joined))
513+
else:
514+
if isinstance(t_unpacked, Instance):
515+
assert t_unpacked.type.fullname == "builtins.tuple"
516+
tuple_instance = t_unpacked
517+
else:
518+
assert isinstance(t_unpacked, TypeVarTupleType)
519+
tuple_instance = t_unpacked.tuple_fallback
520+
items.append(
521+
UnpackType(
522+
tuple_instance.copy_modified(
523+
args=[object_from_instance(tuple_instance)]
524+
)
525+
)
526+
)
527+
if suffix_len:
528+
for si, ti in zip(s.items[-suffix_len:], t.items[-suffix_len:]):
529+
items.append(join_types(si, ti))
530+
return items
531+
if s.length() == 1 or t.length() == 1:
532+
# Another case we can handle is when one of tuple is purely variadic
533+
# (i.e. a non-normalized form of tuple[X, ...]), in this case the join
534+
# will be again purely variadic.
535+
if not (isinstance(s_unpacked, Instance) and isinstance(t_unpacked, Instance)):
536+
return None
537+
assert s_unpacked.type.fullname == "builtins.tuple"
538+
assert t_unpacked.type.fullname == "builtins.tuple"
539+
mid_joined = join_types(s_unpacked.args[0], t_unpacked.args[0])
540+
t_other = [a for i, a in enumerate(t.items) if i != t_unpack_index]
541+
s_other = [a for i, a in enumerate(s.items) if i != s_unpack_index]
542+
other_joined = join_type_list(s_other + t_other)
543+
mid_joined = join_types(mid_joined, other_joined)
544+
return [UnpackType(s_unpacked.copy_modified(args=[mid_joined]))]
545+
# TODO: are there other case we can handle (e.g. both prefix/suffix are shorter)?
546+
return None
547+
if s_unpack_index is not None:
548+
variadic = s
549+
unpack_index = s_unpack_index
550+
fixed = t
551+
else:
552+
assert t_unpack_index is not None
553+
variadic = t
554+
unpack_index = t_unpack_index
555+
fixed = s
556+
# Case where one tuple has variadic item and the other one doesn't. The join will
557+
# be variadic, since fixed tuple is a subtype of variadic, but not vice versa.
558+
unpack = variadic.items[unpack_index]
559+
assert isinstance(unpack, UnpackType)
560+
unpacked = get_proper_type(unpack.type)
561+
if not isinstance(unpacked, Instance):
562+
return None
563+
if fixed.length() < variadic.length() - 1:
564+
# There are no non-trivial types that are supertype of both.
565+
return None
566+
prefix_len = unpack_index
567+
suffix_len = variadic.length() - prefix_len - 1
568+
prefix, middle, suffix = split_with_prefix_and_suffix(
569+
tuple(fixed.items), prefix_len, suffix_len
570+
)
571+
items = []
572+
for fi, vi in zip(prefix, variadic.items[:prefix_len]):
573+
items.append(join_types(fi, vi))
574+
mid_joined = join_type_list(list(middle))
575+
mid_joined = join_types(mid_joined, unpacked.args[0])
576+
items.append(UnpackType(unpacked.copy_modified(args=[mid_joined])))
577+
if suffix_len:
578+
for fi, vi in zip(suffix, variadic.items[-suffix_len:]):
579+
items.append(join_types(fi, vi))
580+
return items
581+
443582
def visit_tuple_type(self, t: TupleType) -> ProperType:
444583
# When given two fixed-length tuples:
445584
# * If they have the same length, join their subtypes item-wise:
@@ -452,19 +591,22 @@ def visit_tuple_type(self, t: TupleType) -> ProperType:
452591
# Tuple[int, bool] + Tuple[bool, ...] becomes Tuple[int, ...]
453592
# * Joining with any Sequence also returns a Sequence:
454593
# Tuple[int, bool] + List[bool] becomes Sequence[int]
455-
if isinstance(self.s, TupleType) and self.s.length() == t.length():
594+
if isinstance(self.s, TupleType):
456595
if self.instance_joiner is None:
457596
self.instance_joiner = InstanceJoiner()
458597
fallback = self.instance_joiner.join_instances(
459598
mypy.typeops.tuple_fallback(self.s), mypy.typeops.tuple_fallback(t)
460599
)
461600
assert isinstance(fallback, Instance)
462-
if self.s.length() == t.length():
463-
items: list[Type] = []
464-
for i in range(t.length()):
465-
items.append(join_types(t.items[i], self.s.items[i]))
601+
items = self.join_tuples(self.s, t)
602+
if items is not None:
466603
return TupleType(items, fallback)
467604
else:
605+
# TODO: should this be a default fallback behaviour like for meet?
606+
if is_proper_subtype(self.s, t):
607+
return t
608+
if is_proper_subtype(t, self.s):
609+
return self.s
468610
return fallback
469611
else:
470612
return join_types(self.s, mypy.typeops.tuple_fallback(t))

0 commit comments

Comments
 (0)