Skip to content

Commit 7051084

Browse files
superbobrycopybara-github
authored andcommitted
Forward *args/**kwargs from functools.partial to the wrapper
PiperOrigin-RevId: 821720894
1 parent 68ece50 commit 7051084

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

pytype/overlays/functools_overlay.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ def new_slot(
5555
self.ctx,
5656
node,
5757
new.to_variable(node),
58-
function.Args((cls, *args), kwargs),
58+
function.Args(
59+
(cls, *args),
60+
kwargs,
61+
call_context.starargs,
62+
call_context.starstarargs,
63+
),
5964
fallback_to_unsolvable=False,
6065
)
6166
[specialized_obj] = specialized_obj.data
@@ -66,7 +71,7 @@ def new_slot(
6671
return node, obj.to_variable(node)
6772

6873
def get_own_new(self, node, value) -> tuple[cfg.CFGNode, cfg.Variable]:
69-
new = abstract.NativeFunction("__new__", self.new_slot, self.ctx)
74+
new = NativeFunction("__new__", self.new_slot, self.ctx)
7075
return node, new.to_variable(node)
7176

7277

@@ -76,6 +81,8 @@ def bind_partial(node, cls, args, kwargs, ctx) -> BoundPartial:
7681
obj.underlying = args[0]
7782
obj.args = args[1:]
7883
obj.kwargs = kwargs
84+
obj.starargs = call_context.starargs
85+
obj.starstarargs = call_context.starstarargs
7986
return obj
8087

8188

@@ -146,6 +153,8 @@ class BoundPartial(abstract.Instance, mixin.HasSlots):
146153
underlying: cfg.Variable
147154
args: tuple[cfg.Variable, ...]
148155
kwargs: dict[str, cfg.Variable]
156+
starargs: cfg.Variable | None
157+
starstarargs: cfg.Variable | None
149158

150159
def __init__(self, ctx, cls, container=None):
151160
super().__init__(cls, ctx, container)
@@ -156,7 +165,9 @@ def __init__(self, ctx, cls, container=None):
156165

157166
def get_signatures(self) -> Sequence[function.Signature]:
158167
sigs = []
159-
args = function.Args(posargs=self.args, namedargs=self.kwargs)
168+
args = function.Args(
169+
self.args, self.kwargs, self.starargs, self.starstarargs
170+
)
160171
for data in self.underlying.data:
161172
for sig in function.get_signatures(data):
162173
# Use the partial arguments as defaults in the signature, making them
@@ -196,15 +207,34 @@ def get_signatures(self) -> Sequence[function.Signature]:
196207
return sigs
197208

198209
def call_slot(self, node: cfg.CFGNode, *args, **kwargs):
210+
if self.starargs and call_context.starargs:
211+
combined_starargs = self.ctx.convert.build_tuple(
212+
node,
213+
(
214+
abstract.Splat(self.ctx, self.starargs).to_variable(node),
215+
abstract.Splat(self.ctx, call_context.starargs).to_variable(node),
216+
),
217+
)
218+
else:
219+
combined_starargs = call_context.starargs or self.starargs
220+
221+
if self.starstarargs and call_context.starstarargs:
222+
d = abstract.Dict(self.ctx)
223+
d.update(node, self.starstarargs.data[0]) # pytype: disable=attribute-error
224+
d.update(node, call_context.starstarargs.data[0])
225+
combined_starstarargs = d.to_variable(node)
226+
else:
227+
combined_starstarargs = call_context.starstarargs or self.starstarargs
228+
199229
return function.call_function(
200230
self.ctx,
201231
node,
202232
self.underlying,
203233
function.Args(
204234
(*self.args, *args),
205235
{**self.kwargs, **kwargs},
206-
call_context.starargs,
207-
call_context.starstarargs,
236+
combined_starargs,
237+
combined_starstarargs,
208238
),
209239
fallback_to_unsolvable=False,
210240
)

pytype/tests/test_functions1.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,18 @@ def f(a, b=None):
10791079
partial_f(0)
10801080
""")
10811081

1082-
def test_functools_partial_with_starstar(self):
1082+
def test_functools_partial_starstar(self):
1083+
self.Check("""
1084+
import functools
1085+
def f(*, a: str, b: int):
1086+
pass
1087+
1088+
def test(**kwargs):
1089+
partial_f = functools.partial(f, **kwargs)
1090+
partial_f()
1091+
""")
1092+
1093+
def test_functools_partial_called_with_starstar(self):
10831094
self.Check("""
10841095
import functools
10851096
def f(a: str, b: int, c: list):
@@ -1090,6 +1101,28 @@ def test(**kwargs):
10901101
partial_f(42, **kwargs)
10911102
""")
10921103

1104+
def test_functools_star_everywhere(self):
1105+
self.Check("""
1106+
import functools
1107+
def f(a: str, b: int):
1108+
pass
1109+
1110+
def test(args, extra_args):
1111+
partial_f = functools.partial(f, *args)
1112+
partial_f(*extra_args)
1113+
""")
1114+
1115+
def test_functools_starstar_everywhere(self):
1116+
self.Check("""
1117+
import functools
1118+
def f(*, a: str, b: int):
1119+
pass
1120+
1121+
def test(**kwargs):
1122+
partial_f = functools.partial(f, **kwargs)
1123+
partial_f(**kwargs)
1124+
""")
1125+
10931126
def test_functools_partial_overloaded(self):
10941127
self.Check("""
10951128
import functools

0 commit comments

Comments
 (0)