Skip to content

Commit 9ed9830

Browse files
committed
Tests for Concatenate
1 parent d202d1e commit 9ed9830

File tree

3 files changed

+128
-4
lines changed

3 files changed

+128
-4
lines changed

mypy/constraints.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -586,9 +586,16 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]:
586586
else:
587587
# TODO: Direction
588588
# TODO: Deal with arguments that come before param spec ones?
589+
# TODO: check the prefixes match
590+
prefix = param_spec.prefix
591+
prefix_len = len(prefix.arg_types)
589592
res.append(Constraint(param_spec.id,
590593
SUBTYPE_OF,
591-
cactual.copy_modified(ret_type=NoneType())))
594+
cactual.copy_modified(
595+
arg_types=cactual.arg_types[prefix_len:],
596+
arg_kinds=cactual.arg_kinds[prefix_len:],
597+
arg_names=cactual.arg_names[prefix_len:],
598+
ret_type=NoneType())))
592599

593600
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
594601
if template.type_guard is not None:

mypy/typeanal.py

+39
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,41 @@ def analyze_callable_args_for_paramspec(
797797
fallback=fallback,
798798
)
799799

800+
def analyze_callable_args_for_concatenate(
801+
self,
802+
callable_args: Type,
803+
ret_type: Type,
804+
fallback: Instance,
805+
) -> Optional[CallableType]:
806+
"""Construct a 'Callable[C, RET]', where C is Concatenate[..., P], return None if we cannot."""
807+
if not isinstance(callable_args, UnboundType):
808+
return None
809+
sym = self.lookup_qualified(callable_args.name, callable_args)
810+
if sym is None:
811+
return None
812+
if sym.node.fullname not in ("typing_extensions.Concatenate", "typing.Concatenate"):
813+
return None
814+
815+
tvar_def = self.anal_type(callable_args, allow_param_spec=True)
816+
if not isinstance(tvar_def, ParamSpecType):
817+
return None
818+
819+
# TODO: Use tuple[...] or Mapping[..] instead?
820+
obj = self.named_type('builtins.object')
821+
# ick, CallableType should take ParamSpecType
822+
prefix = tvar_def.prefix
823+
return CallableType(
824+
[*prefix.arg_types,
825+
ParamSpecType(tvar_def.name, tvar_def.fullname, tvar_def.id, ParamSpecFlavor.ARGS,
826+
upper_bound=obj, prefix=tvar_def.prefix),
827+
ParamSpecType(tvar_def.name, tvar_def.fullname, tvar_def.id, ParamSpecFlavor.KWARGS,
828+
upper_bound=obj)],
829+
[*prefix.arg_kinds, nodes.ARG_STAR, nodes.ARG_STAR2],
830+
[*prefix.arg_names, None, None],
831+
ret_type=ret_type,
832+
fallback=fallback,
833+
)
834+
800835
def analyze_callable_type(self, t: UnboundType) -> Type:
801836
fallback = self.named_type('builtins.function')
802837
if len(t.args) == 0:
@@ -828,6 +863,10 @@ def analyze_callable_type(self, t: UnboundType) -> Type:
828863
callable_args,
829864
ret_type,
830865
fallback
866+
) or self.analyze_callable_args_for_concatenate(
867+
callable_args,
868+
ret_type,
869+
fallback
831870
)
832871
if maybe_ret is None:
833872
# Callable[?, RET] (where ? is something invalid)

test-data/unit/check-parameter-specification.test

+81-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ def foo1(x: Callable[P, int]) -> Callable[P, str]: ...
1515
def foo2(x: P) -> P: ... # E: Invalid location for ParamSpec "P" \
1616
# N: You can use ParamSpec as the first argument to Callable, e.g., 'Callable[P, int]'
1717

18-
# TODO(PEP612): uncomment once we have support for Concatenate
19-
# def foo3(x: Concatenate[int, P]) -> int: ... $ E: Invalid location for Concatenate
18+
# TODO: Better error message
19+
def foo3(x: Concatenate[int, P]) -> int: ... # E: Invalid location for ParamSpec "P" \
20+
# N: You can use ParamSpec as the first argument to Callable, e.g., 'Callable[P, int]'
2021

