@@ -226,25 +226,22 @@ def infer_constraints_for_callable(
226
226
actual_type = mapper .expand_actual_type (
227
227
actual_arg_type , arg_kinds [actual ], callee .arg_names [i ], callee .arg_kinds [i ]
228
228
)
229
- if (
230
- param_spec
231
- and callee .arg_kinds [i ] in (ARG_STAR , ARG_STAR2 )
232
- and not incomplete_star_mapping
233
- ):
229
+ if param_spec and callee .arg_kinds [i ] in (ARG_STAR , ARG_STAR2 ):
234
230
# If actual arguments are mapped to ParamSpec type, we can't infer individual
235
231
# constraints, instead store them and infer single constraint at the end.
236
232
# It is impossible to map actual kind to formal kind, so use some heuristic.
237
233
# This inference is used as a fallback, so relying on heuristic should be OK.
238
- param_spec_arg_types .append (
239
- mapper .expand_actual_type (
240
- actual_arg_type , arg_kinds [actual ], None , arg_kinds [actual ]
234
+ if not incomplete_star_mapping :
235
+ param_spec_arg_types .append (
236
+ mapper .expand_actual_type (
237
+ actual_arg_type , arg_kinds [actual ], None , arg_kinds [actual ]
238
+ )
241
239
)
242
- )
243
- actual_kind = arg_kinds [actual ]
244
- param_spec_arg_kinds .append (
245
- ARG_POS if actual_kind not in (ARG_STAR , ARG_STAR2 ) else actual_kind
246
- )
247
- param_spec_arg_names .append (arg_names [actual ] if arg_names else None )
240
+ actual_kind = arg_kinds [actual ]
241
+ param_spec_arg_kinds .append (
242
+ ARG_POS if actual_kind not in (ARG_STAR , ARG_STAR2 ) else actual_kind
243
+ )
244
+ param_spec_arg_names .append (arg_names [actual ] if arg_names else None )
248
245
else :
249
246
c = infer_constraints (callee .arg_types [i ], actual_type , SUPERTYPE_OF )
250
247
constraints .extend (c )
@@ -267,6 +264,9 @@ def infer_constraints_for_callable(
267
264
),
268
265
)
269
266
)
267
+ if any (isinstance (v , ParamSpecType ) for v in callee .variables ):
268
+ # As a perf optimization filter imprecise constraints only when we can have them.
269
+ constraints = filter_imprecise_kinds (constraints )
270
270
return constraints
271
271
272
272
@@ -1094,29 +1094,18 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
1094
1094
)
1095
1095
1096
1096
param_spec_target : Type | None = None
1097
- skip_imprecise = (
1098
- any (c .type_var == param_spec .id for c in res ) and cactual .imprecise_arg_kinds
1099
- )
1100
1097
if not cactual_ps :
1101
1098
max_prefix_len = len ([k for k in cactual .arg_kinds if k in (ARG_POS , ARG_OPT )])
1102
1099
prefix_len = min (prefix_len , max_prefix_len )
1103
- # This logic matches top-level callable constraint exception, if we managed
1104
- # to get other constraints for ParamSpec, don't infer one with imprecise kinds
1105
- if not skip_imprecise :
1106
- param_spec_target = Parameters (
1107
- arg_types = cactual .arg_types [prefix_len :],
1108
- arg_kinds = cactual .arg_kinds [prefix_len :],
1109
- arg_names = cactual .arg_names [prefix_len :],
1110
- variables = cactual .variables
1111
- if not type_state .infer_polymorphic
1112
- else [],
1113
- imprecise_arg_kinds = cactual .imprecise_arg_kinds ,
1114
- )
1100
+ param_spec_target = Parameters (
1101
+ arg_types = cactual .arg_types [prefix_len :],
1102
+ arg_kinds = cactual .arg_kinds [prefix_len :],
1103
+ arg_names = cactual .arg_names [prefix_len :],
1104
+ variables = cactual .variables if not type_state .infer_polymorphic else [],
1105
+ imprecise_arg_kinds = cactual .imprecise_arg_kinds ,
1106
+ )
1115
1107
else :
1116
- if (
1117
- len (param_spec .prefix .arg_types ) <= len (cactual_ps .prefix .arg_types )
1118
- and not skip_imprecise
1119
- ):
1108
+ if len (param_spec .prefix .arg_types ) <= len (cactual_ps .prefix .arg_types ):
1120
1109
param_spec_target = cactual_ps .copy_modified (
1121
1110
prefix = Parameters (
1122
1111
arg_types = cactual_ps .prefix .arg_types [prefix_len :],
@@ -1611,3 +1600,24 @@ def infer_callable_arguments_constraints(
1611
1600
infer_directed_arg_constraints (left_by_name .typ , right_by_name .typ , direction )
1612
1601
)
1613
1602
return res
1603
+
1604
+
1605
+ def filter_imprecise_kinds (cs : list [Constraint ]) -> list [Constraint ]:
1606
+ """For each ParamSpec remove all imprecise constraints, if at least one precise available."""
1607
+ have_precise = set ()
1608
+ for c in cs :
1609
+ if not isinstance (c .origin_type_var , ParamSpecType ):
1610
+ continue
1611
+ if (
1612
+ isinstance (c .target , ParamSpecType )
1613
+ or isinstance (c .target , Parameters )
1614
+ and not c .target .imprecise_arg_kinds
1615
+ ):
1616
+ have_precise .add (c .type_var )
1617
+ new_cs = []
1618
+ for c in cs :
1619
+ if not isinstance (c .origin_type_var , ParamSpecType ) or c .type_var not in have_precise :
1620
+ new_cs .append (c )
1621
+ if not isinstance (c .target , Parameters ) or not c .target .imprecise_arg_kinds :
1622
+ new_cs .append (c )
1623
+ return new_cs
0 commit comments