Skip to content

Commit 151fa9f

Browse files
Merge pull request #2572 from AI-Hypercomputer:chengnuojin-sharding-fix
PiperOrigin-RevId: 826515273
2 parents 57e0ece + 18fcddc commit 151fa9f

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

src/MaxText/maxtext_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,43 @@
4646
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
4747

4848

49+
def get_input_data_sharding(config, mesh):
50+
max_logging.log(
51+
"WARNING: Function maxtext_utils.get_input_data_sharding is deprecated. Please use sharding.get_input_data_sharding."
52+
)
53+
return sharding.get_input_data_sharding(config, mesh)
54+
55+
56+
def assert_params_sufficiently_sharded(params, mesh, tolerance):
57+
max_logging.log(
58+
"WARNING: Function maxtext_utils.assert_params_sufficiently_sharded is deprecated."
59+
"Please use sharding.assert_params_sufficiently_sharded."
60+
)
61+
return sharding.assert_params_sufficiently_sharded(params, mesh, tolerance)
62+
63+
64+
def add_data_to_sharding(mesh, path, aval, shardings):
65+
max_logging.log(
66+
"WARNING: Function maxtext_utils.add_data_to_sharding is deprecated. Please use sharding.add_data_to_sharding."
67+
)
68+
return sharding.add_data_to_sharding(mesh, path, aval, shardings)
69+
70+
71+
def maybe_update_params_sharding_with_opt(config, state_mesh_shardings):
72+
max_logging.log(
73+
"WARNING: Function maxtext_utils.maybe_update_params_sharding_with_opt is deprecated."
74+
"Please use sharding.maybe_update_params_sharding_with_opt."
75+
)
76+
return sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
77+
78+
79+
def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules):
80+
max_logging.log(
81+
"WARNING: Function maxtext_utils.all_gather_over_fsdp is deprecated. Please use sharding.all_gather_over_fsdp."
82+
)
83+
return sharding.all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules)
84+
85+
4986
def get_functional_train_with_signature(
5087
train_step, data_sharding, state_mesh_shardings, model, config, params_shardings=None
5188
):

0 commit comments

Comments
 (0)