2122
def foo4(x: List[P]) -> None: ... # E: Invalid location for ParamSpec "P" \
2223
# N: You can use ParamSpec as the first argument to Callable, e.g., 'Callable[P, int]'
@@ -455,5 +456,82 @@ reveal_type(kb(n)) # N: Revealed type is "__main__.Z[[builtins.str]]" \
455456
# TODO(PEP612): fancy "aesthetic" syntax defined in PEP
456457
# n2: Z[bytes]
457458
#
458-
# reveal_type(kb(n2)) # N: Revealed type is "__main__.Z[[builtins.str]]"
459+
# reveal_type(kb(n2)) $ N: Revealed type is "__main__.Z[[builtins.str]]"
459460
[builtins fixtures/tuple.pyi]
461+
462+
[case testParamSpecConcatenateFromPep]
463+
from typing_extensions import ParamSpec, Concatenate
464+
from typing import Callable, TypeVar, Generic
465+
466+
P = ParamSpec("P")
467+
R = TypeVar("R")
468+
469+
# CASE 1
470+
class Request:
471+
...
472+
473+
def with_request(f: Callable[Concatenate[Request, P], R]) -> Callable[P, R]:
474+
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
475+
return f(Request(), *args, **kwargs)
476+
return inner
477+
478+
@with_request
479+
def takes_int_str(request: Request, x: int, y: str) -> int:
480+
# use request
481+
return x + 7
482+
483+
reveal_type(takes_int_str) # N: Revealed type is "def (x: builtins.int, y: builtins.str) -> builtins.int*"
484+
485+
takes_int_str(1, "A") # Accepted
486+
takes_int_str("B", 2) # E: Argument 1 to "takes_int_str" has incompatible type "str"; expected "int" \
487+
# E: Argument 2 to "takes_int_str" has incompatible type "int"; expected "str"
488+
489+
# CASE 2
490+
T = TypeVar("T")
491+
P_2 = ParamSpec("P_2")
492+
493+
class X(Generic[T, P]):
494+
f: Callable[P, int]
495+
x: T
496+
497+
def f1(x: X[int, P_2]) -> str: ... # Accepted
498+
def f2(x: X[int, Concatenate[int, P_2]]) -> str: ... # Accepted
499+
def f3(x: X[int, [int, bool]]) -> str: ... # Accepted
500+
# Is ellipsis allowed by PEP? This shows up:
501+
# def f4(x: X[int, ...]) -> str: ... # Accepted
502+
# TODO: this is not rejected:
503+
# def f5(x: X[int, int]) -> str: ... # Rejected
504+
505+
# CASE 3
506+
def bar(x: int, *args: bool) -> int: ...
507+
def add(x: Callable[P, int]) -> Callable[Concatenate[str, P], bool]: ...
508+
509+
reveal_type(add(bar)) # N: Revealed type is "def (builtins.str, x: builtins.int, *args: builtins.bool) -> builtins.bool"
510+
511+
def remove(x: Callable[Concatenate[int, P], int]) -> Callable[P, bool]: ...
512+
513+
reveal_type(remove(bar)) # N: Revealed type is "def (*args: builtins.bool) -> builtins.bool"
514+
515+
def transform(
516+
x: Callable[Concatenate[int, P], int]
517+
) -> Callable[Concatenate[str, P], bool]: ...
518+
519+
# In the PEP, "__a" appears. What is that? Autogenerated names? To what spec?
520+
reveal_type(transform(bar)) # N: Revealed type is "def (builtins.str, *args: builtins.bool) -> builtins.bool"
521+
522+
# CASE 4
523+
def expects_int_first(x: Callable[Concatenate[int, P], int]) -> None: ...
524+
525+
@expects_int_first # E: Argument 1 to "expects_int_first" has incompatible type "Callable[[str], int]"; expected "Callable[[int], int]"
526+
def one(x: str) -> int: ...
527+
528+
@expects_int_first # E: Argument 1 to "expects_int_first" has incompatible type "Callable[[NamedArg(int, 'x')], int]"; expected "Callable[[int], int]"
529+
def two(*, x: int) -> int: ...
530+
531+
@expects_int_first # E: Argument 1 to "expects_int_first" has incompatible type "Callable[[KwArg(int)], int]"; expected "Callable[[int], int]"
532+
def three(**kwargs: int) -> int: ...
533+
534+
@expects_int_first # Accepted
535+
def four(*args: int) -> int: ...
536+
[builtins fixtures/tuple.pyi]
537+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)