Skip to content

Commit 09fab9b

Browse files
authored
Pass kwargs as well to function (#30)
This should be merged directly in main
1 parent 63bdf94 commit 09fab9b

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

autoparallel/optimize_sharding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,11 @@ def build_sharding_metadata(self):
6363
user_args = tree_map_only(
6464
torch.fx.Node, lambda x: x.meta["val"], node.args
6565
)
66+
user_kwargs = tree_map_only(
67+
torch.fx.Node, lambda x: x.meta["val"], node.kwargs
68+
)
6669
strat = get_placement_options(
67-
self.mesh, node.target, user_strats, user_args
70+
self.mesh, node.target, user_strats, user_args, user_kwargs
6871
)
6972
strats[node] = strat
7073
elif node.op == "output":

autoparallel/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs
1414

1515

16-
def propagate_tensor_meta(op, user_args, out_strat):
17-
out_t = op(*user_args)
16+
def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
17+
out_t = op(*user_args, **user_kwargs)
1818

1919
if isinstance(out_t, torch.Tensor):
2020
new_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype)
@@ -85,7 +85,7 @@ def fill_missing_redistribute_cost(op, specs, out_strat):
8585
strat.redistribute_cost = redistribute_costs
8686

8787

88-
def get_placement_options(mesh, op, specs, user_args):
88+
def get_placement_options(mesh, op, specs, user_args, user_kwargs):
8989
# print(op)
9090

9191
if op in _op_rules:
@@ -118,7 +118,7 @@ def get_placement_options(mesh, op, specs, user_args):
118118
op_schema
119119
)
120120

121-
propagate_tensor_meta(op, user_args, out_strat)
121+
propagate_tensor_meta(op, user_args, user_kwargs, out_strat)
122122
fill_missing_redistribute_cost(op, specs, out_strat)
123123
out_strat = remove_invalid_configs(out_strat, mesh)
124124

0 commit comments

Comments
 (0)