Skip to content

Commit fc811ae

Browse files
authored
Fix polymorphic application for callback protocols (#16514)
Fixes #16512 The problems were caused if same callback protocol appeared multiple times in a signature. Previous logic confused this with a recursive callback protocol.
1 parent 3e6b552 commit fc811ae

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

mypy/checkexpr.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -6209,11 +6209,16 @@ class PolyTranslator(TypeTranslator):
62096209
See docstring for apply_poly() for details.
62106210
"""
62116211

6212-
def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None:
6212+
def __init__(
6213+
self,
6214+
poly_tvars: Iterable[TypeVarLikeType],
6215+
bound_tvars: frozenset[TypeVarLikeType] = frozenset(),
6216+
seen_aliases: frozenset[TypeInfo] = frozenset(),
6217+
) -> None:
62136218
self.poly_tvars = set(poly_tvars)
62146219
# This is a simplified version of TypeVarScope used during semantic analysis.
6215-
self.bound_tvars: set[TypeVarLikeType] = set()
6216-
self.seen_aliases: set[TypeInfo] = set()
6220+
self.bound_tvars = bound_tvars
6221+
self.seen_aliases = seen_aliases
62176222

62186223
def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
62196224
found_vars = []
@@ -6289,10 +6294,11 @@ def visit_instance(self, t: Instance) -> Type:
62896294
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
62906295
if t.type in self.seen_aliases:
62916296
raise PolyTranslationError()
6292-
self.seen_aliases.add(t.type)
62936297
call = find_member("__call__", t, t, is_operator=True)
62946298
assert call is not None
6295-
return call.accept(self)
6299+
return call.accept(
6300+
PolyTranslator(self.poly_tvars, self.bound_tvars, self.seen_aliases | {t.type})
6301+
)
62966302
return super().visit_instance(t)
62976303

62986304

test-data/unit/check-inference.test

+25
Original file line numberDiff line numberDiff line change
@@ -3788,3 +3788,28 @@ def func2(arg: T) -> List[Union[T, str]]:
37883788
reveal_type(func2) # N: Revealed type is "def [S] (S`4) -> Union[S`4, builtins.str]"
37893789
reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]"
37903790
[builtins fixtures/list.pyi]
3791+
3792+
[case testInferenceAgainstGenericCallbackProtoMultiple]
3793+
from typing import Callable, Protocol, TypeVar
3794+
from typing_extensions import Concatenate, ParamSpec
3795+
3796+
V_co = TypeVar("V_co", covariant=True)
3797+
class Metric(Protocol[V_co]):
3798+
def __call__(self) -> V_co: ...
3799+
3800+
T = TypeVar("T")
3801+
P = ParamSpec("P")
3802+
def simple_metric(func: Callable[Concatenate[int, P], T]) -> Callable[P, T]: ...
3803+
3804+
@simple_metric
3805+
def Negate(count: int, /, metric: Metric[float]) -> float: ...
3806+
@simple_metric
3807+
def Combine(count: int, m1: Metric[T], m2: Metric[T], /, *more: Metric[T]) -> T: ...
3808+
3809+
reveal_type(Negate) # N: Revealed type is "def (metric: __main__.Metric[builtins.float]) -> builtins.float"
3810+
reveal_type(Combine) # N: Revealed type is "def [T] (def () -> T`4, def () -> T`4, *more: def () -> T`4) -> T`4"
3811+
3812+
def m1() -> float: ...
3813+
def m2() -> float: ...
3814+
reveal_type(Combine(m1, m2)) # N: Revealed type is "builtins.float"
3815+
[builtins fixtures/list.pyi]

0 commit comments

Comments
 (0)