Skip to content

Commit 43ea203

Browse files
authored
Infer correct types with overloads of Type[Guard | Is] (#17678)
Closes #17579 Consider this as a prototype, because I understand that there might be a lot of extra work to get this right. However, this does solve this problem in the original issue.
1 parent 42a97bb commit 43ea203

File tree

4 files changed

+268
-14
lines changed

4 files changed

+268
-14
lines changed

Diff for: mypy/checker.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -6036,15 +6036,31 @@ def find_isinstance_check_helper(
60366036
# considered "always right" (i.e. even if the types are not overlapping).
60376037
# Also note that a care must be taken to unwrap this back at read places
60386038
# where we use this to narrow down declared type.
6039-
if node.callee.type_guard is not None:
6040-
return {expr: TypeGuardedType(node.callee.type_guard)}, {}
6039+
with self.msg.filter_errors(), self.local_type_map():
6040+
# `node.callee` can be an `overload`ed function,
6041+
# we need to resolve the real `overload` case.
6042+
_, real_func = self.expr_checker.check_call(
6043+
get_proper_type(self.lookup_type(node.callee)),
6044+
node.args,
6045+
node.arg_kinds,
6046+
node,
6047+
node.arg_names,
6048+
)
6049+
real_func = get_proper_type(real_func)
6050+
if not isinstance(real_func, CallableType) or not (
6051+
real_func.type_guard or real_func.type_is
6052+
):
6053+
return {}, {}
6054+
6055+
if real_func.type_guard is not None:
6056+
return {expr: TypeGuardedType(real_func.type_guard)}, {}
60416057
else:
6042-
assert node.callee.type_is is not None
6058+
assert real_func.type_is is not None
60436059
return conditional_types_to_typemaps(
60446060
expr,
60456061
*self.conditional_types_with_intersection(
60466062
self.lookup_type(expr),
6047-
[TypeRange(node.callee.type_is, is_upper_bound=False)],
6063+
[TypeRange(real_func.type_is, is_upper_bound=False)],
60486064
expr,
60496065
),
60506066
)

Diff for: mypy/checkexpr.py

+73-10
Original file line numberDiff line numberDiff line change
@@ -2906,16 +2906,37 @@ def infer_overload_return_type(
29062906
elif all_same_types([erase_type(typ) for typ in return_types]):
29072907
self.chk.store_types(type_maps[0])
29082908
return erase_type(return_types[0]), erase_type(inferred_types[0])
2909-
else:
2910-
return self.check_call(
2911-
callee=AnyType(TypeOfAny.special_form),
2912-
args=args,
2913-
arg_kinds=arg_kinds,
2914-
arg_names=arg_names,
2915-
context=context,
2916-
callable_name=callable_name,
2917-
object_type=object_type,
2918-
)
2909+
return self.check_call(
2910+
callee=AnyType(TypeOfAny.special_form),
2911+
args=args,
2912+
arg_kinds=arg_kinds,
2913+
arg_names=arg_names,
2914+
context=context,
2915+
callable_name=callable_name,
2916+
object_type=object_type,
2917+
)
2918+
elif not all_same_type_narrowers(matches):
2919+
# This is an example of how overloads can be:
2920+
#
2921+
# @overload
2922+
# def is_int(obj: float) -> TypeGuard[float]: ...
2923+
# @overload
2924+
# def is_int(obj: int) -> TypeGuard[int]: ...
2925+
#
2926+
# x: Any
2927+
# if is_int(x):
2928+
# reveal_type(x) # N: int | float
2929+
#
2930+
# So, we need to check that special case.
2931+
return self.check_call(
2932+
callee=self.combine_function_signatures(cast("list[ProperType]", matches)),
2933+
args=args,
2934+
arg_kinds=arg_kinds,
2935+
arg_names=arg_names,
2936+
context=context,
2937+
callable_name=callable_name,
2938+
object_type=object_type,
2939+
)
29192940
else:
29202941
# Success! No ambiguity; return the first match.
29212942
self.chk.store_types(type_maps[0])
@@ -3130,6 +3151,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
31303151
new_args: list[list[Type]] = [[] for _ in range(len(callables[0].arg_types))]
31313152
new_kinds = list(callables[0].arg_kinds)
31323153
new_returns: list[Type] = []
3154+
new_type_guards: list[Type] = []
3155+
new_type_narrowers: list[Type] = []
31333156

31343157
too_complex = False
31353158
for target in callables:
@@ -3156,8 +3179,25 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
31563179
for i, arg in enumerate(target.arg_types):
31573180
new_args[i].append(arg)
31583181
new_returns.append(target.ret_type)
3182+
if target.type_guard:
3183+
new_type_guards.append(target.type_guard)
3184+
if target.type_is:
3185+
new_type_narrowers.append(target.type_is)
3186+
3187+
if new_type_guards and new_type_narrowers:
3188+
# They cannot be definined at the same time,
3189+
# declaring this function as too complex!
3190+
too_complex = True
3191+
union_type_guard = None
3192+
union_type_is = None
3193+
else:
3194+
union_type_guard = make_simplified_union(new_type_guards) if new_type_guards else None
3195+
union_type_is = (
3196+
make_simplified_union(new_type_narrowers) if new_type_narrowers else None
3197+
)
31593198

31603199
union_return = make_simplified_union(new_returns)
3200+
31613201
if too_complex:
31623202
any = AnyType(TypeOfAny.special_form)
31633203
return callables[0].copy_modified(
@@ -3167,6 +3207,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
31673207
ret_type=union_return,
31683208
variables=variables,
31693209
implicit=True,
3210+
type_guard=union_type_guard,
3211+
type_is=union_type_is,
31703212
)
31713213

31723214
final_args = []
@@ -3180,6 +3222,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
31803222
ret_type=union_return,
31813223
variables=variables,
31823224
implicit=True,
3225+
type_guard=union_type_guard,
3226+
type_is=union_type_is,
31833227
)
31843228

31853229
def erased_signature_similarity(
@@ -6520,6 +6564,25 @@ def all_same_types(types: list[Type]) -> bool:
65206564
return all(is_same_type(t, types[0]) for t in types[1:])
65216565

65226566

6567+
def all_same_type_narrowers(types: list[CallableType]) -> bool:
6568+
if len(types) <= 1:
6569+
return True
6570+
6571+
type_guards: list[Type] = []
6572+
type_narrowers: list[Type] = []
6573+
6574+
for typ in types:
6575+
if typ.type_guard:
6576+
type_guards.append(typ.type_guard)
6577+
if typ.type_is:
6578+
type_narrowers.append(typ.type_is)
6579+
if type_guards and type_narrowers:
6580+
# Some overloads declare `TypeGuard` and some declare `TypeIs`,
6581+
# we cannot handle this in a union.
6582+
return False
6583+
return all_same_types(type_guards) and all_same_types(type_narrowers)
6584+
6585+
65236586
def merge_typevars_in_callables_by_name(
65246587
callables: Sequence[CallableType],
65256588
) -> tuple[list[CallableType], list[TypeVarType]]:

Diff for: test-data/unit/check-typeguard.test

+56
Original file line numberDiff line numberDiff line change
@@ -730,3 +730,59 @@ x: object
730730
assert a(x=x)
731731
reveal_type(x) # N: Revealed type is "builtins.int"
732732
[builtins fixtures/tuple.pyi]
733+
734+
[case testTypeGuardInOverloads]
735+
from typing import Any, overload, Union
736+
from typing_extensions import TypeGuard
737+
738+
@overload
739+
def func1(x: str) -> TypeGuard[str]:
740+
...
741+
742+
@overload
743+
def func1(x: int) -> TypeGuard[int]:
744+
...
745+
746+
def func1(x: Any) -> Any:
747+
return True
748+
749+
def func2(val: Any):
750+
if func1(val):
751+
reveal_type(val) # N: Revealed type is "Union[builtins.str, builtins.int]"
752+
else:
753+
reveal_type(val) # N: Revealed type is "Any"
754+
755+
def func3(val: Union[int, str]):
756+
if func1(val):
757+
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
758+
else:
759+
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
760+
761+
def func4(val: int):
762+
if func1(val):
763+
reveal_type(val) # N: Revealed type is "builtins.int"
764+
else:
765+
reveal_type(val) # N: Revealed type is "builtins.int"
766+
[builtins fixtures/tuple.pyi]
767+
768+
[case testTypeIsInOverloadsSameReturn]
769+
from typing import Any, overload, Union
770+
from typing_extensions import TypeGuard
771+
772+
@overload
773+
def func1(x: str) -> TypeGuard[str]:
774+
...
775+
776+
@overload
777+
def func1(x: int) -> TypeGuard[str]:
778+
...
779+
780+
def func1(x: Any) -> Any:
781+
return True
782+
783+
def func2(val: Union[int, str]):
784+
if func1(val):
785+
reveal_type(val) # N: Revealed type is "builtins.str"
786+
else:
787+
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
788+
[builtins fixtures/tuple.pyi]

Diff for: test-data/unit/check-typeis.test

+119
Original file line numberDiff line numberDiff line change
@@ -817,3 +817,122 @@ accept_typeguard(typeis) # E: Argument 1 to "accept_typeguard" has incompatible
817817
accept_typeguard(typeguard)
818818

819819
[builtins fixtures/tuple.pyi]
820+
821+
[case testTypeIsInOverloads]
822+
from typing import Any, overload, Union
823+
from typing_extensions import TypeIs
824+
825+
@overload
826+
def func1(x: str) -> TypeIs[str]:
827+
...
828+
829+
@overload
830+
def func1(x: int) -> TypeIs[int]:
831+
...
832+
833+
def func1(x: Any) -> Any:
834+
return True
835+
836+
def func2(val: Any):
837+
if func1(val):
838+
reveal_type(val) # N: Revealed type is "Union[builtins.str, builtins.int]"
839+
else:
840+
reveal_type(val) # N: Revealed type is "Any"
841+
842+
def func3(val: Union[int, str]):
843+
if func1(val):
844+
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
845+
else:
846+
reveal_type(val)
847+
848+
def func4(val: int):
849+
if func1(val):
850+
reveal_type(val) # N: Revealed type is "builtins.int"
851+
else:
852+
reveal_type(val)
853+
[builtins fixtures/tuple.pyi]
854+
855+
[case testTypeIsInOverloadsSameReturn]
856+
from typing import Any, overload, Union
857+
from typing_extensions import TypeIs
858+
859+
@overload
860+
def func1(x: str) -> TypeIs[str]:
861+
...
862+
863+
@overload
864+
def func1(x: int) -> TypeIs[str]: # type: ignore
865+
...
866+
867+
def func1(x: Any) -> Any:
868+
return True
869+
870+
def func2(val: Union[int, str]):
871+
if func1(val):
872+
reveal_type(val) # N: Revealed type is "builtins.str"
873+
else:
874+
reveal_type(val) # N: Revealed type is "builtins.int"
875+
[builtins fixtures/tuple.pyi]
876+
877+
[case testTypeIsInOverloadsUnionizeError]
878+
from typing import Any, overload, Union
879+
from typing_extensions import TypeIs, TypeGuard
880+
881+
@overload
882+
def func1(x: str) -> TypeIs[str]:
883+
...
884+
885+
@overload
886+
def func1(x: int) -> TypeGuard[int]:
887+
...
888+
889+
def func1(x: Any) -> Any:
890+
return True
891+
892+
def func2(val: Union[int, str]):
893+
if func1(val):
894+
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
895+
else:
896+
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
897+
[builtins fixtures/tuple.pyi]
898+
899+
[case testTypeIsInOverloadsUnionizeError2]
900+
from typing import Any, overload, Union
901+
from typing_extensions import TypeIs, TypeGuard
902+
903+
@overload
904+
def func1(x: int) -> TypeGuard[int]:
905+
...
906+
907+
@overload
908+
def func1(x: str) -> TypeIs[str]:
909+
...
910+
911+
def func1(x: Any) -> Any:
912+
return True
913+
914+
def func2(val: Union[int, str]):
915+
if func1(val):
916+
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
917+
else:
918+
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
919+
[builtins fixtures/tuple.pyi]
920+
921+
[case testTypeIsLikeIsDataclass]
922+
from typing import Any, overload, Union, Type
923+
from typing_extensions import TypeIs
924+
925+
class DataclassInstance: ...
926+
927+
@overload
928+
def is_dataclass(obj: type) -> TypeIs[Type[DataclassInstance]]: ...
929+
@overload
930+
def is_dataclass(obj: object) -> TypeIs[Union[DataclassInstance, Type[DataclassInstance]]]: ...
931+
932+
def is_dataclass(obj: Union[type, object]) -> bool:
933+
return False
934+
935+
def func(arg: Any) -> None:
936+
if is_dataclass(arg):
937+
reveal_type(arg) # N: Revealed type is "Union[Type[__main__.DataclassInstance], __main__.DataclassInstance]"
938+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)