diff --git a/pytype/overlays/functools_overlay.py b/pytype/overlays/functools_overlay.py index bddb67709..dfc8e044e 100644 --- a/pytype/overlays/functools_overlay.py +++ b/pytype/overlays/functools_overlay.py @@ -55,7 +55,12 @@ def new_slot( self.ctx, node, new.to_variable(node), - function.Args((cls, *args), kwargs), + function.Args( + (cls, *args), + kwargs, + call_context.starargs, + call_context.starstarargs, + ), fallback_to_unsolvable=False, ) [specialized_obj] = specialized_obj.data @@ -66,7 +71,7 @@ def new_slot( return node, obj.to_variable(node) def get_own_new(self, node, value) -> tuple[cfg.CFGNode, cfg.Variable]: - new = abstract.NativeFunction("__new__", self.new_slot, self.ctx) + new = NativeFunction("__new__", self.new_slot, self.ctx) return node, new.to_variable(node) @@ -76,6 +81,8 @@ def bind_partial(node, cls, args, kwargs, ctx) -> BoundPartial: obj.underlying = args[0] obj.args = args[1:] obj.kwargs = kwargs + obj.starargs = call_context.starargs + obj.starstarargs = call_context.starstarargs return obj @@ -146,6 +153,8 @@ class BoundPartial(abstract.Instance, mixin.HasSlots): underlying: cfg.Variable args: tuple[cfg.Variable, ...] kwargs: dict[str, cfg.Variable] + starargs: cfg.Variable | None + starstarargs: cfg.Variable | None def __init__(self, ctx, cls, container=None): super().__init__(cls, ctx, container) @@ -156,7 +165,9 @@ def __init__(self, ctx, cls, container=None): def get_signatures(self) -> Sequence[function.Signature]: sigs = [] - args = function.Args(posargs=self.args, namedargs=self.kwargs) + args = function.Args( + self.args, self.kwargs, self.starargs, self.starstarargs + ) for data in self.underlying.data: for sig in function.get_signatures(data): # Use the partial arguments as defaults in the signature, making them @@ -196,6 +207,25 @@ def get_signatures(self) -> Sequence[function.Signature]: return sigs def call_slot(self, node: cfg.CFGNode, *args, **kwargs): + if self.starargs and call_context.starargs: + combined_starargs = self.ctx.convert.build_tuple( + node, + ( + abstract.Splat(self.ctx, self.starargs).to_variable(node), + abstract.Splat(self.ctx, call_context.starargs).to_variable(node), + ), + ) + else: + combined_starargs = call_context.starargs or self.starargs + + if self.starstarargs and call_context.starstarargs: + d = abstract.Dict(self.ctx) + d.update(node, self.starstarargs.data[0]) # pytype: disable=attribute-error + d.update(node, call_context.starstarargs.data[0]) + combined_starstarargs = d.to_variable(node) + else: + combined_starstarargs = call_context.starstarargs or self.starstarargs + return function.call_function( self.ctx, node, @@ -203,8 +233,8 @@ def call_slot(self, node: cfg.CFGNode, *args, **kwargs): function.Args( (*self.args, *args), {**self.kwargs, **kwargs}, - call_context.starargs, - call_context.starstarargs, + combined_starargs, + combined_starstarargs, ), fallback_to_unsolvable=False, ) diff --git a/pytype/tests/test_functions1.py b/pytype/tests/test_functions1.py index 3824e4d24..560f1ab76 100644 --- a/pytype/tests/test_functions1.py +++ b/pytype/tests/test_functions1.py @@ -1079,7 +1079,18 @@ def f(a, b=None): partial_f(0) """) - def test_functools_partial_with_starstar(self): + def test_functools_partial_starstar(self): + self.Check(""" + import functools + def f(*, a: str, b: int): + pass + + def test(**kwargs): + partial_f = functools.partial(f, **kwargs) + partial_f() + """) + + def test_functools_partial_called_with_starstar(self): self.Check(""" import functools def f(a: str, b: int, c: list): @@ -1090,6 +1101,28 @@ def test(**kwargs): partial_f(42, **kwargs) """) + def test_functools_star_everywhere(self): + self.Check(""" + import functools + def f(a: str, b: int): + pass + + def test(args, extra_args): + partial_f = functools.partial(f, *args) + partial_f(*extra_args) + """) + + def test_functools_starstar_everywhere(self): + self.Check(""" + import functools + def f(*, a: str, b: int): + pass + + def test(**kwargs): + partial_f = functools.partial(f, **kwargs) + partial_f(**kwargs) + """) + def test_functools_partial_overloaded(self): self.Check(""" import functools