Skip to content

Commit 854a9f8

Browse files
ilevkivskyiIvan LevkivskyiAlexWaygood
authored
Allow None vs TypeVar overlap for overloads (#15846)
Fixes #8881 This is technically unsafe, and I remember we explicitly discussed this a while ago, but related use cases turn out to be more common than I expected (judging by how popular the issue is). Also the fix is really simple. --------- Co-authored-by: Ivan Levkivskyi <[email protected]> Co-authored-by: Alex Waygood <[email protected]>
1 parent a1fcad5 commit 854a9f8

File tree

4 files changed

+135
-29
lines changed

4 files changed

+135
-29
lines changed

mypy/checker.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -7216,22 +7216,32 @@ def is_unsafe_overlapping_overload_signatures(
72167216
#
72177217
# This discrepancy is unfortunately difficult to get rid of, so we repeat the
72187218
# checks twice in both directions for now.
7219+
#
7220+
# Note that we ignore possible overlap between type variables and None. This
7221+
# is technically unsafe, but unsafety is tiny and this prevents some common
7222+
# use cases like:
7223+
# @overload
7224+
# def foo(x: None) -> None: ..
7225+
# @overload
7226+
# def foo(x: T) -> Foo[T]: ...
72197227
return is_callable_compatible(
72207228
signature,
72217229
other,
7222-
is_compat=is_overlapping_types_no_promote_no_uninhabited,
7230+
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
72237231
is_compat_return=lambda l, r: not is_subtype_no_promote(l, r),
72247232
ignore_return=False,
72257233
check_args_covariantly=True,
72267234
allow_partial_overlap=True,
7235+
no_unify_none=True,
72277236
) or is_callable_compatible(
72287237
other,
72297238
signature,
7230-
is_compat=is_overlapping_types_no_promote_no_uninhabited,
7239+
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
72317240
is_compat_return=lambda l, r: not is_subtype_no_promote(r, l),
72327241
ignore_return=False,
72337242
check_args_covariantly=False,
72347243
allow_partial_overlap=True,
7244+
no_unify_none=True,
72357245
)
72367246

72377247

@@ -7717,12 +7727,18 @@ def is_subtype_no_promote(left: Type, right: Type) -> bool:
77177727
return is_subtype(left, right, ignore_promotions=True)
77187728

77197729

7720-
def is_overlapping_types_no_promote_no_uninhabited(left: Type, right: Type) -> bool:
7730+
def is_overlapping_types_no_promote_no_uninhabited_no_none(left: Type, right: Type) -> bool:
77217731
# For the purpose of unsafe overload checks we consider list[<nothing>] and list[int]
77227732
# non-overlapping. This is consistent with how we treat list[int] and list[str] as
77237733
# non-overlapping, despite [] belongs to both. Also this will prevent false positives
77247734
# for failed type inference during unification.
7725-
return is_overlapping_types(left, right, ignore_promotions=True, ignore_uninhabited=True)
7735+
return is_overlapping_types(
7736+
left,
7737+
right,
7738+
ignore_promotions=True,
7739+
ignore_uninhabited=True,
7740+
prohibit_none_typevar_overlap=True,
7741+
)
77267742

77277743

77287744
def is_private(node_name: str) -> bool:

mypy/checkexpr.py

+69-17
Original file line numberDiff line numberDiff line change
@@ -2409,6 +2409,11 @@ def check_overload_call(
24092409
# typevar. See https://github.com/python/mypy/issues/4063 for related discussion.
24102410
erased_targets: list[CallableType] | None = None
24112411
unioned_result: tuple[Type, Type] | None = None
2412+
2413+
# Determine whether we need to encourage union math. This should be generally safe,
2414+
# as union math infers better results in the vast majority of cases, but it is very
2415+
# computationally intensive.
2416+
none_type_var_overlap = self.possible_none_type_var_overlap(arg_types, plausible_targets)
24122417
union_interrupted = False # did we try all union combinations?
24132418
if any(self.real_union(arg) for arg in arg_types):
24142419
try:
@@ -2421,6 +2426,7 @@ def check_overload_call(
24212426
arg_names,
24222427
callable_name,
24232428
object_type,
2429+
none_type_var_overlap,
24242430
context,
24252431
)
24262432
except TooManyUnions:
@@ -2453,8 +2459,10 @@ def check_overload_call(
24532459
# If any of checks succeed, stop early.
24542460
if inferred_result is not None and unioned_result is not None:
24552461
# Both unioned and direct checks succeeded, choose the more precise type.
2456-
if is_subtype(inferred_result[0], unioned_result[0]) and not isinstance(
2457-
get_proper_type(inferred_result[0]), AnyType
2462+
if (
2463+
is_subtype(inferred_result[0], unioned_result[0])
2464+
and not isinstance(get_proper_type(inferred_result[0]), AnyType)
2465+
and not none_type_var_overlap
24582466
):
24592467
return inferred_result
24602468
return unioned_result
@@ -2504,7 +2512,8 @@ def check_overload_call(
25042512
callable_name=callable_name,
25052513
object_type=object_type,
25062514
)
2507-
if union_interrupted:
2515+
# Do not show the extra error if the union math was forced.
2516+
if union_interrupted and not none_type_var_overlap:
25082517
self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context)
25092518
return result
25102519

@@ -2659,6 +2668,44 @@ def overload_erased_call_targets(
26592668
matches.append(typ)
26602669
return matches
26612670

2671+
def possible_none_type_var_overlap(
2672+
self, arg_types: list[Type], plausible_targets: list[CallableType]
2673+
) -> bool:
2674+
"""Heuristic to determine whether we need to try forcing union math.
2675+
2676+
This is needed to avoid greedy type variable match in situations like this:
2677+
@overload
2678+
def foo(x: None) -> None: ...
2679+
@overload
2680+
def foo(x: T) -> list[T]: ...
2681+
2682+
x: int | None
2683+
foo(x)
2684+
we want this call to infer list[int] | None, not list[int | None].
2685+
"""
2686+
if not plausible_targets or not arg_types:
2687+
return False
2688+
has_optional_arg = False
2689+
for arg_type in get_proper_types(arg_types):
2690+
if not isinstance(arg_type, UnionType):
2691+
continue
2692+
for item in get_proper_types(arg_type.items):
2693+
if isinstance(item, NoneType):
2694+
has_optional_arg = True
2695+
break
2696+
if not has_optional_arg:
2697+
return False
2698+
2699+
min_prefix = min(len(c.arg_types) for c in plausible_targets)
2700+
for i in range(min_prefix):
2701+
if any(
2702+
isinstance(get_proper_type(c.arg_types[i]), NoneType) for c in plausible_targets
2703+
) and any(
2704+
isinstance(get_proper_type(c.arg_types[i]), TypeVarType) for c in plausible_targets
2705+
):
2706+
return True
2707+
return False
2708+
26622709
def union_overload_result(
26632710
self,
26642711
plausible_targets: list[CallableType],
@@ -2668,6 +2715,7 @@ def union_overload_result(
26682715
arg_names: Sequence[str | None] | None,
26692716
callable_name: str | None,
26702717
object_type: Type | None,
2718+
none_type_var_overlap: bool,
26712719
context: Context,
26722720
level: int = 0,
26732721
) -> list[tuple[Type, Type]] | None:
@@ -2707,20 +2755,23 @@ def union_overload_result(
27072755

27082756
# Step 3: Try a direct match before splitting to avoid unnecessary union splits
27092757
# and save performance.
2710-
with self.type_overrides_set(args, arg_types):
2711-
direct = self.infer_overload_return_type(
2712-
plausible_targets,
2713-
args,
2714-
arg_types,
2715-
arg_kinds,
2716-
arg_names,
2717-
callable_name,
2718-
object_type,
2719-
context,
2720-
)
2721-
if direct is not None and not isinstance(get_proper_type(direct[0]), (UnionType, AnyType)):
2722-
# We only return non-unions soon, to avoid greedy match.
2723-
return [direct]
2758+
if not none_type_var_overlap:
2759+
with self.type_overrides_set(args, arg_types):
2760+
direct = self.infer_overload_return_type(
2761+
plausible_targets,
2762+
args,
2763+
arg_types,
2764+
arg_kinds,
2765+
arg_names,
2766+
callable_name,
2767+
object_type,
2768+
context,
2769+
)
2770+
if direct is not None and not isinstance(
2771+
get_proper_type(direct[0]), (UnionType, AnyType)
2772+
):
2773+
# We only return non-unions soon, to avoid greedy match.
2774+
return [direct]
27242775

27252776
# Step 4: Split the first remaining union type in arguments into items and
27262777
# try to match each item individually (recursive).
@@ -2738,6 +2789,7 @@ def union_overload_result(
27382789
arg_names,
27392790
callable_name,
27402791
object_type,
2792+
none_type_var_overlap,
27412793
context,
27422794
level + 1,
27432795
)

mypy/subtypes.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,7 @@ def is_callable_compatible(
12991299
check_args_covariantly: bool = False,
13001300
allow_partial_overlap: bool = False,
13011301
strict_concatenate: bool = False,
1302+
no_unify_none: bool = False,
13021303
) -> bool:
13031304
"""Is the left compatible with the right, using the provided compatibility check?
13041305
@@ -1415,7 +1416,9 @@ def g(x: int) -> int: ...
14151416
# (below) treats type variables on the two sides as independent.
14161417
if left.variables:
14171418
# Apply generic type variables away in left via type inference.
1418-
unified = unify_generic_callable(left, right, ignore_return=ignore_return)
1419+
unified = unify_generic_callable(
1420+
left, right, ignore_return=ignore_return, no_unify_none=no_unify_none
1421+
)
14191422
if unified is None:
14201423
return False
14211424
left = unified
@@ -1427,7 +1430,9 @@ def g(x: int) -> int: ...
14271430
# So, we repeat the above checks in the opposite direction. This also
14281431
# lets us preserve the 'symmetry' property of allow_partial_overlap.
14291432
if allow_partial_overlap and right.variables:
1430-
unified = unify_generic_callable(right, left, ignore_return=ignore_return)
1433+
unified = unify_generic_callable(
1434+
right, left, ignore_return=ignore_return, no_unify_none=no_unify_none
1435+
)
14311436
if unified is not None:
14321437
right = unified
14331438

@@ -1687,6 +1692,8 @@ def unify_generic_callable(
16871692
target: NormalizedCallableType,
16881693
ignore_return: bool,
16891694
return_constraint_direction: int | None = None,
1695+
*,
1696+
no_unify_none: bool = False,
16901697
) -> NormalizedCallableType | None:
16911698
"""Try to unify a generic callable type with another callable type.
16921699
@@ -1708,6 +1715,10 @@ def unify_generic_callable(
17081715
type.ret_type, target.ret_type, return_constraint_direction
17091716
)
17101717
constraints.extend(c)
1718+
if no_unify_none:
1719+
constraints = [
1720+
c for c in constraints if not isinstance(get_proper_type(c.target), NoneType)
1721+
]
17111722
inferred_vars, _ = mypy.solve.solve_constraints(type.variables, constraints)
17121723
if None in inferred_vars:
17131724
return None

test-data/unit/check-overloading.test

+33-6
Original file line numberDiff line numberDiff line change
@@ -2185,36 +2185,63 @@ def bar2(*x: int) -> int: ...
21852185
[builtins fixtures/tuple.pyi]
21862186

21872187
[case testOverloadDetectsPossibleMatchesWithGenerics]
2188-
from typing import overload, TypeVar, Generic
2188+
# flags: --strict-optional
2189+
from typing import overload, TypeVar, Generic, Optional, List
21892190

21902191
T = TypeVar('T')
2192+
# The examples below are unsafe, but it is a quite common pattern
2193+
# so we ignore the possibility of type variables taking value `None`
2194+
# for the purpose of overload overlap checks.
21912195

21922196
@overload
2193-
def foo(x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
2197+
def foo(x: None, y: None) -> str: ...
21942198
@overload
21952199
def foo(x: T, y: T) -> int: ...
21962200
def foo(x): ...
21972201

2202+
oi: Optional[int]
2203+
reveal_type(foo(None, None)) # N: Revealed type is "builtins.str"
2204+
reveal_type(foo(None, 42)) # N: Revealed type is "builtins.int"
2205+
reveal_type(foo(42, 42)) # N: Revealed type is "builtins.int"
2206+
reveal_type(foo(oi, None)) # N: Revealed type is "Union[builtins.int, builtins.str]"
2207+
reveal_type(foo(oi, 42)) # N: Revealed type is "builtins.int"
2208+
reveal_type(foo(oi, oi)) # N: Revealed type is "Union[builtins.int, builtins.str]"
2209+
2210+
@overload
2211+
def foo_list(x: None) -> None: ...
2212+
@overload
2213+
def foo_list(x: T) -> List[T]: ...
2214+
def foo_list(x): ...
2215+
2216+
reveal_type(foo_list(oi)) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
2217+
21982218
# What if 'T' is 'object'?
21992219
@overload
2200-
def bar(x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
2220+
def bar(x: None, y: int) -> str: ...
22012221
@overload
22022222
def bar(x: T, y: T) -> int: ...
22032223
def bar(x, y): ...
22042224

22052225
class Wrapper(Generic[T]):
22062226
@overload
2207-
def foo(self, x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
2227+
def foo(self, x: None, y: None) -> str: ...
22082228
@overload
22092229
def foo(self, x: T, y: None) -> int: ...
22102230
def foo(self, x): ...
22112231

22122232
@overload
2213-
def bar(self, x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
2233+
def bar(self, x: None, y: int) -> str: ...
22142234
@overload
22152235
def bar(self, x: T, y: T) -> int: ...
22162236
def bar(self, x, y): ...
22172237

2238+
@overload
2239+
def baz(x: str, y: str) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
2240+
@overload
2241+
def baz(x: T, y: T) -> int: ...
2242+
def baz(x): ...
2243+
[builtins fixtures/tuple.pyi]
2244+
22182245
[case testOverloadFlagsPossibleMatches]
22192246
from wrapper import *
22202247
[file wrapper.pyi]
@@ -3996,7 +4023,7 @@ T = TypeVar('T')
39964023

39974024
class FakeAttribute(Generic[T]):
39984025
@overload
3999-
def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
4026+
def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ...
40004027
@overload
40014028
def dummy(self, instance: T, owner: Type[T]) -> int: ...
40024029
def dummy(self, instance: Optional[T], owner: Type[T]) -> Union['FakeAttribute[T]', int]: ...

0 commit comments

Comments
 (0)