diff --git a/benchmarks/globals.py b/benchmarks/globals.py index ba3a625b7..90528eadf 100644 --- a/benchmarks/globals.py +++ b/benchmarks/globals.py @@ -17,7 +17,7 @@ import os.path # This is the MaxText root: with "max_utils.py"; &etc. TODO: Replace `os.path.basename` with `os.path.abspath` -MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "MaxText") +MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "src/MaxText") # This is the maxtext repo root: with ".git" folder; "README.md"; "pyproject.toml"; &etc. MAXTEXT_REPO_ROOT = os.environ.get( diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index d968d0490..bebcd8893 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -583,6 +583,8 @@ def generate_xpk_workload_cmd( cluster_config: XpkClusterConfig, wl_config: WorkloadConfig, workload_name=None, + user=os.environ['USER'], + temp_key=None, exp_name=None, ): """Generates a command to run a maxtext model on XPK.""" @@ -592,12 +594,16 @@ def generate_xpk_workload_cmd( time.localtime() length_of_random_str = 3 - temp_post_fix = "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length_of_random_str)) + # In order to allow the DAG to obtain the actual workload name for deletion, instead of automatically generating a random ID. + if temp_key is not None: + temp_post_fix = temp_key + else: + temp_post_fix = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length_of_random_str)) truncate_model_name = 10 truncate_prefix = 3 post_fix = f"-{wl_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}" - common_prefix = os.environ["USER"] + common_prefix = user pw_prefix = "pw-" if workload_name is None: # Generate name if not provided diff --git a/benchmarks/recipes/args_helper.py b/benchmarks/recipes/args_helper.py index 0789b19e7..931e4987c 100644 --- a/benchmarks/recipes/args_helper.py +++ b/benchmarks/recipes/args_helper.py @@ -66,41 +66,19 @@ def handle_delete_specific_workload(cluster_config: XpkClusterConfig, workload_n os.system(f"yes | {delete_command}") -def handle_cmd_args(cluster_config: XpkClusterConfig, *actions: str, **kwargs) -> bool: +def handle_cmd_args(cluster_config: XpkClusterConfig, is_delete: bool, user: str, **kwargs) -> bool: """Parses command-line arguments and executes the specified actions. Args: cluster_config: Contains Cluster configuration information that's helpful for running the actions. - *actions: Variable number of string arguments representing the actions to - be performed. + is_delete: A boolean indicating whether the delete action should be + performed. **kwargs: Optional keyword arguments to be passed to action handlers. - - Raises: - ValueError: If an unsupported action is provided or if unknown arguments are - passed. """ - - parser = argparse.ArgumentParser() - - if DELETE in actions: - parser.add_argument( - "--delete", - action="store_true", - help="Delete workloads starting with the user's first five characters.", - ) - - known_args, unknown_args = parser.parse_known_args() - - if unknown_args: - raise ValueError(f"Unrecognized arguments: {unknown_args}") - - # Get user - user = os.environ["USER"] - # Handle actions should_continue = True - if DELETE in actions and known_args.delete: + if is_delete: _handle_delete(cluster_config, user, **kwargs) should_continue = False diff --git a/benchmarks/recipes/parser_utils.py b/benchmarks/recipes/parser_utils.py new file mode 100644 index 000000000..789eced47 --- /dev/null +++ b/benchmarks/recipes/parser_utils.py @@ -0,0 +1,182 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides utility functions for custom argument parsing and defines a comprehensive set of command-line arguments for configuring a machine learning workload. +""" + +import argparse + +def parse_int_list(arg): + """Parses a string with comma-separated values into a list of integers.""" + return [int(x) for x in arg.split(',')] + +def parse_str_list(arg): + """Parses a string with space-separated values into a list of strings.""" + return [s.strip() for s in arg.split(',')] + +def str2bool(v): + """Parses a string representation of a boolean value into a Python boolean.""" + if isinstance(v, bool): + return v + if v.lower() in ('true'): + return True + elif v.lower() in ('false'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected (e.g., True or False).') + +def add_arguments(parser: argparse.ArgumentParser): + """Add arguments to arg parsers that need it. + + Args: + parser: parser to add shared arguments to. + """ + # Add the arguments for each parser. + # GCP Configuration + parser.add_argument( + '--user', + type=str, + default='user_name', + help='GCP user name.') + parser.add_argument( + '--cluster_name', + type=str, + default='test-v5e-32-cluster', + help='Name of the TPU cluster.') + parser.add_argument( + '--project', + type=str, + default='cloud-tpu-cluster', + help='GCP project ID.') + parser.add_argument( + '--zone', + type=str, + default='us-south1-a', + help='GCP zone for the cluster.') + parser.add_argument( + '--device_type', + type=str, + default='v5litepod-32', + help='Type of TPU device (e.g., v5litepod-32).') + parser.add_argument( + '--priority', + type=str, + choices=['low', 'medium', 'high', 'very high'], + default='medium', + help='Priority of the job.') + + # Image Configuration + parser.add_argument( + '--server_image', + type=str, + default='us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server', + help='Docker image for the proxy server.') + parser.add_argument( + '--proxy_image', + type=str, + default='us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server', + help='Docker image for the server.') + parser.add_argument( + '--runner', + type=str, + default='us-docker.pkg.dev/path/to/maxtext_runner', + help='Docker image for the runner.') + parser.add_argument( + '--colocated_python_image', + type=str, + default=None, + help='Colocated Python image.') + parser.add_argument( + '--worker_flags', + type=str, + default='', + help='Worker flags.') + parser.add_argument( + '--proxy_flags', + type=str, + default='', + help='Proxy flags.') + parser.add_argument( + '--server_flags', + type=str, + default='', + help='Server flags.') + + # Model Configuration + parser.add_argument( + '--benchmark_steps', + type=int, + default=20, + help='Number of benchmark steps.') + parser.add_argument( + '--headless', + action=argparse.BooleanOptionalAction, + default=False, + help='Run in headless mode.') + parser.add_argument( + '--selected_model_framework', + type=parse_str_list, + default=['pathways'], + help='List of model frameworks (e.g., pathways, mcjax') + parser.add_argument( + '--selected_model_names', + type=parse_str_list, + default=['llama3_1_8b_8192_v5e_256'], + help='List of model names (e.g., llama3_1_8b_8192_v5e_256, llama2-7b-v5e-256') + parser.add_argument( + '--num_slices_list', + type=parse_int_list, + default=[2], + help='List of number of slices.') + + # BigQuery configuration + parser.add_argument( + '--bq_enable', + type=str2bool, + default=False, + help='Enable BigQuery logging. Must be True or False. Defaults to False.') + + parser.add_argument( + '--bq_db_project', + type=str, + default='', + help='BigQuery project ID where the logging dataset resides.') + + parser.add_argument( + '--bq_db_dataset', + type=str, + default='', + help='BigQuery dataset name where metrics will be written.') + + # Other configurations + parser.add_argument( + '--xpk_path', + type=str, + default='~/xpk', + help='Path to xpk.') + parser.add_argument( + '--delete', + action='store_true', + help='Delete the cluster workload') + parser.add_argument( + '--max_restarts', + type=int, + default=0, + help='Maximum number of restarts') + parser.add_argument( + '--temp_key', + type=str, + default=None, + help='Temporary placeholder code') \ No newline at end of file diff --git a/benchmarks/recipes/pw_mcjax_benchmark_recipe.py b/benchmarks/recipes/pw_mcjax_benchmark_recipe.py index 83fcc5269..7248e6dfe 100644 --- a/benchmarks/recipes/pw_mcjax_benchmark_recipe.py +++ b/benchmarks/recipes/pw_mcjax_benchmark_recipe.py @@ -15,33 +15,60 @@ """Used to perf benchmarks between Pathways and McJax.""" import os import sys - parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) from . import args_helper as helper -from . import user_configs +from .user_configs import UserConfig +from .user_configs import USER_CONFIG from .runner_utils import generate_and_run_workloads +from . import parser_utils +import argparse +from google.cloud import storage +from .pw_utils import check_and_create_bucket - -def main() -> int: +def main(user_config) -> int: """Main program entry point""" - user_configs.USER_CONFIG.headless = False - should_continue = helper.handle_cmd_args( - user_configs.USER_CONFIG.cluster_config, helper.DELETE, xpk_path=user_configs.USER_CONFIG.xpk_path - ) - - if not should_continue: - return 0 - - return_code = generate_and_run_workloads( - user_configs.USER_CONFIG, - user_configs.USER_CONFIG.num_slices_list, - user_configs.USER_CONFIG.benchmark_steps, - user_configs.USER_CONFIG.priority, - ) + storage_client = storage.Client(project=user_config.project) + check_and_create_bucket(storage_client, user_config.base_output_directory[5:].split('/')[0], user_config.region) + return_code = generate_and_run_workloads(user_config, user_config.num_slices_list, user_config.benchmark_steps, user_config.priority) return return_code if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="Used to perf benchmarks between Pathways and McJax.") + parser_utils.add_arguments(parser) + args = parser.parse_args() + + if len(sys.argv) > 2: + print("Multiple command line arguments detected. Custom configuration will be used.") + user_config = UserConfig(**vars(args)) + should_continue = helper.handle_cmd_args( + user_config.cluster_config, + is_delete=user_config.delete, + user=user_config.user, + xpk_path=user_config.xpk_path + ) + if not should_continue: + sys.exit(0) + + print(f"configuration used:{user_config}") + return_code = main(user_config) + sys.exit(return_code) + + else: + print("No command line or only a single --delete argument was detected. The default configuration will be used.") + user_config = USER_CONFIG + if "--delete" in sys.argv: + user_config.delete = True + should_continue = helper.handle_cmd_args( + user_config.cluster_config, + is_delete=user_config.delete, + user=user_config.user, + xpk_path=user_config.xpk_path + ) + if not should_continue: + sys.exit(0) + print(f"configuration used:{user_config}") + return_code= main(user_config) + sys.exit(return_code) diff --git a/benchmarks/recipes/pw_utils.py b/benchmarks/recipes/pw_utils.py index 7a8ece7bd..90cb9ccfc 100644 --- a/benchmarks/recipes/pw_utils.py +++ b/benchmarks/recipes/pw_utils.py @@ -14,14 +14,13 @@ """ This module provides utility functions for Pathways-related benchmark recipes. - -It includes helpers for building lists of model configurations based on user -selections and for generating `XpkClusterConfig` and `PathwaysConfig` objects. """ import typing import maxtext_xpk_runner as mxr +from google.cloud import storage +from google.cloud.exceptions import NotFound def build_user_models( @@ -108,3 +107,42 @@ def get_pathways_config( worker_flags=worker_flags, ) return pathways_config + +def check_and_create_bucket(storage_client, bucket_name, region): + """ + Checks if the GCS bucket exists. + Prints a message if it exists. + Attempts to create the bucket if it does not exist. + """ + print(f"Checking GCS bucket: {bucket_name}...") + + try: + # Attempt to retrieve the bucket's metadata + bucket = storage_client.get_bucket(bucket_name) + + # If successful, the bucket exists + print(f"GCS bucket '{bucket_name}' already exists. No creation needed.") + return bucket + + except NotFound: + # If NotFound error is raised, the bucket does not exist + print(f"GCS bucket '{bucket_name}' not found. Attempting to create...") + try: + # Create a new bucket + # Note: GCS bucket names must be globally unique! + # The location parameter is used to specify the geographical region (e.g., 'us-central1', 'asia-east1') + # If not specified, it uses the project's default location. + new_bucket = storage_client.create_bucket( + bucket_or_name=bucket_name, + location=region # Use the provided region as location + ) + print(f"Successfully created GCS bucket: '{new_bucket.name}' in region: {new_bucket.location}.") + return new_bucket + + except Exception as e: + # Catch other potential errors during creation (e.g., permission denied or name already taken) + print(f"Failed to create GCS bucket! Error message: {e}") + + except Exception as e: + # Catch other potential errors (e.g., authentication failure) + print(f"An error occurred while checking GCS bucket! Error message: {e}") \ No newline at end of file diff --git a/benchmarks/recipes/runner_utils.py b/benchmarks/recipes/runner_utils.py index 031158b59..74ab17eb0 100644 --- a/benchmarks/recipes/runner_utils.py +++ b/benchmarks/recipes/runner_utils.py @@ -68,7 +68,9 @@ def generate_and_run_workloads(user_config, num_slices_list, num_steps, priority ) # Generate XPK command - command, name = mxr.generate_xpk_workload_cmd(cluster_config=user_config.cluster_config, wl_config=wl_config) + command, name = mxr.generate_xpk_workload_cmd( + cluster_config=user_config.cluster_config, wl_config=wl_config, user=user_config.user, temp_key=user_config.temp_key + ) logging.info("Generated workload: %s", name) logging.info("Generated command: %s", command) diff --git a/benchmarks/recipes/user_configs.py b/benchmarks/recipes/user_configs.py index 6c01c0dfc..ab0f84e5c 100644 --- a/benchmarks/recipes/user_configs.py +++ b/benchmarks/recipes/user_configs.py @@ -55,8 +55,8 @@ class UserConfig: priority: str = "medium" # Images for env - server_image: str = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server" - proxy_image: str = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server" + server_image: str = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server" + proxy_image: str = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server" runner: str = "us-docker.pkg.dev/path/to/maxtext_runner" colocated_python_image: str = None worker_flags: str = "" @@ -70,9 +70,16 @@ class UserConfig: selected_model_names: list[str] = dataclasses.field(default_factory=lambda: ["llama3_1_8b_8192"]) num_slices_list: list[int] = dataclasses.field(default_factory=lambda: [2]) + # BigQuery configuration + bq_enable: bool = False + bq_db_project: str = '' + bq_db_dataset: str = '' + # other configuration xpk_path: str = "~/xpk" + delete: bool = False max_restarts: int = 0 + temp_key: str = None def __post_init__(self): """Automatically generate derived attributes after the object is created."""