Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os.path

# This is the MaxText root: with "max_utils.py"; &etc. TODO: Replace `os.path.basename` with `os.path.abspath`
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "MaxText")
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "src/MaxText")

# This is the maxtext repo root: with ".git" folder; "README.md"; "pyproject.toml"; &etc.
MAXTEXT_REPO_ROOT = os.environ.get(
Expand Down
10 changes: 8 additions & 2 deletions benchmarks/maxtext_xpk_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,8 @@ def generate_xpk_workload_cmd(
cluster_config: XpkClusterConfig,
wl_config: WorkloadConfig,
workload_name=None,
user=os.environ['USER'],
temp_key=None,
exp_name=None,
):
"""Generates a command to run a maxtext model on XPK."""
Expand All @@ -592,12 +594,16 @@ def generate_xpk_workload_cmd(

time.localtime()
length_of_random_str = 3
temp_post_fix = "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length_of_random_str))
# In order to allow the DAG to obtain the actual workload name for deletion, instead of automatically generating a random ID.
if temp_key is not None:
temp_post_fix = temp_key
else:
temp_post_fix = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length_of_random_str))

truncate_model_name = 10
truncate_prefix = 3
post_fix = f"-{wl_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}"
common_prefix = os.environ["USER"]
common_prefix = user
pw_prefix = "pw-"

