|
| 1 | +""" |
| 2 | + Copyright 2025 Google LLC |
| 3 | +
|
| 4 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + you may not use this file except in compliance with the License. |
| 6 | + You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | + Unless required by applicable law or agreed to in writing, software |
| 11 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + See the License for the specific language governing permissions and |
| 14 | + limitations under the License. |
| 15 | + """ |
| 16 | + |
| 17 | +import datetime |
| 18 | +import sys |
| 19 | +import os |
| 20 | + |
| 21 | +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| 22 | +sys.path.append(parent_dir) |
| 23 | + |
| 24 | +import recipes.args_helper as helper |
| 25 | +import maxtext_trillium_model_configs as model_configs |
| 26 | +import maxtext_xpk_runner as mxr |
| 27 | +from xpk_configs import XpkClusterConfig |
| 28 | + |
| 29 | +# Cluster Params |
| 30 | +CLUSTER = "v6e-256-cluster" |
| 31 | +PROJECT = "tpu-prod-env-cluster" |
| 32 | +ZONE = "us-east5-b" |
| 33 | +REGION = "us-east5" |
| 34 | +COUNTRY = "us" |
| 35 | +DEVICE_TYPE = "v6e-256" |
| 36 | + |
| 37 | +# Other parameters (MUST BE SET BY USER) |
| 38 | +XPK_PATH = os.path.join("~", "xpk") |
| 39 | +USER = os.environ["USER"] |
| 40 | +BASE_OUTPUT_DIRECTORY = ( |
| 41 | + f"gs://{USER}-{PROJECT}-{COUNTRY}/mcjax_long_run/" |
| 42 | +) |
| 43 | +# Generate your own runner image from MaxText repo. |
| 44 | +RUNNER = f"gcr.io/{PROJECT}/{USER}_latest" |
| 45 | + |
| 46 | +MAX_RESTARTS = 10_000 |
| 47 | +BENCHMARK_STEPS=10_000_000 |
| 48 | + |
| 49 | + |
| 50 | +def main() -> int: |
| 51 | + # V6e cluster config |
| 52 | + cluster_config = XpkClusterConfig( |
| 53 | + cluster_name=CLUSTER, |
| 54 | + project=PROJECT, |
| 55 | + zone=ZONE, |
| 56 | + device_type=DEVICE_TYPE, |
| 57 | + ) |
| 58 | + |
| 59 | + # Handle command line arguments using args_helper |
| 60 | + should_continue = helper.handle_cmd_args( |
| 61 | + cluster_config, helper.DELETE, xpk_path=XPK_PATH |
| 62 | + ) |
| 63 | + |
| 64 | + if not should_continue: |
| 65 | + return 0 |
| 66 | + |
| 67 | + model_list = [ |
| 68 | + # model_configs.llama3_1_70b_8192_pw_lr_real_data, |
| 69 | + # model_configs.llama3_1_8b_8192, |
| 70 | + model_configs.llama3_1_70b_8192_iter_synth_data_and_checkpointing, |
| 71 | + # model_configs.llama3_1_70b_8192_iter_real_data_and_checkpointing_tfds, |
| 72 | + ] |
| 73 | + num_slices_list = [ |
| 74 | + 2 |
| 75 | + ] |
| 76 | + |
| 77 | + xpk_workload_cmds = [] |
| 78 | + xpk_workload_names = [] |
| 79 | + |
| 80 | + for model in model_list: |
| 81 | + # Run workloads on the below clusters |
| 82 | + for cluster_config in [ |
| 83 | + cluster_config, |
| 84 | + ]: |
| 85 | + |
| 86 | + # Make modifications to the model config here to add in any additional |
| 87 | + # flags or changes to the model config. |
| 88 | + model.tuning_params["use_vertex_tensorboard"] = True |
| 89 | + model.tuning_params["vertex_tensorboard_project"] = PROJECT |
| 90 | + model.tuning_params["vertex_tensorboard_region"] = REGION |
| 91 | + |
| 92 | + # Run workloads in the following slice configurations |
| 93 | + for num_slices in num_slices_list: |
| 94 | + wl_config = mxr.WorkloadConfig( |
| 95 | + model=model, |
| 96 | + num_slices=num_slices, |
| 97 | + device_type=cluster_config.device_type, |
| 98 | + base_output_directory=BASE_OUTPUT_DIRECTORY, |
| 99 | + max_restarts=MAX_RESTARTS, |
| 100 | + libtpu_type=mxr.LibTpuType.MAXTEXT, |
| 101 | + libtpu_nightly_version="", |
| 102 | + base_docker_image=RUNNER, |
| 103 | + xpk_path=XPK_PATH, |
| 104 | + num_steps=BENCHMARK_STEPS, |
| 105 | + priority="medium", |
| 106 | + ) |
| 107 | + command, name = mxr.generate_xpk_workload_cmd( |
| 108 | + cluster_config=cluster_config, wl_config=wl_config |
| 109 | + ) |
| 110 | + |
| 111 | + print(f"Name of the workload is: {name} \n") |
| 112 | + xpk_workload_names.append(name) |
| 113 | + |
| 114 | + print(f"XPK command to be used is: {command} \n") |
| 115 | + xpk_workload_cmds.append(command) |
| 116 | + |
| 117 | + for xpk_workload_name, xpk_workload_cmd in zip( |
| 118 | + xpk_workload_names, xpk_workload_cmds |
| 119 | + ): |
| 120 | + timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| 121 | + print( |
| 122 | + f"[{timestamp}] Running workload: {xpk_workload_name} with command:" |
| 123 | + f" {xpk_workload_cmd}" |
| 124 | + ) |
| 125 | + return_code = mxr.run_command_with_updates( |
| 126 | + xpk_workload_cmd, xpk_workload_name |
| 127 | + ) |
| 128 | + if return_code != 0: |
| 129 | + print(f"Unable to run xpk workload: {xpk_workload_name}") |
| 130 | + |
| 131 | + |
| 132 | +if __name__ == "__main__": |
| 133 | + main() |
0 commit comments