Skip to content

Commit 3e6b552

Browse files
authored
Make imprecise constraints handling more robust (#16502)
Fixes #16485 My initial implementation of imprecise constraints fallback was really fragile and ad-hoc, and I now see several edge case scenarios where we may end up using imprecise constraints for a `ParamSpec` while some precise ones are available. So I re-organized it: now we just infer everything as normally, and filter out imprecise (if needed) at the very end, when we have the full picture. I also fix an accidental omission in `expand_type()`.
1 parent a3e488d commit 3e6b552

File tree

3 files changed

+67
-33
lines changed

3 files changed

+67
-33
lines changed

mypy/constraints.py

+43-33
Original file line numberDiff line numberDiff line change
@@ -226,25 +226,22 @@ def infer_constraints_for_callable(
226226
actual_type = mapper.expand_actual_type(
227227
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
228228
)
229-
if (
230-
param_spec
231-
and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2)
232-
and not incomplete_star_mapping
233-
):
229+
if param_spec and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2):
234230
# If actual arguments are mapped to ParamSpec type, we can't infer individual
235231
# constraints, instead store them and infer single constraint at the end.
236232
# It is impossible to map actual kind to formal kind, so use some heuristic.
237233
# This inference is used as a fallback, so relying on heuristic should be OK.
238-
param_spec_arg_types.append(
239-
mapper.expand_actual_type(
240-
actual_arg_type, arg_kinds[actual], None, arg_kinds[actual]
234+
if not incomplete_star_mapping:
235+
param_spec_arg_types.append(
236+
mapper.expand_actual_type(
237+
actual_arg_type, arg_kinds[actual], None, arg_kinds[actual]
238+
)
241239
)
242-
)
243-
actual_kind = arg_kinds[actual]
244-
param_spec_arg_kinds.append(
245-
ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind
246-
)
247-
param_spec_arg_names.append(arg_names[actual] if arg_names else None)
240+
actual_kind = arg_kinds[actual]
241+
param_spec_arg_kinds.append(
242+
ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind
243+
)
244+
param_spec_arg_names.append(arg_names[actual] if arg_names else None)
248245
else:
249246
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
250247
constraints.extend(c)
@@ -267,6 +264,9 @@ def infer_constraints_for_callable(
267264
),
268265
)
269266
)
267+
if any(isinstance(v, ParamSpecType) for v in callee.variables):
268+
# As a perf optimization filter imprecise constraints only when we can have them.
269+
constraints = filter_imprecise_kinds(constraints)
270270
return constraints
271271

272272

@@ -1094,29 +1094,18 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
10941094
)
10951095

10961096
param_spec_target: Type | None = None
1097-
skip_imprecise = (
1098-
any(c.type_var == param_spec.id for c in res) and cactual.imprecise_arg_kinds
1099-
)
11001097
if not cactual_ps:
11011098
max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)])
11021099
prefix_len = min(prefix_len, max_prefix_len)
1103-
# This logic matches top-level callable constraint exception, if we managed
1104-
# to get other constraints for ParamSpec, don't infer one with imprecise kinds
1105-
if not skip_imprecise:
1106-
param_spec_target = Parameters(
1107-
arg_types=cactual.arg_types[prefix_len:],
1108-
arg_kinds=cactual.arg_kinds[prefix_len:],
1109-
arg_names=cactual.arg_names[prefix_len:],
1110-
variables=cactual.variables
1111-
if not type_state.infer_polymorphic
1112-
else [],
1113-
imprecise_arg_kinds=cactual.imprecise_arg_kinds,
1114-
)
1100+
param_spec_target = Parameters(
1101+
arg_types=cactual.arg_types[prefix_len:],
1102+
arg_kinds=cactual.arg_kinds[prefix_len:],
1103+
arg_names=cactual.arg_names[prefix_len:],
1104+
variables=cactual.variables if not type_state.infer_polymorphic else [],
1105+
imprecise_arg_kinds=cactual.imprecise_arg_kinds,
1106+
)
11151107
else:
1116-
if (
1117-
len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types)
1118-
and not skip_imprecise
1119-
):
1108+
if len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types):
11201109
param_spec_target = cactual_ps.copy_modified(
11211110
prefix=Parameters(
11221111
arg_types=cactual_ps.prefix.arg_types[prefix_len:],
@@ -1611,3 +1600,24 @@ def infer_callable_arguments_constraints(
16111600
infer_directed_arg_constraints(left_by_name.typ, right_by_name.typ, direction)
16121601
)
16131602
return res
1603+
1604+
1605+
def filter_imprecise_kinds(cs: list[Constraint]) -> list[Constraint]:
1606+
"""For each ParamSpec remove all imprecise constraints, if at least one precise available."""
1607+
have_precise = set()
1608+
for c in cs:
1609+
if not isinstance(c.origin_type_var, ParamSpecType):
1610+
continue
1611+
if (
1612+
isinstance(c.target, ParamSpecType)
1613+
or isinstance(c.target, Parameters)
1614+
and not c.target.imprecise_arg_kinds
1615+
):
1616+
have_precise.add(c.type_var)
1617+
new_cs = []
1618+
for c in cs:
1619+
if not isinstance(c.origin_type_var, ParamSpecType) or c.type_var not in have_precise:
1620+
new_cs.append(c)
1621+
if not isinstance(c.target, Parameters) or not c.target.imprecise_arg_kinds:
1622+
new_cs.append(c)
1623+
return new_cs

mypy/expandtype.py

+1
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
253253
t.prefix.arg_kinds + repl.arg_kinds,
254254
t.prefix.arg_names + repl.arg_names,
255255
variables=[*t.prefix.variables, *repl.variables],
256+
imprecise_arg_kinds=repl.imprecise_arg_kinds,
256257
)
257258
else:
258259
# We could encode Any as trivial parameters etc., but it would be too verbose.

test-data/unit/check-parameter-specification.test

+23
Original file line numberDiff line numberDiff line change
@@ -2163,3 +2163,26 @@ def func2(arg: T) -> List[Union[T, str]]:
21632163
reveal_type(func2) # N: Revealed type is "def [T] (arg: T`-1) -> Union[T`-1, builtins.str]"
21642164
reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]"
21652165
[builtins fixtures/paramspec.pyi]
2166+
2167+
[case testParamSpecPreciseKindsUsedIfPossible]
2168+
from typing import Callable, Generic
2169+
from typing_extensions import ParamSpec
2170+
2171+
P = ParamSpec('P')
2172+
2173+
class Case(Generic[P]):
2174+
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
2175+
pass
2176+
2177+
def _test(a: int, b: int = 0) -> None: ...
2178+
2179+
def parametrize(
2180+
func: Callable[P, None], *cases: Case[P], **named_cases: Case[P]
2181+
) -> Callable[[], None]:
2182+
...
2183+
2184+
parametrize(_test, Case(1, 2), Case(3, 4))
2185+
parametrize(_test, Case(1, b=2), Case(3, b=4))
2186+
parametrize(_test, Case(1, 2), Case(3))
2187+
parametrize(_test, Case(1, 2), Case(3, b=4))
2188+
[builtins fixtures/paramspec.pyi]

0 commit comments

Comments
 (0)