|
46 | 46 | OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" |
47 | 47 |
|
48 | 48 |
|
| 49 | +def get_input_data_sharding(config, mesh): |
| 50 | + max_logging.log( |
| 51 | + "WARNING: Function maxtext_utils.get_input_data_sharding is deprecated. Please use sharding.get_input_data_sharding." |
| 52 | + ) |
| 53 | + return sharding.get_input_data_sharding(config, mesh) |
| 54 | + |
| 55 | + |
| 56 | +def assert_params_sufficiently_sharded(params, mesh, tolerance): |
| 57 | + max_logging.log( |
| 58 | + "WARNING: Function maxtext_utils.assert_params_sufficiently_sharded is deprecated." |
| 59 | + "Please use sharding.assert_params_sufficiently_sharded." |
| 60 | + ) |
| 61 | + return sharding.assert_params_sufficiently_sharded(params, mesh, tolerance) |
| 62 | + |
| 63 | + |
| 64 | +def add_data_to_sharding(mesh, path, aval, shardings): |
| 65 | + max_logging.log( |
| 66 | + "WARNING: Function maxtext_utils.add_data_to_sharding is deprecated. Please use sharding.add_data_to_sharding." |
| 67 | + ) |
| 68 | + return sharding.add_data_to_sharding(mesh, path, aval, shardings) |
| 69 | + |
| 70 | + |
| 71 | +def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): |
| 72 | + max_logging.log( |
| 73 | + "WARNING: Function maxtext_utils.maybe_update_params_sharding_with_opt is deprecated." |
| 74 | + "Please use sharding.maybe_update_params_sharding_with_opt." |
| 75 | + ) |
| 76 | + return sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) |
| 77 | + |
| 78 | + |
| 79 | +def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules): |
| 80 | + max_logging.log( |
| 81 | + "WARNING: Function maxtext_utils.all_gather_over_fsdp is deprecated. Please use sharding.all_gather_over_fsdp." |
| 82 | + ) |
| 83 | + return sharding.all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules) |
| 84 | + |
| 85 | + |
49 | 86 | def get_functional_train_with_signature( |
50 | 87 | train_step, data_sharding, state_mesh_shardings, model, config, params_shardings=None |
51 | 88 | ): |
|
0 commit comments