if workload_name is None: # Generate name if not provided
Expand Down
30 changes: 4 additions & 26 deletions benchmarks/recipes/args_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,41 +66,19 @@ def handle_delete_specific_workload(cluster_config: XpkClusterConfig, workload_n
os.system(f"yes | {delete_command}")


def handle_cmd_args(cluster_config: XpkClusterConfig, *actions: str, **kwargs) -> bool:
def handle_cmd_args(cluster_config: XpkClusterConfig, is_delete: bool, user: str, **kwargs) -> bool:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the original recipe could only accept the --delete flag, this section had to be modified to allow it to accept multiple flags.

"""Parses command-line arguments and executes the specified actions.

Args:
cluster_config: Contains Cluster configuration information that's helpful
for running the actions.
*actions: Variable number of string arguments representing the actions to
be performed.
is_delete: A boolean indicating whether the delete action should be
performed.
**kwargs: Optional keyword arguments to be passed to action handlers.

Raises:
ValueError: If an unsupported action is provided or if unknown arguments are
passed.
"""

parser = argparse.ArgumentParser()

if DELETE in actions:
parser.add_argument(
"--delete",
action="store_true",
help="Delete workloads starting with the user's first five characters.",
)

known_args, unknown_args = parser.parse_known_args()

if unknown_args:
raise ValueError(f"Unrecognized arguments: {unknown_args}")

# Get user
user = os.environ["USER"]

# Handle actions
should_continue = True
if DELETE in actions and known_args.delete:
if is_delete:
_handle_delete(cluster_config, user, **kwargs)
should_continue = False

Expand Down
182 changes: 182 additions & 0 deletions benchmarks/recipes/parser_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module provides utility functions for custom argument parsing and defines a comprehensive set of command-line arguments for configuring a machine learning workload.
"""

import argparse

def parse_int_list(arg):
"""Parses a string with comma-separated values into a list of integers."""
return [int(x) for x in arg.split(',')]

def parse_str_list(arg):
"""Parses a string with space-separated values into a list of strings."""
return [s.strip() for s in arg.split(',')]

def str2bool(v):
"""Parses a string representation of a boolean value into a Python boolean."""
if isinstance(v, bool):
return v
if v.lower() in ('true'):
return True
elif v.lower() in ('false'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected (e.g., True or False).')

def add_arguments(parser: argparse.ArgumentParser):
"""Add arguments to arg parsers that need it.
Args:
parser: parser to add shared arguments to.
"""
# Add the arguments for each parser.
# GCP Configuration
parser.add_argument(
'--user',
type=str,
default='user_name',
help='GCP user name.')
parser.add_argument(
'--cluster_name',
type=str,
default='test-v5e-32-cluster',
help='Name of the TPU cluster.')
parser.add_argument(
'--project',
type=str,
default='cloud-tpu-cluster',
help='GCP project ID.')
parser.add_argument(
'--zone',
type=str,
default='us-south1-a',
help='GCP zone for the cluster.')
parser.add_argument(
'--device_type',
type=str,
default='v5litepod-32',
help='Type of TPU device (e.g., v5litepod-32).')
parser.add_argument(
'--priority',
type=str,
choices=['low', 'medium', 'high', 'very high'],
default='medium',
help='Priority of the job.')

# Image Configuration
parser.add_argument(
'--server_image',
type=str,
default='us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server',
help='Docker image for the proxy server.')
parser.add_argument(
'--proxy_image',
type=str,
default='us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server',
help='Docker image for the server.')
parser.add_argument(
'--runner',
type=str,
default='us-docker.pkg.dev/path/to/maxtext_runner',
help='Docker image for the runner.')
parser.add_argument(
'--colocated_python_image',
type=str,
default=None,
help='Colocated Python image.')
parser.add_argument(
'--worker_flags',
type=str,
default='',
help='Worker flags.')
parser.add_argument(
'--proxy_flags',
type=str,
default='',
help='Proxy flags.')
parser.add_argument(
'--server_flags',
type=str,
default='',
help='Server flags.')

# Model Configuration
parser.add_argument(
'--benchmark_steps',
type=int,
default=20,
help='Number of benchmark steps.')
parser.add_argument(
'--headless',
action=argparse.BooleanOptionalAction,
default=False,
help='Run in headless mode.')
parser.add_argument(
'--selected_model_framework',
type=parse_str_list,
default=['pathways'],
help='List of model frameworks (e.g., pathways, mcjax')
parser.add_argument(
'--selected_model_names',
type=parse_str_list,
default=['llama3_1_8b_8192_v5e_256'],
help='List of model names (e.g., llama3_1_8b_8192_v5e_256, llama2-7b-v5e-256')
parser.add_argument(
'--num_slices_list',
type=parse_int_list,
default=[2],
help='List of number of slices.')

# BigQuery configuration
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting for BQ PR to be merged

parser.add_argument(
'--bq_enable',
type=str2bool,
default=False,
help='Enable BigQuery logging. Must be True or False. Defaults to False.')

parser.add_argument(
'--bq_db_project',
type=str,
default='',
help='BigQuery project ID where the logging dataset resides.')

parser.add_argument(
'--bq_db_dataset',
type=str,
default='',
help='BigQuery dataset name where metrics will be written.')

# Other configurations
parser.add_argument(
'--xpk_path',
type=str,
default='~/xpk',
help='Path to xpk.')
parser.add_argument(
'--delete',
action='store_true',
help='Delete the cluster workload')
parser.add_argument(
'--max_restarts',
type=int,
default=0,
help='Maximum number of restarts')
parser.add_argument(
'--temp_key',
type=str,
default=None,
help='Temporary placeholder code')
65 changes: 46 additions & 19 deletions benchmarks/recipes/pw_mcjax_benchmark_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,60 @@
"""Used to perf benchmarks between Pathways and McJax."""
import os
import sys

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
from . import args_helper as helper
from . import user_configs
from .user_configs import UserConfig
from .user_configs import USER_CONFIG
from .runner_utils import generate_and_run_workloads
from . import parser_utils
import argparse
from google.cloud import storage
from .pw_utils import check_and_create_bucket


def main() -> int:
def main(user_config) -> int:
"""Main program entry point"""
user_configs.USER_CONFIG.headless = False
should_continue = helper.handle_cmd_args(
user_configs.USER_CONFIG.cluster_config, helper.DELETE, xpk_path=user_configs.USER_CONFIG.xpk_path
)

if not should_continue:
return 0

return_code = generate_and_run_workloads(
user_configs.USER_CONFIG,
user_configs.USER_CONFIG.num_slices_list,
user_configs.USER_CONFIG.benchmark_steps,
user_configs.USER_CONFIG.priority,
)
storage_client = storage.Client(project=user_config.project)
check_and_create_bucket(storage_client, user_config.base_output_directory[5:].split('/')[0], user_config.region)
return_code = generate_and_run_workloads(user_config, user_config.num_slices_list, user_config.benchmark_steps, user_config.priority)

return return_code


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser(description="Used to perf benchmarks between Pathways and McJax.")
parser_utils.add_arguments(parser)
args = parser.parse_args()

if len(sys.argv) > 2:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is implemented to determine whether multiple flags are being used.

print("Multiple command line arguments detected. Custom configuration will be used.")
user_config = UserConfig(**vars(args))
should_continue = helper.handle_cmd_args(
user_config.cluster_config,
is_delete=user_config.delete,
user=user_config.user,
xpk_path=user_config.xpk_path
)
if not should_continue:
sys.exit(0)

print(f"configuration used:{user_config}")
return_code = main(user_config)
sys.exit(return_code)

else:
print("No command line or only a single --delete argument was detected. The default configuration will be used.")
user_config = USER_CONFIG
if "--delete" in sys.argv:
user_config.delete = True
should_continue = helper.handle_cmd_args(
user_config.cluster_config,
is_delete=user_config.delete,
user=user_config.user,
xpk_path=user_config.xpk_path
)
if not should_continue:
sys.exit(0)
print(f"configuration used:{user_config}")
return_code= main(user_config)
sys.exit(return_code)
Loading
Loading