|
13 | 13 | from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs |
14 | 14 |
|
15 | 15 |
|
16 | | -def propagate_tensor_meta(op, user_args, out_strat): |
17 | | - out_t = op(*user_args) |
| 16 | +def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): |
| 17 | + out_t = op(*user_args, **user_kwargs) |
18 | 18 |
|
19 | 19 | if isinstance(out_t, torch.Tensor): |
20 | 20 | new_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype) |
@@ -85,7 +85,7 @@ def fill_missing_redistribute_cost(op, specs, out_strat): |
85 | 85 | strat.redistribute_cost = redistribute_costs |
86 | 86 |
|
87 | 87 |
|
88 | | -def get_placement_options(mesh, op, specs, user_args): |
| 88 | +def get_placement_options(mesh, op, specs, user_args, user_kwargs): |
89 | 89 | # print(op) |
90 | 90 |
|
91 | 91 | if op in _op_rules: |
@@ -118,7 +118,7 @@ def get_placement_options(mesh, op, specs, user_args): |
118 | 118 | op_schema |
119 | 119 | ) |
120 | 120 |
|
121 | | - propagate_tensor_meta(op, user_args, out_strat) |
| 121 | + propagate_tensor_meta(op, user_args, user_kwargs, out_strat) |
122 | 122 | fill_missing_redistribute_cost(op, specs, out_strat) |
123 | 123 | out_strat = remove_invalid_configs(out_strat, mesh) |
124 | 124 |
|
|
0 commit comments