From 3d5d501cc59d45b194ccc0496ed0ea2394d901c1 Mon Sep 17 00:00:00 2001 From: A5rocks Date: Wed, 1 Jan 2025 17:46:03 +0900 Subject: [PATCH] Allow passing a callable with type vars in self types --- mypy/solve.py | 13 ++++++------- mypy/typeops.py | 27 +++++++++++++++++++++++---- test-data/unit/check-selftype.test | 15 +++++++++++++++ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/mypy/solve.py b/mypy/solve.py index cac1a23c5a33..45f72914e113 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -73,13 +73,6 @@ def solve_constraints( # Constraints inferred from unions require special handling in polymorphic inference. constraints = skip_reverse_union_constraints(constraints) - # Collect a list of constraints for each type variable. - cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars} - for con in constraints: - if con.type_var in vars + extra_vars: - cmap[con.type_var].append(con) - - if allow_polymorphic: if constraints: solutions, free_vars = solve_with_dependent( vars + extra_vars, constraints, vars, originals @@ -88,6 +81,12 @@ def solve_constraints( solutions = {} free_vars = [] else: + # Collect a list of constraints for each type variable. + cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars} + for con in constraints: + if con.type_var in vars + extra_vars: + cmap[con.type_var].append(con) + solutions = {} free_vars = [] for tv, cs in cmap.items(): diff --git a/mypy/typeops.py b/mypy/typeops.py index 4a269f725cef..4dbc67597cc6 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -352,7 +352,9 @@ class B(A): pass if func.variables and supported_self_type( self_param_type, allow_callable=allow_callable, allow_instances=not ignore_instances ): + from mypy.constraints import SUPERTYPE_OF, infer_constraints from mypy.infer import infer_type_arguments + from mypy.solve import solve_constraints if original_type is None: # TODO: type check method override (see #7861). @@ -364,9 +366,9 @@ class B(A): pass self_vars = [tv for tv in func.variables if tv.id in self_ids] # Solve for these type arguments using the actual class or instance type. - typeargs = infer_type_arguments( - self_vars, self_param_type, original_type, is_supertype=True - ) + constraints = infer_constraints(self_param_type, original_type, SUPERTYPE_OF) + typeargs, free_vars = solve_constraints(self_vars, constraints, allow_polymorphic=True) + if ( is_classmethod and any(isinstance(get_proper_type(t), UninhabitedType) for t in typeargs) @@ -376,12 +378,29 @@ class B(A): pass typeargs = infer_type_arguments( self_vars, self_param_type, TypeType(original_type), is_supertype=True ) + free_vars = [] # Update the method signature with the solutions found. # Technically, some constraints might be unsolvable, make them Never. to_apply = [t if t is not None else UninhabitedType() for t in typeargs] + + # Try to push in any type vars where the self type was the only location. e.g.: + # [T] () -> (T) -> T should return () -> [T] (T) -> T + outer_tvs = set() + for arg in func.arg_types[1:]: + outer_tvs |= set(get_all_type_vars(arg)) & set(free_vars) + + inner_tvs = [v for v in free_vars if v not in outer_tvs] + result_type = get_proper_type(func.ret_type) + if isinstance(result_type, CallableType): + func = func.copy_modified( + ret_type=result_type.copy_modified( + variables=list(result_type.variables) + inner_tvs + ) + ) + func = expand_type(func, {tv.id: arg for tv, arg in zip(self_vars, to_apply)}) - variables = [v for v in func.variables if v not in self_vars] + variables = [v for v in func.variables if v not in self_vars or v in outer_tvs] else: variables = func.variables diff --git a/test-data/unit/check-selftype.test b/test-data/unit/check-selftype.test index fa853ac48e5a..54643a3ca6f5 100644 --- a/test-data/unit/check-selftype.test +++ b/test-data/unit/check-selftype.test @@ -2214,3 +2214,18 @@ class Test2: reveal_type(Test2().method) # N: Revealed type is "def (foo: builtins.int, *, bar: builtins.str) -> builtins.bytes" [builtins fixtures/tuple.pyi] + +[case testCallableWithTypeVarInSelfType] +from typing import Generic, TypeVar, Callable + +T = TypeVar("T") +V = TypeVar("V") + +class X(Generic[T]): + def f(self: X[Callable[[V], None]]) -> Callable[[V], V]: + def inner_f(v: V) -> V: + return v + + return inner_f + +reveal_type(X[Callable[[T], None]]().f()) # N: Revealed type is "def [V] (V`4) -> V`4"